From c0829e5196035e81edc1e249e8ac59c6a913adbb Mon Sep 17 00:00:00 2001 From: CheeseLee888 Date: Wed, 8 Oct 2025 14:48:44 -0500 Subject: [PATCH 1/2] Split metamodels into univariate/multivariate version --- contextualized/regression/metamodels.py | 106 ++++++++++++++++++++---- 1 file changed, 88 insertions(+), 18 deletions(-) diff --git a/contextualized/regression/metamodels.py b/contextualized/regression/metamodels.py index b65ae32d..a443eca5 100644 --- a/contextualized/regression/metamodels.py +++ b/contextualized/regression/metamodels.py @@ -8,8 +8,9 @@ from contextualized.modules import ENCODERS, Explainer, SoftSelect from contextualized.functions import LINK_FUNCTIONS +# Multivariate variants (explicit classes below) -class NaiveMetamodel(nn.Module): +class NaiveMultivariateMetamodel(nn.Module): """Probabilistic assumptions as a graphical model (observed) {unobserved}: (C) --> {beta, mu} --> (X, Y) @@ -21,7 +22,6 @@ def __init__( context_dim: int, x_dim: int, y_dim: int, - univariate: bool = False, encoder_type: str = "mlp", width: int = 25, layers: int = 1, @@ -46,7 +46,8 @@ def __init__( self.y_dim = y_dim encoder = ENCODERS[encoder_type] - self.mu_dim = x_dim if univariate else 1 + # multivariate: mu is scalar per y + self.mu_dim = 1 out_dim = (x_dim + self.mu_dim) * y_dim if encoder_type == "linear": self.context_encoder = encoder(context_dim, out_dim) @@ -68,7 +69,7 @@ def forward(self, C): return beta, mu -class SubtypeMetamodel(nn.Module): +class SubtypeMultivariateMetamodel(nn.Module): """Probabilistic assumptions as a graphical model (observed) {unobserved}: (C) <-- {Z} --> {beta, mu} --> (X) @@ -82,7 +83,6 @@ def __init__( context_dim: int, x_dim: int, y_dim: int, - univariate: bool = False, num_archetypes: int = 10, encoder_type: str = "mlp", width: int = 25, @@ -109,7 +109,8 @@ def __init__( self.y_dim = y_dim encoder = ENCODERS[encoder_type] - out_shape = (y_dim, x_dim * 2, 1) if univariate else (y_dim, x_dim + 1) + # multivariate: out_shape is (y_dim, x_dim + 1) + out_shape = (y_dim, x_dim + 1) if encoder_type == "linear": self.context_encoder = encoder(context_dim, num_archetypes) else: @@ -131,7 +132,7 @@ def forward(self, C): return beta, mu -class MultitaskMetamodel(nn.Module): +class MultitaskMultivariateMetamodel(nn.Module): """Probabilistic assumptions as a graphical model (observed) {unobserved}: (C) <-- {Z} --> {beta, mu} --> (X) (T) <---/ @@ -146,7 +147,6 @@ def __init__( context_dim: int, x_dim: int, y_dim: int, - univariate: bool = False, num_archetypes: int = 10, encoder_type: str = "mlp", width: int = 25, @@ -173,8 +173,9 @@ def __init__( self.y_dim = y_dim encoder = ENCODERS[encoder_type] - beta_dim = 1 if univariate else x_dim - task_dim = y_dim + x_dim if univariate else y_dim + # multivariate multitask: beta_dim = x_dim, task_dim = y_dim + beta_dim = x_dim + task_dim = y_dim if encoder_type == "linear": self.context_encoder = encoder(context_dim + task_dim, num_archetypes) else: @@ -202,7 +203,7 @@ def forward(self, C, T): return beta, mu -class TasksplitMetamodel(nn.Module): +class TasksplitMultivariateMetamodel(nn.Module): """Probabilistic assumptions as a graphical model (observed) {unobserved}: (C) <-- {Z_c} --> {beta, mu} --> (X) (T) <-- {Z_t} ----^ @@ -218,7 +219,6 @@ def __init__( context_dim: int, x_dim: int, y_dim: int, - univariate: bool = False, context_archetypes: int = 10, task_archetypes: int = 10, context_encoder_type: str = "mlp", @@ -256,8 +256,9 @@ def __init__( context_encoder = ENCODERS[context_encoder_type] task_encoder = ENCODERS[task_encoder_type] - beta_dim = 1 if univariate else x_dim - task_dim = y_dim + x_dim if univariate else y_dim + # multivariate tasksplit: beta_dim = x_dim, task_dim = y_dim + beta_dim = x_dim + task_dim = y_dim self.context_encoder = context_encoder( context_dim, context_archetypes, @@ -292,11 +293,80 @@ def forward(self, C, T): SINGLE_TASK_METAMODELS = { - "naive": NaiveMetamodel, - "subtype": SubtypeMetamodel, + "naive": NaiveMultivariateMetamodel, + "subtype": SubtypeMultivariateMetamodel, } MULTITASK_METAMODELS = { - "multitask": MultitaskMetamodel, - "tasksplit": TasksplitMetamodel, + "multitask": MultitaskMultivariateMetamodel, + "tasksplit": TasksplitMultivariateMetamodel, +} + +# Backwards compatible aliases for the original class names +NaiveMetamodel = NaiveMultivariateMetamodel +SubtypeMetamodel = SubtypeMultivariateMetamodel +MultitaskMetamodel = MultitaskMultivariateMetamodel +TasksplitMetamodel = TasksplitMultivariateMetamodel + +# Univariate variants (explicit classes below) + + +class NaiveUnivariateMetamodel(NaiveMetamodel): + """Univariate version of NaiveMetamodel where mu has dimension x_dim.""" + + def __init__(self, context_dim, x_dim, y_dim, encoder_type="mlp", width=25, layers=1, link_fn=LINK_FUNCTIONS["identity"]): + super().__init__(context_dim, x_dim, y_dim, encoder_type=encoder_type, width=width, layers=layers, link_fn=link_fn) + # override mu_dim and reshaping behavior + self.mu_dim = x_dim + out_dim = (x_dim + self.mu_dim) * y_dim + # rebuild encoder to match new out_dim + encoder = ENCODERS[encoder_type] + if encoder_type == "linear": + self.context_encoder = encoder(context_dim, out_dim) + else: + self.context_encoder = encoder(context_dim, out_dim, width=width, layers=layers, link_fn=link_fn) + + +class SubtypeUnivariateMetamodel(SubtypeMetamodel): + def __init__(self, context_dim, x_dim, y_dim, num_archetypes=10, encoder_type="mlp", width=25, layers=1, link_fn=LINK_FUNCTIONS["identity"]): + super().__init__(context_dim, x_dim, y_dim, num_archetypes=num_archetypes, encoder_type=encoder_type, width=width, layers=layers, link_fn=link_fn) + # adjust explainer out shape for univariate behavior + out_shape = (y_dim, x_dim * 2, 1) + self.explainer = Explainer(num_archetypes, out_shape) + + +class MultitaskUnivariateMetamodel(MultitaskMetamodel): + def __init__(self, context_dim, x_dim, y_dim, num_archetypes=10, encoder_type="mlp", width=25, layers=1, link_fn=LINK_FUNCTIONS["identity"]): + # For univariate multitask, beta is scalar per task and task_dim = y_dim + x_dim + super().__init__(context_dim, x_dim, y_dim, num_archetypes=num_archetypes, encoder_type=encoder_type, width=width, layers=layers, link_fn=link_fn) + beta_dim = 1 + task_dim = y_dim + x_dim + encoder = ENCODERS[encoder_type] + if encoder_type == "linear": + self.context_encoder = encoder(context_dim + task_dim, num_archetypes) + else: + self.context_encoder = encoder(context_dim + task_dim, num_archetypes, width=width, layers=layers, link_fn=link_fn) + self.explainer = Explainer(num_archetypes, (beta_dim + 1,)) + + +class TasksplitUnivariateMetamodel(TasksplitMetamodel): + def __init__(self, context_dim, x_dim, y_dim, context_archetypes=10, task_archetypes=10, context_encoder_type="mlp", context_width=25, context_layers=1, context_link_fn=LINK_FUNCTIONS["softmax"], task_encoder_type="mlp", task_width=25, task_layers=1, task_link_fn=LINK_FUNCTIONS["identity"]): + super().__init__(context_dim, x_dim, y_dim, context_archetypes=context_archetypes, task_archetypes=task_archetypes, context_encoder_type=context_encoder_type, context_width=context_width, context_layers=context_layers, context_link_fn=context_link_fn, task_encoder_type=task_encoder_type, task_width=task_width, task_layers=task_layers, task_link_fn=task_link_fn) + beta_dim = 1 + task_dim = y_dim + x_dim + context_encoder = ENCODERS[context_encoder_type] + task_encoder = ENCODERS[task_encoder_type] + self.context_encoder = context_encoder(context_dim, context_archetypes, width=context_width, layers=context_layers, link_fn=context_link_fn) + self.task_encoder = task_encoder(task_dim, task_archetypes, width=task_width, layers=task_layers, link_fn=task_link_fn) + self.explainer = SoftSelect((context_archetypes, task_archetypes), (beta_dim + 1,)) + + +SINGLE_TASK_UNIVARIATE_METAMODELS = { + "naive": NaiveUnivariateMetamodel, + "subtype": SubtypeUnivariateMetamodel, +} + +MULTITASK_UNIVARIATE_METAMODELS = { + "multitask": MultitaskUnivariateMetamodel, + "tasksplit": TasksplitUnivariateMetamodel, } From c2ef2ad18e0cc201a1cdcb7cc596f1f275c3c7f1 Mon Sep 17 00:00:00 2001 From: CheeseLee888 Date: Wed, 8 Oct 2025 15:38:26 -0500 Subject: [PATCH 2/2] Lightning modules explicitly call uni/multi variate model with their names withoug setting univariate parameters --- contextualized/regression/lightning_modules.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index 4cfd0bb6..7093bd8b 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -26,8 +26,11 @@ SubtypeMetamodel, MultitaskMetamodel, TasksplitMetamodel, + TasksplitUnivariateMetamodel, SINGLE_TASK_METAMODELS, MULTITASK_METAMODELS, + SINGLE_TASK_UNIVARIATE_METAMODELS, + MULTITASK_UNIVARIATE_METAMODELS, ) from contextualized.regression.datasets import ( DataIterable, @@ -269,7 +272,6 @@ def _build_metamodel( context_dim, x_dim, y_dim, - univariate=False, encoder_type=encoder_type, width=width, layers=layers, @@ -372,7 +374,6 @@ def _build_metamodel( context_dim, x_dim, y_dim, - univariate=False, encoder_type=encoder_type, width=width, layers=layers, @@ -481,7 +482,6 @@ def _build_metamodel( context_dim, x_dim, y_dim, - univariate=False, encoder_type=encoder_type, width=width, layers=layers, @@ -595,7 +595,6 @@ def _build_metamodel( context_dim, x_dim, y_dim, - univariate=False, context_archetypes=context_archetypes, task_archetypes=task_archetypes, context_encoder_type=context_encoder_type, @@ -700,11 +699,11 @@ def _build_metamodel( :param **kwargs: Additional keyword arguments for the metamodel """ - self.metamodel = SINGLE_TASK_METAMODELS[self.metamodel_type]( + # Use explicit univariate metamodel mapping + self.metamodel = SINGLE_TASK_UNIVARIATE_METAMODELS[self.metamodel_type]( context_dim, x_dim, y_dim, - univariate=True, encoder_type=encoder_type, width=width, layers=layers, @@ -790,11 +789,11 @@ def _build_metamodel( :param task_link_fn: Link function to use for the task (default is identity) """ - self.metamodel = TasksplitMetamodel( + # Use the explicit Tasksplit univariate metamodel + self.metamodel = TasksplitUnivariateMetamodel( context_dim, x_dim, y_dim, - univariate=True, context_archetypes=context_archetypes, task_archetypes=task_archetypes, context_encoder_type=context_encoder_type,