Skip to content

Separate observation processing from state inference #330

@conorheins

Description

@conorheins

Following the landing of #315 to v1.0.0_alpha and following up on @nikolamilovic-ft's comment there, in order to separate the optimized inference execution itself from the overhead introducing by padding observations, it makes sense to separate the processing of observations from the inference algorithm run upon them.

The idea would be to add a separate process_obs() method to the Agent class which can be overridden in a custom way (or be flexibly configured based on some kind of InferenceConfig if and when we get around to adding that, see e.g. #275) when using advanced inference algorithms where padding/preprocessing overhead can be segregated from the execution of the inference algorithm on those preprocessed observations. This is based on conversations with @praesc @nikolamilovic-ft where it can be useful to preprocess observations separately (even in a different runtime or at least jitted context) before running inference upon them.

The default behavior of process_obs() would be like what currently happens at the beginning of infer_states(), namely

def process_obs(self, observations, categorical_obs=False):
       use_categorical = categorical_obs or self.categorical_obs

        if not use_categorical:
            o_vec = [nn.one_hot(o, self.num_obs[m]) for m, o in enumerate(observations)]
        else:
            o_vec = observations
       
        return o_vec

When using optimized inference algorithms that require advanced padding techniques, this method could be overridden in an algo-specific way.

Then instead of calling only infer_states() in context of e.g., rollout.py::infer_and_plan(), we do instead

processed_obs = agent.process_obs(observations)
qs = agent.infer_states(processed_obs, ...)

This way, the processing could in theory (in custom designs) be separated out from the inference. If we want to keep the lines of code in rollout streamlined however, we could still create some default version of infer_states that includes both steps and basically is exactly what is currently happening in infer_states(), namely

def infer_states_fused(observations, categorical_obs=False)
     processed_obs = self.process_obs(observations, categorical_obs=categorical_obs) # with the default behavior, e.g. just categorical vs discrete index obs
     return agent.infer_states(processed_obs, ...)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions