Add Student's t RandomVariable#1211
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1211 +/- ##
=======================================
Coverage 74.10% 74.11%
=======================================
Files 174 174
Lines 48624 48636 +12
Branches 10351 10351
=======================================
+ Hits 36035 36047 +12
Misses 10301 10301
Partials 2288 2288
|
RandomVariableRandomVariable
aesara/tensor/random/basic.py
Outdated
| gamma = GammaRV() | ||
|
|
||
|
|
||
| class StandardGammaRV(GammaRV): |
There was a problem hiding this comment.
If we only intend to add standard_gamma for NumPy interface compatibility, then the GammaRV Op alone should suffice, no? For instance, standard_gamma could be an Op constructor function that effectively removes the shape and rate arguments from GammaRV.
As always, we want to avoid adding new Ops whenever we reasonably can.
N.B. This is an example of the concern stated in aesara-devs/aemcmc#67 (comment).
There was a problem hiding this comment.
I think the issue is that you can't use RandomStream if you don't have a specific Op. That's why the standard_nornal was made a subclass as well IIRC
There was a problem hiding this comment.
If we only intend to add
standard_gammafor NumPy interface compatibility, then theGammaRVOpalone should suffice, no? For instance,standard_gammacould be anOpconstructor function that effectively removes the shape and rate arguments fromGammaRV.As always, we want to avoid adding new
Ops whenever we reasonably can.N.B. This is an example of the concern stated in aesara-devs/aemcmc#67 (comment).
I agree with all this.
@ricardoV94 we can monkey patch the base classes. I'm not a big fan of this approach, but it should work:
import aesara.tensor as at
from aesara.tensor.random.basic import NormalRV
def create_standard_normal():
def standard_call(self, shape, size=None, **kwargs):
return self.general_call(0., 1., size, **kwargs)
RV = NormalRV
RV.general_call = RV.__call__
RV.__call__ = standard_call
return RV()
standard_normal = create_standard_normal()
print(type(normal))
# <class 'aesara.tensor.random.basic.NormalRV'>
print(type(standard_normal))
# <class 'aesara.tensor.random.basic.NormalRV'>There was a problem hiding this comment.
Does it work with RandomStream like that?
There was a problem hiding this comment.
I think it will, RandomStream checks that the attribute is both in aesara.tensor.random.basic and is an instance of RandomVariable. We'll soon know for sure.
There was a problem hiding this comment.
This quickly gets complicated, monkey patching affects the class globally so i would need to replace the method of the instance directly.
By the way, I've noticed that NumPy's Generator has a a method for every RV
There was a problem hiding this comment.
Maybe we can add a canonical rewrite that replaces the standard versions immediately?
Or can we tweak RandomStream perhaps?
There was a problem hiding this comment.
Maybe we can add a canonical rewrite that replaces the standard versions immediately?
That would certainly work. But I'd rather not rely on rewrites to "fix" a problem in the representation of objects in the IR, which is also my concern in #1213. It would be better if they were represented by the same type in the original graph.
There was a problem hiding this comment.
Adding / Replacing most methods in class instances works fine, for instance:
import aesara.tensor as at
from aesara.tensor.random.basic import t, StudentTRV
def create_standard_t():
C = StudentTRV()
def new_call(self, df, size=None, **kwargs):
return self.__call__(df, 0, 1, size, **kwargs)
C.call = new_call.__get__(C)
return C
standard_t = create_standard_t()
print(standard_t.call(2.).owner.inputs)
# [RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F618591A420>), TensorConstant{[]}, TensorConstant{11}, TensorConstant{2.0}, TensorConstant{0}, TensorConstant{1}]However, special methods like __call__ are looked up with respect to the class of the object, and not its instance. So if we monkey patch the method it will end up affecting all the instances of the class, and thus instances of the non-standard RandomVariable. So I think that's a dead end.
There was a problem hiding this comment.
Yeah, I think we should be able to update RandomStream to handle these cases without too much difficulty and/or compromise.
ad83911 to
252a4b5
Compare
|
Removed the commits related to |
Numpy only defines
standard_t(it has inconsistent naming with the location-scale library), so I defined it as aScipyRandomVariablesinstead. I followed what was done for the other members of the location-scale family.Here are a few important guidelines and requirements to check before your PR can be merged:
pre-commitis installed and set up.Ticks one off of #1093