-
Notifications
You must be signed in to change notification settings - Fork 119
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Raised by @riddhipits , who was running into awkward construction of multiple agents using the model class with different generative model parameters.
# these are both instances of the `Model` class from `pymdp.distribution`
model_agent0 = compile_model(agent0_args)
model_agent1 = compile_model(agent1_args)Have to instantiate separate Agent instances as Agent doesn't take admit Model instances with a batch shape
agent0 = Agent(model_agent0.A, model_agent0.B, ..., batch_size=1)
agent1 = Agent(model_agent1.A, model_agent1.B, ..., batch_size=1)then, we have to concatenate the two Agents so that they are batched
agents = jtu.tree_map(lambda x,y: jnp.concatenate([x,y], axis=0), agent0, agent1)To fix this:
- allow the
Agentconstructor to take aModelas input argument, or perhaps some unpacking logic forModelsuch that you can pass inAgent(**model, ...)and it'll unpack theDistributionattributes of theModelinstance - allow the
Modeland itsDistributionattributes to have a batch_size, so we can pass in a batch of models / distributions to theAgentconstructor
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request