Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 23 additions & 57 deletions ndlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,79 +3,45 @@
import torch.nn as nn
import torch.optim as optim


class NdLinear(nn.Module):
def __init__(self, input_dims: tuple, hidden_size: tuple, transform_outer=True):
def __init__(self, input_dims: tuple, hidden_size: tuple, transform_outer=True, bias=True):
"""
NdLinear: A PyTorch layer for projecting tensors into multi-space representations.

Unlike conventional embedding layers that map into a single vector space, NdLinear
transforms tensors across a collection of vector spaces, capturing multivariate structure
and topical information that standard deep learning architectures typically lose.
NdLinear: Multidimensional Linear Projection Layer.
Projects tensor inputs across multiple dimensions independently.

Args:
input_dims (tuple): Shape of input tensor (excluding batch dimension).
hidden_size (tuple): Target hidden dimensions after transformation.
input_dims (tuple): Shape of the input excluding batch dimension.
hidden_size (tuple): Shape after transformation (must be same length as input_dims).
transform_outer (bool): Whether to transform outer dims first (True) or inner (False).
bias (bool): Whether to use bias in each Linear layer.
"""
super(NdLinear, self).__init__()

super().__init__()
if len(input_dims) != len(hidden_size):
raise Exception("Input shape and hidden shape do not match.")

raise ValueError("Input and hidden shape must have the same rank.")
self.input_dims = input_dims
self.hidden_size = hidden_size
self.num_layers = len(input_dims) # Must match since dims are equal
self.num_layers = len(input_dims)
self.transform_outer = transform_outer

# Define transformation layers per dimension
self.align_layers = nn.ModuleList([
nn.Linear(input_dims[i], hidden_size[i]) for i in range(self.num_layers)
nn.Linear(input_dims[i], hidden_size[i], bias=bias)
for i in range(self.num_layers)
])


def forward(self, X):
"""
Forward pass to project input tensor into a new multi-space representation.
- Incrementally transposes, flattens, applies linear layers, and restores shape.
transform_indices = range(self.num_layers) if self.transform_outer else reversed(range(self.num_layers))

Expected Input Shape: [batch_size, *input_dims]
Output Shape: [batch_size, *hidden_size]

Args:
X (torch.Tensor): Input tensor with shape [batch_size, *input_dims]
for i in transform_indices:
layer = self.align_layers[i]
transpose_dim = i + 1 # +1 to skip batch

Returns:
torch.Tensor: Output tensor with shape [batch_size, *hidden_size]
"""
num_transforms = self.num_layers # Number of transformations

# Define iteration order
# transform_indices = range(num_transforms) if transform_outer else reversed(range(num_transforms))

for i in range(num_transforms):
if self.transform_outer:
layer = self.align_layers[i]
transpose_dim = i + 1
else:
layer = self.align_layers[num_transforms - (i+1)]
transpose_dim = num_transforms - i

# Transpose the selected dimension to the last position
X = torch.transpose(X, transpose_dim, num_transforms).contiguous()

# Store original shape before transformation
X_size = X.shape[:-1]

# Flatten everything except the last dimension
X = torch.transpose(X, transpose_dim, self.num_layers).contiguous()
orig_shape = X.shape[:-1]
X = X.view(-1, X.shape[-1])

# Apply transformation
X = layer(X)

# Reshape back to the original spatial structure (with new embedding dim)
X = X.view(*X_size, X.shape[-1])

# Transpose the dimension back to its original position
X = torch.transpose(X, transpose_dim, num_transforms).contiguous()
X = X.view(*orig_shape, X.shape[-1])
X = torch.transpose(X, transpose_dim, self.num_layers).contiguous()

return X
assert X.shape[1:] == self.hidden_size, f"Expected shape {self.hidden_size}, got {X.shape[1:]}"
return X
168 changes: 67 additions & 101 deletions ndlinear_gated.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,22 @@
import torch
import torch.nn as nn
from typing import Literal, Tuple, Optional
from typing import Literal, Optional, Tuple


class NdLinearGated(nn.Module):
"""
NdLinearGated: A PyTorch layer for projecting tensors into multi-space representations with gating mechanisms.

Extends the NdLinear concept by incorporating gating mechanisms that control information flow.
This allows the model to selectively utilize transformations based on input characteristics,
enabling more adaptive and context-dependent multi-space representations.
"""
def __init__(self,
input_dims: tuple,
hidden_size: tuple,
input_dims: Tuple[int, ...],
hidden_size: Tuple[int, ...],
transform_outer: bool = True,
gating_mode: Literal["soft", "hard"] = "soft",
gating_hidden_dim: int = 16,
gated_modes: Literal["all", "first", "topk"] = "all") -> None:
"""
Initialize the NdLinearGated layer.

Args:
input_dims: Shape of input tensor (excluding batch dimension).
hidden_size: Target hidden dimensions after transformation.
transform_outer: If True, transforms from outer to inner dimensions.
gating_mode: Type of gating mechanism - "soft" uses continuous values, "hard" uses binary.
gating_hidden_dim: Hidden dimension size for the gating networks.
gated_modes: Specifies which dimensions to apply gating to.
"""
gated_modes: Literal["all", "first", "topk", "none"] = "all",
topk: int = 2,
dropout_p: float = 0.1) -> None:
super(NdLinearGated, self).__init__()

if len(input_dims) != len(hidden_size):
raise Exception("Input shape and hidden shape do not match.")
raise ValueError("Input shape and hidden shape must have the same number of dimensions.")

self.input_dims = input_dims
self.hidden_size = hidden_size
Expand All @@ -41,11 +25,14 @@ def __init__(self,
self.gating_mode = gating_mode
self.gating_hidden_dim = gating_hidden_dim
self.gated_modes = gated_modes

self.topk = topk
self.topk_modes = None
self.first_batch_processed = False

self.align_layers = nn.ModuleList([
nn.Linear(input_dims[i], hidden_size[i]) for i in range(self.num_layers)
])

self.gate_networks = nn.ModuleList([
nn.Sequential(
nn.Linear(input_dims[i], gating_hidden_dim),
Expand All @@ -54,105 +41,84 @@ def __init__(self,
nn.Sigmoid()
) for i in range(self.num_layers)
])

self.identity_projections = nn.ModuleList([
nn.Linear(input_dims[i], hidden_size[i]) if input_dims[i] != hidden_size[i] else nn.Identity()
for i in range(self.num_layers)
])
self.topk_modes = None
self.first_batch_processed = False
def _compute_topk_modes(self, X: torch.Tensor) -> list:

self.dropout = nn.Dropout(dropout_p)

def reset_topk(self, X: torch.Tensor):
"""Recomputes top-k dimension indices based on std of mean values."""
mode_stds = []
for i in range(self.num_layers):
transpose_dim = i + 1 if self.transform_outer else self.num_layers - i
X_transposed = torch.transpose(X, transpose_dim, self.num_layers)
X_mean = X_transposed.mean(dim=tuple(range(len(X_transposed.shape) - 1)))
dim = i + 1 if self.transform_outer else self.num_layers - i
Xt = torch.transpose(X, dim, self.num_layers)
X_mean = Xt.mean(dim=tuple(range(len(Xt.shape) - 1)))
mode_stds.append(X_mean.std().item())

sorted_modes = sorted(range(len(mode_stds)), key=lambda i: mode_stds[i], reverse=True)
return sorted_modes[:2]

self.topk_modes = sorted(range(self.num_layers), key=lambda i: mode_stds[i], reverse=True)[:self.topk]
self.first_batch_processed = True

def _transform_step(self, X: torch.Tensor, layer_idx: int, transpose_dim: int, apply_gating: bool) -> torch.Tensor:
X_identity = X # no clone
X = torch.transpose(X, transpose_dim, self.num_layers).contiguous()
X_size = X.shape[:-1]
X_flat = X.view(-1, X.shape[-1])
X_transformed = self.align_layers[layer_idx](X_flat).view(*X_size, -1)

if not apply_gating:
return torch.transpose(X_transformed, transpose_dim, self.num_layers).contiguous()

# Gating value
gate = self.gate_networks[layer_idx](X_flat.mean(dim=0, keepdim=True)) # shape [1, 1]
gate = self.dropout(gate)

X_identity_transposed = torch.transpose(X_identity, transpose_dim, self.num_layers).contiguous()
X_identity_flat = X_identity_transposed.view(-1, X_identity_transposed.shape[-1])

identity_flat = (self.identity_projections[layer_idx](X_identity_flat)
if self.input_dims[layer_idx] != self.hidden_size[layer_idx]
else X_identity_flat)

if self.gating_mode == "soft":
out_flat = gate * X_transformed.view(-1, X_transformed.shape[-1]) + \
(1 - gate) * identity_flat
else:
out_flat = torch.where(gate > 0.5,
X_transformed.view(-1, X_transformed.shape[-1]),
identity_flat)

X_out = out_flat.view(*X_size, -1)
return torch.transpose(X_out, transpose_dim, self.num_layers).contiguous()

def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Forward pass to project input tensor into a new multi-space representation with gating.

Applies dimensional transformations with selective gating based on the configured mode.
The gating mechanism allows the network to adaptively choose between transformed
representations and identity mappings.

Args:
X (torch.Tensor): Input tensor with shape [batch_size, *input_dims]

Returns:
torch.Tensor: Output tensor with shape [batch_size, *hidden_size]
"""
num_transforms = self.num_layers

if self.gated_modes == "topk" and not self.first_batch_processed:
self.topk_modes = self._compute_topk_modes(X)
self.first_batch_processed = True

for i in range(num_transforms):
self.reset_topk(X)

for i in range(self.num_layers):
if self.transform_outer:
layer_idx = i
transpose_dim = i + 1
else:
layer_idx = num_transforms - (i+1)
transpose_dim = num_transforms - i

layer_idx = self.num_layers - i - 1
transpose_dim = self.num_layers - i

# Determine if gating is applied
apply_gating = False
if self.gated_modes == "all":
apply_gating = True
elif self.gated_modes == "first" and i == 0:
apply_gating = True
elif self.gated_modes == "topk" and self.topk_modes and layer_idx in self.topk_modes:
apply_gating = True

X_original = X.clone()

X = torch.transpose(X, transpose_dim, num_transforms).contiguous()

X_size = X.shape[:-1]

X_flat = X.view(-1, X.shape[-1])

X_transformed = self.align_layers[layer_idx](X_flat)

if apply_gating:
X_mean = X_flat.mean(dim=0, keepdim=True)

gate = self.gate_networks[layer_idx](X_mean)

X_transformed = X_transformed.view(*X_size, X_transformed.shape[-1])

X_identity = torch.transpose(X_original, transpose_dim, num_transforms).contiguous()

X_identity_flat = X_identity.view(-1, X_identity.shape[-1])

if X_transformed.shape[-1] != X_identity_flat.shape[-1]:
identity_flat = self.identity_projections[layer_idx](X_identity_flat)
else:
identity_flat = X_identity_flat

if self.gating_mode == "soft":
X_flat = gate * X_transformed.view(-1, X_transformed.shape[-1]) + (1 - gate) * identity_flat
else:
X_flat = torch.where(gate > 0.5,
X_transformed.view(-1, X_transformed.shape[-1]),
identity_flat)

X = X_flat.view(*X_size, X_flat.shape[-1])
else:
X = X_transformed.view(*X_size, X_transformed.shape[-1])

X = torch.transpose(X, transpose_dim, num_transforms).contiguous()


X = self._transform_step(X, layer_idx, transpose_dim, apply_gating)

return X

def __repr__(self) -> str:
return (f"{self.__class__.__name__}(input_dims={self.input_dims}, "
f"hidden_size={self.hidden_size}, transform_outer={self.transform_outer}, "
f"gating_mode={self.gating_mode}, gating_hidden_dim={self.gating_hidden_dim}, "
f"gated_modes={self.gated_modes})")
f"gated_modes={self.gated_modes}, topk={self.topk})")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "NdLinear"
version = "0.1.0"
description = "An environment setup to support NdLinear. "
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.11"
dependencies = [
"accelerate>=1.5.2",
"einops>=0.8.1",
Expand Down