From 9746181951c824fb6c4b586837bb6f3d032a1cce Mon Sep 17 00:00:00 2001 From: Eraly Date: Sat, 14 Jun 2025 16:12:34 +0500 Subject: [PATCH 1/5] Create READ --- READ | 1 + 1 file changed, 1 insertion(+) create mode 100644 READ diff --git a/READ b/READ new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/READ @@ -0,0 +1 @@ + From b94fba14e5a5b845e0a92cc3b0b348b420a8d6e8 Mon Sep 17 00:00:00 2001 From: Eraly Date: Sat, 14 Jun 2025 16:13:41 +0500 Subject: [PATCH 2/5] Delete READ --- READ | 1 - 1 file changed, 1 deletion(-) delete mode 100644 READ diff --git a/READ b/READ deleted file mode 100644 index 8b13789..0000000 --- a/READ +++ /dev/null @@ -1 +0,0 @@ - From 89d7e0e6b0c734ec884defb72ca6b72f06a7ee1e Mon Sep 17 00:00:00 2001 From: Eraly Date: Sat, 14 Jun 2025 16:13:59 +0500 Subject: [PATCH 3/5] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7dd4a06..a5995b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "NdLinear" version = "0.1.0" description = "An environment setup to support NdLinear. " readme = "README.md" -requires-python = ">=3.12" +requires-python = ">=3.11" dependencies = [ "accelerate>=1.5.2", "einops>=0.8.1", From afb80a8e55ef406f59dae7ae1619da641b771090 Mon Sep 17 00:00:00 2001 From: Eraly Date: Sat, 14 Jun 2025 16:16:26 +0500 Subject: [PATCH 4/5] Update ndlinear.py --- ndlinear.py | 80 +++++++++++++++-------------------------------------- 1 file changed, 23 insertions(+), 57 deletions(-) diff --git a/ndlinear.py b/ndlinear.py index dc9e68b..c10fe28 100644 --- a/ndlinear.py +++ b/ndlinear.py @@ -3,79 +3,45 @@ import torch.nn as nn import torch.optim as optim - class NdLinear(nn.Module): - def __init__(self, input_dims: tuple, hidden_size: tuple, transform_outer=True): + def __init__(self, input_dims: tuple, hidden_size: tuple, transform_outer=True, bias=True): """ - NdLinear: A PyTorch layer for projecting tensors into multi-space representations. - - Unlike conventional embedding layers that map into a single vector space, NdLinear - transforms tensors across a collection of vector spaces, capturing multivariate structure - and topical information that standard deep learning architectures typically lose. + NdLinear: Multidimensional Linear Projection Layer. + Projects tensor inputs across multiple dimensions independently. Args: - input_dims (tuple): Shape of input tensor (excluding batch dimension). - hidden_size (tuple): Target hidden dimensions after transformation. + input_dims (tuple): Shape of the input excluding batch dimension. + hidden_size (tuple): Shape after transformation (must be same length as input_dims). + transform_outer (bool): Whether to transform outer dims first (True) or inner (False). + bias (bool): Whether to use bias in each Linear layer. """ - super(NdLinear, self).__init__() - + super().__init__() if len(input_dims) != len(hidden_size): - raise Exception("Input shape and hidden shape do not match.") - + raise ValueError("Input and hidden shape must have the same rank.") + self.input_dims = input_dims self.hidden_size = hidden_size - self.num_layers = len(input_dims) # Must match since dims are equal + self.num_layers = len(input_dims) self.transform_outer = transform_outer - # Define transformation layers per dimension self.align_layers = nn.ModuleList([ - nn.Linear(input_dims[i], hidden_size[i]) for i in range(self.num_layers) + nn.Linear(input_dims[i], hidden_size[i], bias=bias) + for i in range(self.num_layers) ]) - def forward(self, X): - """ - Forward pass to project input tensor into a new multi-space representation. - - Incrementally transposes, flattens, applies linear layers, and restores shape. + transform_indices = range(self.num_layers) if self.transform_outer else reversed(range(self.num_layers)) - Expected Input Shape: [batch_size, *input_dims] - Output Shape: [batch_size, *hidden_size] - - Args: - X (torch.Tensor): Input tensor with shape [batch_size, *input_dims] + for i in transform_indices: + layer = self.align_layers[i] + transpose_dim = i + 1 # +1 to skip batch - Returns: - torch.Tensor: Output tensor with shape [batch_size, *hidden_size] - """ - num_transforms = self.num_layers # Number of transformations - - # Define iteration order - # transform_indices = range(num_transforms) if transform_outer else reversed(range(num_transforms)) - - for i in range(num_transforms): - if self.transform_outer: - layer = self.align_layers[i] - transpose_dim = i + 1 - else: - layer = self.align_layers[num_transforms - (i+1)] - transpose_dim = num_transforms - i - - # Transpose the selected dimension to the last position - X = torch.transpose(X, transpose_dim, num_transforms).contiguous() - - # Store original shape before transformation - X_size = X.shape[:-1] - - # Flatten everything except the last dimension + X = torch.transpose(X, transpose_dim, self.num_layers).contiguous() + orig_shape = X.shape[:-1] X = X.view(-1, X.shape[-1]) - - # Apply transformation X = layer(X) - - # Reshape back to the original spatial structure (with new embedding dim) - X = X.view(*X_size, X.shape[-1]) - - # Transpose the dimension back to its original position - X = torch.transpose(X, transpose_dim, num_transforms).contiguous() + X = X.view(*orig_shape, X.shape[-1]) + X = torch.transpose(X, transpose_dim, self.num_layers).contiguous() - return X \ No newline at end of file + assert X.shape[1:] == self.hidden_size, f"Expected shape {self.hidden_size}, got {X.shape[1:]}" + return X From 8708c53f2d50a46824af39f991a6cb30cbae1983 Mon Sep 17 00:00:00 2001 From: Eraly Date: Sat, 14 Jun 2025 16:18:11 +0500 Subject: [PATCH 5/5] Update ndlinear_gated.py --- ndlinear_gated.py | 168 ++++++++++++++++++---------------------------- 1 file changed, 67 insertions(+), 101 deletions(-) diff --git a/ndlinear_gated.py b/ndlinear_gated.py index 66ca1b2..ea3732c 100644 --- a/ndlinear_gated.py +++ b/ndlinear_gated.py @@ -1,38 +1,22 @@ import torch import torch.nn as nn -from typing import Literal, Tuple, Optional +from typing import Literal, Optional, Tuple class NdLinearGated(nn.Module): - """ - NdLinearGated: A PyTorch layer for projecting tensors into multi-space representations with gating mechanisms. - - Extends the NdLinear concept by incorporating gating mechanisms that control information flow. - This allows the model to selectively utilize transformations based on input characteristics, - enabling more adaptive and context-dependent multi-space representations. - """ def __init__(self, - input_dims: tuple, - hidden_size: tuple, + input_dims: Tuple[int, ...], + hidden_size: Tuple[int, ...], transform_outer: bool = True, gating_mode: Literal["soft", "hard"] = "soft", gating_hidden_dim: int = 16, - gated_modes: Literal["all", "first", "topk"] = "all") -> None: - """ - Initialize the NdLinearGated layer. - - Args: - input_dims: Shape of input tensor (excluding batch dimension). - hidden_size: Target hidden dimensions after transformation. - transform_outer: If True, transforms from outer to inner dimensions. - gating_mode: Type of gating mechanism - "soft" uses continuous values, "hard" uses binary. - gating_hidden_dim: Hidden dimension size for the gating networks. - gated_modes: Specifies which dimensions to apply gating to. - """ + gated_modes: Literal["all", "first", "topk", "none"] = "all", + topk: int = 2, + dropout_p: float = 0.1) -> None: super(NdLinearGated, self).__init__() if len(input_dims) != len(hidden_size): - raise Exception("Input shape and hidden shape do not match.") + raise ValueError("Input shape and hidden shape must have the same number of dimensions.") self.input_dims = input_dims self.hidden_size = hidden_size @@ -41,11 +25,14 @@ def __init__(self, self.gating_mode = gating_mode self.gating_hidden_dim = gating_hidden_dim self.gated_modes = gated_modes - + self.topk = topk + self.topk_modes = None + self.first_batch_processed = False + self.align_layers = nn.ModuleList([ nn.Linear(input_dims[i], hidden_size[i]) for i in range(self.num_layers) ]) - + self.gate_networks = nn.ModuleList([ nn.Sequential( nn.Linear(input_dims[i], gating_hidden_dim), @@ -54,54 +41,70 @@ def __init__(self, nn.Sigmoid() ) for i in range(self.num_layers) ]) - + self.identity_projections = nn.ModuleList([ nn.Linear(input_dims[i], hidden_size[i]) if input_dims[i] != hidden_size[i] else nn.Identity() for i in range(self.num_layers) ]) - - self.topk_modes = None - self.first_batch_processed = False - - def _compute_topk_modes(self, X: torch.Tensor) -> list: + + self.dropout = nn.Dropout(dropout_p) + + def reset_topk(self, X: torch.Tensor): + """Recomputes top-k dimension indices based on std of mean values.""" mode_stds = [] for i in range(self.num_layers): - transpose_dim = i + 1 if self.transform_outer else self.num_layers - i - X_transposed = torch.transpose(X, transpose_dim, self.num_layers) - X_mean = X_transposed.mean(dim=tuple(range(len(X_transposed.shape) - 1))) + dim = i + 1 if self.transform_outer else self.num_layers - i + Xt = torch.transpose(X, dim, self.num_layers) + X_mean = Xt.mean(dim=tuple(range(len(Xt.shape) - 1))) mode_stds.append(X_mean.std().item()) - - sorted_modes = sorted(range(len(mode_stds)), key=lambda i: mode_stds[i], reverse=True) - return sorted_modes[:2] - + self.topk_modes = sorted(range(self.num_layers), key=lambda i: mode_stds[i], reverse=True)[:self.topk] + self.first_batch_processed = True + + def _transform_step(self, X: torch.Tensor, layer_idx: int, transpose_dim: int, apply_gating: bool) -> torch.Tensor: + X_identity = X # no clone + X = torch.transpose(X, transpose_dim, self.num_layers).contiguous() + X_size = X.shape[:-1] + X_flat = X.view(-1, X.shape[-1]) + X_transformed = self.align_layers[layer_idx](X_flat).view(*X_size, -1) + + if not apply_gating: + return torch.transpose(X_transformed, transpose_dim, self.num_layers).contiguous() + + # Gating value + gate = self.gate_networks[layer_idx](X_flat.mean(dim=0, keepdim=True)) # shape [1, 1] + gate = self.dropout(gate) + + X_identity_transposed = torch.transpose(X_identity, transpose_dim, self.num_layers).contiguous() + X_identity_flat = X_identity_transposed.view(-1, X_identity_transposed.shape[-1]) + + identity_flat = (self.identity_projections[layer_idx](X_identity_flat) + if self.input_dims[layer_idx] != self.hidden_size[layer_idx] + else X_identity_flat) + + if self.gating_mode == "soft": + out_flat = gate * X_transformed.view(-1, X_transformed.shape[-1]) + \ + (1 - gate) * identity_flat + else: + out_flat = torch.where(gate > 0.5, + X_transformed.view(-1, X_transformed.shape[-1]), + identity_flat) + + X_out = out_flat.view(*X_size, -1) + return torch.transpose(X_out, transpose_dim, self.num_layers).contiguous() + def forward(self, X: torch.Tensor) -> torch.Tensor: - """ - Forward pass to project input tensor into a new multi-space representation with gating. - - Applies dimensional transformations with selective gating based on the configured mode. - The gating mechanism allows the network to adaptively choose between transformed - representations and identity mappings. - - Args: - X (torch.Tensor): Input tensor with shape [batch_size, *input_dims] - - Returns: - torch.Tensor: Output tensor with shape [batch_size, *hidden_size] - """ - num_transforms = self.num_layers - if self.gated_modes == "topk" and not self.first_batch_processed: - self.topk_modes = self._compute_topk_modes(X) - self.first_batch_processed = True - - for i in range(num_transforms): + self.reset_topk(X) + + for i in range(self.num_layers): if self.transform_outer: layer_idx = i transpose_dim = i + 1 else: - layer_idx = num_transforms - (i+1) - transpose_dim = num_transforms - i - + layer_idx = self.num_layers - i - 1 + transpose_dim = self.num_layers - i + + # Determine if gating is applied apply_gating = False if self.gated_modes == "all": apply_gating = True @@ -109,50 +112,13 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: apply_gating = True elif self.gated_modes == "topk" and self.topk_modes and layer_idx in self.topk_modes: apply_gating = True - - X_original = X.clone() - - X = torch.transpose(X, transpose_dim, num_transforms).contiguous() - - X_size = X.shape[:-1] - - X_flat = X.view(-1, X.shape[-1]) - - X_transformed = self.align_layers[layer_idx](X_flat) - - if apply_gating: - X_mean = X_flat.mean(dim=0, keepdim=True) - - gate = self.gate_networks[layer_idx](X_mean) - - X_transformed = X_transformed.view(*X_size, X_transformed.shape[-1]) - - X_identity = torch.transpose(X_original, transpose_dim, num_transforms).contiguous() - - X_identity_flat = X_identity.view(-1, X_identity.shape[-1]) - - if X_transformed.shape[-1] != X_identity_flat.shape[-1]: - identity_flat = self.identity_projections[layer_idx](X_identity_flat) - else: - identity_flat = X_identity_flat - - if self.gating_mode == "soft": - X_flat = gate * X_transformed.view(-1, X_transformed.shape[-1]) + (1 - gate) * identity_flat - else: - X_flat = torch.where(gate > 0.5, - X_transformed.view(-1, X_transformed.shape[-1]), - identity_flat) - - X = X_flat.view(*X_size, X_flat.shape[-1]) - else: - X = X_transformed.view(*X_size, X_transformed.shape[-1]) - - X = torch.transpose(X, transpose_dim, num_transforms).contiguous() - + + X = self._transform_step(X, layer_idx, transpose_dim, apply_gating) + return X def __repr__(self) -> str: return (f"{self.__class__.__name__}(input_dims={self.input_dims}, " f"hidden_size={self.hidden_size}, transform_outer={self.transform_outer}, " f"gating_mode={self.gating_mode}, gating_hidden_dim={self.gating_hidden_dim}, " - f"gated_modes={self.gated_modes})") \ No newline at end of file + f"gated_modes={self.gated_modes}, topk={self.topk})")