This repository was archived by the owner on Dec 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 51
This repository was archived by the owner on Dec 18, 2023. It is now read-only.
BMGInference does not handle torch.stack called with other RVs #1565
Copy link
Copy link
Open
Description
Issue Description
When other RVs are concatenated together using torch.stack, BMGInference fails to
trace execution because it assumes that all arguments to stack are of type Tensor.
The example runs fine if stack is replaced by torch.tensor, but torch.tensor is not differentiable wrt its arguments which precludes methods such as VI and HMC.
Steps to Reproduce
import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference
foo = bm.random_variable(lambda: dist.Normal(torch.stack([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
queries=[foo()],
observations={},
num_samples=1,
)
raises
expected Tensor as element 0 in argument 0, but got SampleNode
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-358-e65cd9e99a94> in <module>
4 foo = bm.random_variable(lambda: dist.MultivariateNormal(torch.stack([bar(i) for i in range(2)]), torch.eye(2)))
5 bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
----> 6 BMGInference().infer(
7 queries=[foo()],
8 observations={},
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in infer(self, queries, observations, num_samples, num_chains, inference_type, skip_optimizations)
262 # TODO: Add verbose level
263 # TODO: Add logging
--> 264 samples, _ = self._infer(
265 queries,
266 observations,
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _infer(self, queries, observations, num_samples, num_chains, inference_type, produce_report, skip_optimizations)
182 self._pd = prof.ProfilerData()
183
--> 184 rt = self._accumulate_graph(queries, observations)
185 bmg = rt._bmg
186 report = pr.PerformanceReport()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _accumulate_graph(self, queries, observations)
71 rt = BMGRuntime()
72 rt._pd = self._pd
---> 73 bmg = rt.accumulate_graph(queries, observations)
74 # TODO: Figure out a better way to pass this flag around
75 bmg._fix_observe_true = self._fix_observe_true
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in accumulate_graph(self, queries, observations)
719 self._bmg.add_observation(node, val)
720 for qrv in queries:
--> 721 node = self._rv_to_node(qrv)
722 q = self._bmg.add_query(node)
723 self._rv_to_query[qrv] = q
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in _rv_to_node(self, rv)
583 # RVID, and if we're in the second situation, we will not.
584
--> 585 value = self._context.call(rewritten_function, rv.arguments)
586 if isinstance(value, RVIdentifier):
587 # We have a rewritten function with a decorator already applied.
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/execution_context.py in call(self, func, args, kwargs)
92 self._stack.push(FunctionCall(func, args, kwargs))
93 try:
---> 94 return func(*args, **kwargs)
95 finally:
96 self._stack.pop()
<BMGJIT> in a1()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in handle_function(self, function, arguments, kwargs)
510 function, arguments, kwargs
511 ):
--> 512 result = self._special_function_caller.do_special_call_maybe_stochastic(
513 function, arguments, kwargs
514 )
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/special_function_caller.py in do_special_call_maybe_stochastic(self, func, args, kwargs)
629 new_args = (_get_ordinary_value(arg) for arg in args)
630 new_kwargs = {key: _get_ordinary_value(arg) for key, arg in kwargs.items()}
--> 631 return func(*new_args, **new_kwargs)
632
633 if _is_in_place_operator(func):
TypeError: expected Tensor as element 0 in argument 0, but got SampleNode
Expected Behavior
Successful execution with identical results to s/stack/tensor i.e.
import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference
foo = bm.random_variable(lambda: dist.Normal(torch.tensor([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
queries=[foo()],
observations={},
num_samples=1,
)
Metadata
Metadata
Assignees
Labels
No labels