Skip to content

rollout function is incompatible with onehot_obs Agent? #277

@Arun-Niranjan

Description

@Arun-Niranjan

With an initial setup like so:

from jax import random as jr
from pymdp.agent import Agent
from pymdp.envs import TMaze, rollout

key = jr.PRNGKey(0)
batch_size=1
T = 5
env = TMaze(
    batch_size=batch_size,
)
A, A_dependencies = env.generate_A()
B, B_dependencies = env.generate_B()
agent = Agent(
    A=A,
    B=B,
    A_dependencies=A_dependencies,
    B_dependencies=B_dependencies,
    batch_size=batch_size,
    policy_len=T,
    onehot_obs=True,
)
_, info, _ = rollout(agent, env, num_timesteps=T, rng_key=key)

rollout raises an exception when trying to call agent.infer_states(...).

This does not occur when setting onehot_obs=False in the Agent constructor. We should either:

  1. fix rollout to handle one hot encoded observations (preferable)
  2. remove the flag from the agent constructor

if we keep support for one hot encoded observations, we should add unit tests to cover this option

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions