-
Notifications
You must be signed in to change notification settings - Fork 119
Description
Following the landing of #315 to v1.0.0_alpha and following up on @nikolamilovic-ft's comment there, in order to separate the optimized inference execution itself from the overhead introducing by padding observations, it makes sense to separate the processing of observations from the inference algorithm run upon them.
The idea would be to add a separate process_obs() method to the Agent class which can be overridden in a custom way (or be flexibly configured based on some kind of InferenceConfig if and when we get around to adding that, see e.g. #275) when using advanced inference algorithms where padding/preprocessing overhead can be segregated from the execution of the inference algorithm on those preprocessed observations. This is based on conversations with @praesc @nikolamilovic-ft where it can be useful to preprocess observations separately (even in a different runtime or at least jitted context) before running inference upon them.
The default behavior of process_obs() would be like what currently happens at the beginning of infer_states(), namely
def process_obs(self, observations, categorical_obs=False):
use_categorical = categorical_obs or self.categorical_obs
if not use_categorical:
o_vec = [nn.one_hot(o, self.num_obs[m]) for m, o in enumerate(observations)]
else:
o_vec = observations
return o_vecWhen using optimized inference algorithms that require advanced padding techniques, this method could be overridden in an algo-specific way.
Then instead of calling only infer_states() in context of e.g., rollout.py::infer_and_plan(), we do instead
processed_obs = agent.process_obs(observations)
qs = agent.infer_states(processed_obs, ...)This way, the processing could in theory (in custom designs) be separated out from the inference. If we want to keep the lines of code in rollout streamlined however, we could still create some default version of infer_states that includes both steps and basically is exactly what is currently happening in infer_states(), namely
def infer_states_fused(observations, categorical_obs=False)
processed_obs = self.process_obs(observations, categorical_obs=categorical_obs) # with the default behavior, e.g. just categorical vs discrete index obs
return agent.infer_states(processed_obs, ...)