Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions contextualized/regression/lightning_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
SubtypeMetamodel,
MultitaskMetamodel,
TasksplitMetamodel,
TasksplitUnivariateMetamodel,
SINGLE_TASK_METAMODELS,
MULTITASK_METAMODELS,
SINGLE_TASK_UNIVARIATE_METAMODELS,
MULTITASK_UNIVARIATE_METAMODELS,
)
from contextualized.regression.datasets import (
DataIterable,
Expand Down Expand Up @@ -269,7 +272,6 @@ def _build_metamodel(
context_dim,
x_dim,
y_dim,
univariate=False,
encoder_type=encoder_type,
width=width,
layers=layers,
Expand Down Expand Up @@ -372,7 +374,6 @@ def _build_metamodel(
context_dim,
x_dim,
y_dim,
univariate=False,
encoder_type=encoder_type,
width=width,
layers=layers,
Expand Down Expand Up @@ -481,7 +482,6 @@ def _build_metamodel(
context_dim,
x_dim,
y_dim,
univariate=False,
encoder_type=encoder_type,
width=width,
layers=layers,
Expand Down Expand Up @@ -595,7 +595,6 @@ def _build_metamodel(
context_dim,
x_dim,
y_dim,
univariate=False,
context_archetypes=context_archetypes,
task_archetypes=task_archetypes,
context_encoder_type=context_encoder_type,
Expand Down Expand Up @@ -700,11 +699,11 @@ def _build_metamodel(
:param **kwargs: Additional keyword arguments for the metamodel

"""
self.metamodel = SINGLE_TASK_METAMODELS[self.metamodel_type](
# Use explicit univariate metamodel mapping
self.metamodel = SINGLE_TASK_UNIVARIATE_METAMODELS[self.metamodel_type](
context_dim,
x_dim,
y_dim,
univariate=True,
encoder_type=encoder_type,
width=width,
layers=layers,
Expand Down Expand Up @@ -790,11 +789,11 @@ def _build_metamodel(
:param task_link_fn: Link function to use for the task (default is identity)

"""
self.metamodel = TasksplitMetamodel(
# Use the explicit Tasksplit univariate metamodel
self.metamodel = TasksplitUnivariateMetamodel(
context_dim,
x_dim,
y_dim,
univariate=True,
context_archetypes=context_archetypes,
task_archetypes=task_archetypes,
context_encoder_type=context_encoder_type,
Expand Down
106 changes: 88 additions & 18 deletions contextualized/regression/metamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from contextualized.modules import ENCODERS, Explainer, SoftSelect
from contextualized.functions import LINK_FUNCTIONS

# Multivariate variants (explicit classes below)

class NaiveMetamodel(nn.Module):
class NaiveMultivariateMetamodel(nn.Module):
"""Probabilistic assumptions as a graphical model (observed) {unobserved}:
(C) --> {beta, mu} --> (X, Y)

Expand All @@ -21,7 +22,6 @@ def __init__(
context_dim: int,
x_dim: int,
y_dim: int,
univariate: bool = False,
encoder_type: str = "mlp",
width: int = 25,
layers: int = 1,
Expand All @@ -46,7 +46,8 @@ def __init__(
self.y_dim = y_dim

encoder = ENCODERS[encoder_type]
self.mu_dim = x_dim if univariate else 1
# multivariate: mu is scalar per y
self.mu_dim = 1
out_dim = (x_dim + self.mu_dim) * y_dim
if encoder_type == "linear":
self.context_encoder = encoder(context_dim, out_dim)
Expand All @@ -68,7 +69,7 @@ def forward(self, C):
return beta, mu


class SubtypeMetamodel(nn.Module):
class SubtypeMultivariateMetamodel(nn.Module):
"""Probabilistic assumptions as a graphical model (observed) {unobserved}:
(C) <-- {Z} --> {beta, mu} --> (X)

Expand All @@ -82,7 +83,6 @@ def __init__(
context_dim: int,
x_dim: int,
y_dim: int,
univariate: bool = False,
num_archetypes: int = 10,
encoder_type: str = "mlp",
width: int = 25,
Expand All @@ -109,7 +109,8 @@ def __init__(
self.y_dim = y_dim

encoder = ENCODERS[encoder_type]
out_shape = (y_dim, x_dim * 2, 1) if univariate else (y_dim, x_dim + 1)
# multivariate: out_shape is (y_dim, x_dim + 1)
out_shape = (y_dim, x_dim + 1)
if encoder_type == "linear":
self.context_encoder = encoder(context_dim, num_archetypes)
else:
Expand All @@ -131,7 +132,7 @@ def forward(self, C):
return beta, mu


class MultitaskMetamodel(nn.Module):
class MultitaskMultivariateMetamodel(nn.Module):
"""Probabilistic assumptions as a graphical model (observed) {unobserved}:
(C) <-- {Z} --> {beta, mu} --> (X)
(T) <---/
Expand All @@ -146,7 +147,6 @@ def __init__(
context_dim: int,
x_dim: int,
y_dim: int,
univariate: bool = False,
num_archetypes: int = 10,
encoder_type: str = "mlp",
width: int = 25,
Expand All @@ -173,8 +173,9 @@ def __init__(
self.y_dim = y_dim

encoder = ENCODERS[encoder_type]
beta_dim = 1 if univariate else x_dim
task_dim = y_dim + x_dim if univariate else y_dim
# multivariate multitask: beta_dim = x_dim, task_dim = y_dim
beta_dim = x_dim
task_dim = y_dim
if encoder_type == "linear":
self.context_encoder = encoder(context_dim + task_dim, num_archetypes)
else:
Expand Down Expand Up @@ -202,7 +203,7 @@ def forward(self, C, T):
return beta, mu


class TasksplitMetamodel(nn.Module):
class TasksplitMultivariateMetamodel(nn.Module):
"""Probabilistic assumptions as a graphical model (observed) {unobserved}:
(C) <-- {Z_c} --> {beta, mu} --> (X)
(T) <-- {Z_t} ----^
Expand All @@ -218,7 +219,6 @@ def __init__(
context_dim: int,
x_dim: int,
y_dim: int,
univariate: bool = False,
context_archetypes: int = 10,
task_archetypes: int = 10,
context_encoder_type: str = "mlp",
Expand Down Expand Up @@ -256,8 +256,9 @@ def __init__(

context_encoder = ENCODERS[context_encoder_type]
task_encoder = ENCODERS[task_encoder_type]
beta_dim = 1 if univariate else x_dim
task_dim = y_dim + x_dim if univariate else y_dim
# multivariate tasksplit: beta_dim = x_dim, task_dim = y_dim
beta_dim = x_dim
task_dim = y_dim
self.context_encoder = context_encoder(
context_dim,
context_archetypes,
Expand Down Expand Up @@ -292,11 +293,80 @@ def forward(self, C, T):


SINGLE_TASK_METAMODELS = {
"naive": NaiveMetamodel,
"subtype": SubtypeMetamodel,
"naive": NaiveMultivariateMetamodel,
"subtype": SubtypeMultivariateMetamodel,
}

MULTITASK_METAMODELS = {
"multitask": MultitaskMetamodel,
"tasksplit": TasksplitMetamodel,
"multitask": MultitaskMultivariateMetamodel,
"tasksplit": TasksplitMultivariateMetamodel,
}

# Backwards compatible aliases for the original class names
NaiveMetamodel = NaiveMultivariateMetamodel
SubtypeMetamodel = SubtypeMultivariateMetamodel
MultitaskMetamodel = MultitaskMultivariateMetamodel
TasksplitMetamodel = TasksplitMultivariateMetamodel

# Univariate variants (explicit classes below)


class NaiveUnivariateMetamodel(NaiveMetamodel):
"""Univariate version of NaiveMetamodel where mu has dimension x_dim."""

def __init__(self, context_dim, x_dim, y_dim, encoder_type="mlp", width=25, layers=1, link_fn=LINK_FUNCTIONS["identity"]):
super().__init__(context_dim, x_dim, y_dim, encoder_type=encoder_type, width=width, layers=layers, link_fn=link_fn)
# override mu_dim and reshaping behavior
self.mu_dim = x_dim
out_dim = (x_dim + self.mu_dim) * y_dim
# rebuild encoder to match new out_dim
encoder = ENCODERS[encoder_type]
if encoder_type == "linear":
self.context_encoder = encoder(context_dim, out_dim)
else:
self.context_encoder = encoder(context_dim, out_dim, width=width, layers=layers, link_fn=link_fn)


class SubtypeUnivariateMetamodel(SubtypeMetamodel):
def __init__(self, context_dim, x_dim, y_dim, num_archetypes=10, encoder_type="mlp", width=25, layers=1, link_fn=LINK_FUNCTIONS["identity"]):
super().__init__(context_dim, x_dim, y_dim, num_archetypes=num_archetypes, encoder_type=encoder_type, width=width, layers=layers, link_fn=link_fn)
# adjust explainer out shape for univariate behavior
out_shape = (y_dim, x_dim * 2, 1)
self.explainer = Explainer(num_archetypes, out_shape)


class MultitaskUnivariateMetamodel(MultitaskMetamodel):
def __init__(self, context_dim, x_dim, y_dim, num_archetypes=10, encoder_type="mlp", width=25, layers=1, link_fn=LINK_FUNCTIONS["identity"]):
# For univariate multitask, beta is scalar per task and task_dim = y_dim + x_dim
super().__init__(context_dim, x_dim, y_dim, num_archetypes=num_archetypes, encoder_type=encoder_type, width=width, layers=layers, link_fn=link_fn)
beta_dim = 1
task_dim = y_dim + x_dim
encoder = ENCODERS[encoder_type]
if encoder_type == "linear":
self.context_encoder = encoder(context_dim + task_dim, num_archetypes)
else:
self.context_encoder = encoder(context_dim + task_dim, num_archetypes, width=width, layers=layers, link_fn=link_fn)
self.explainer = Explainer(num_archetypes, (beta_dim + 1,))


class TasksplitUnivariateMetamodel(TasksplitMetamodel):
def __init__(self, context_dim, x_dim, y_dim, context_archetypes=10, task_archetypes=10, context_encoder_type="mlp", context_width=25, context_layers=1, context_link_fn=LINK_FUNCTIONS["softmax"], task_encoder_type="mlp", task_width=25, task_layers=1, task_link_fn=LINK_FUNCTIONS["identity"]):
super().__init__(context_dim, x_dim, y_dim, context_archetypes=context_archetypes, task_archetypes=task_archetypes, context_encoder_type=context_encoder_type, context_width=context_width, context_layers=context_layers, context_link_fn=context_link_fn, task_encoder_type=task_encoder_type, task_width=task_width, task_layers=task_layers, task_link_fn=task_link_fn)
beta_dim = 1
task_dim = y_dim + x_dim
context_encoder = ENCODERS[context_encoder_type]
task_encoder = ENCODERS[task_encoder_type]
self.context_encoder = context_encoder(context_dim, context_archetypes, width=context_width, layers=context_layers, link_fn=context_link_fn)
self.task_encoder = task_encoder(task_dim, task_archetypes, width=task_width, layers=task_layers, link_fn=task_link_fn)
self.explainer = SoftSelect((context_archetypes, task_archetypes), (beta_dim + 1,))


SINGLE_TASK_UNIVARIATE_METAMODELS = {
"naive": NaiveUnivariateMetamodel,
"subtype": SubtypeUnivariateMetamodel,
}

MULTITASK_UNIVARIATE_METAMODELS = {
"multitask": MultitaskUnivariateMetamodel,
"tasksplit": TasksplitUnivariateMetamodel,
}