Skip to content

Create batch_size compatible versions of Model and Distribution classes #316

@conorheins

Description

@conorheins

Raised by @riddhipits , who was running into awkward construction of multiple agents using the model class with different generative model parameters.

# these are both instances of the `Model` class from `pymdp.distribution`
model_agent0 = compile_model(agent0_args)
model_agent1 = compile_model(agent1_args)

Have to instantiate separate Agent instances as Agent doesn't take admit Model instances with a batch shape

agent0 = Agent(model_agent0.A, model_agent0.B, ...,  batch_size=1)
agent1 = Agent(model_agent1.A, model_agent1.B, ...,  batch_size=1)

then, we have to concatenate the two Agents so that they are batched

agents = jtu.tree_map(lambda x,y: jnp.concatenate([x,y], axis=0), agent0, agent1)

To fix this:

  • allow the Agent constructor to take a Model as input argument, or perhaps some unpacking logic for Model such that you can pass in Agent(**model, ...) and it'll unpack the Distribution attributes of the Model instance
  • allow the Model and its Distribution attributes to have a batch_size, so we can pass in a batch of models / distributions to the Agent constructor

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions