-
Notifications
You must be signed in to change notification settings - Fork 119
Open
Labels
bugSomething isn't workingSomething isn't working
Description
With an initial setup like so:
from jax import random as jr
from pymdp.agent import Agent
from pymdp.envs import TMaze, rollout
key = jr.PRNGKey(0)
batch_size=1
T = 5
env = TMaze(
batch_size=batch_size,
)
A, A_dependencies = env.generate_A()
B, B_dependencies = env.generate_B()
agent = Agent(
A=A,
B=B,
A_dependencies=A_dependencies,
B_dependencies=B_dependencies,
batch_size=batch_size,
policy_len=T,
onehot_obs=True,
)
_, info, _ = rollout(agent, env, num_timesteps=T, rng_key=key)rollout raises an exception when trying to call agent.infer_states(...).
This does not occur when setting onehot_obs=False in the Agent constructor. We should either:
- fix rollout to handle one hot encoded observations (preferable)
- remove the flag from the agent constructor
if we keep support for one hot encoded observations, we should add unit tests to cover this option
conorheins
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working