Skip to content
Merged
Binary file added docs/assets/images/black_logo_description.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/assets/images/black_logo_illia.png
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to adjust the image resolution so that it is at the edge of the content?

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/assets/images/logo_black_small.png
Binary file not shown.
Binary file added docs/assets/images/white_logo_description.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/assets/images/white_logo_illia.png
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to adjust the image resolution so that it is at the edge of the content?

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
70 changes: 28 additions & 42 deletions docs/examples/Computer Vision/MNIST Bayesian CNN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "b3342740",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.2.1\n",
"0.2.4\n",
"torch\n"
]
}
Expand All @@ -39,7 +39,7 @@
"os.environ[\"ILLIA_BACKEND\"] = \"torch\"\n",
"\n",
"import illia\n",
"from illia.nn import Conv2d, Linear\n",
"from illia.nn import Conv2d, Linear, ReLU, MaxPool2d, Dropout\n",
"\n",
"import numpy as np\n",
"import torch\n",
Expand All @@ -65,7 +65,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 13,
"id": "1d0963f5",
"metadata": {},
"outputs": [
Expand All @@ -92,7 +92,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 14,
"id": "83ba2071",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -125,7 +125,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "c96c969c",
"metadata": {},
"outputs": [
Expand All @@ -150,23 +150,27 @@
" self.fc1 = Linear(64 * 7 * 7, 128)\n",
" self.fc2 = Linear(128, 10)\n",
"\n",
" # Activation and pooling\n",
" self.relu = ReLU()\n",
" self.pool = MaxPool2d(2)\n",
"\n",
" # Dropout for regularization\n",
" self.dropout = nn.Dropout(0.25)\n",
" self.dropout = Dropout(0.25)\n",
"\n",
" def forward(self, x):\n",
" # First conv + ReLU + MaxPool\n",
" x = F.relu(self.conv1(x))\n",
" x = F.max_pool2d(x, 2)\n",
" x = self.relu(self.conv1(x))\n",
" x = self.pool(x)\n",
"\n",
" # Second conv + ReLU + MaxPool\n",
" x = F.relu(self.conv2(x))\n",
" x = F.max_pool2d(x, 2)\n",
" x = self.relu(self.conv2(x))\n",
" x = self.pool(x)\n",
"\n",
" # Flatten before fully connected layers\n",
" x = x.view(x.size(0), -1)\n",
"\n",
" # Fully connected + dropout\n",
" x = F.relu(self.fc1(x))\n",
" x = self.relu(self.fc1(x))\n",
" x = self.dropout(x)\n",
" x = self.fc2(x)\n",
"\n",
Expand All @@ -189,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 16,
"id": "4b0b9fed",
"metadata": {},
"outputs": [],
Expand All @@ -212,7 +216,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 17,
"id": "b80d4785",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -285,7 +289,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "f7021581",
"metadata": {},
"outputs": [
Expand All @@ -297,27 +301,9 @@
"\n",
"Epoch 1/2\n",
"--------------------------------------------------\n",
"Batch 0/938, Loss: 3.542201, Accuracy: 9.38%\n",
"Batch 200/938, Loss: 0.405245, Accuracy: 83.38%\n",
"Batch 400/938, Loss: 0.096597, Accuracy: 89.36%\n",
"Batch 600/938, Loss: 0.062778, Accuracy: 91.73%\n",
"Batch 800/938, Loss: 0.162707, Accuracy: 93.07%\n",
"\n",
"Test Loss: 0.043868, Test Accuracy: 98.52%\n",
"\n",
"Epoch 1 - Train Loss: 0.206314, Train Acc: 93.69%, Test Acc: 98.52%\n",
"Epoch 2/2\n",
"--------------------------------------------------\n",
"Batch 0/938, Loss: 0.052432, Accuracy: 98.44%\n",
"Batch 200/938, Loss: 0.013136, Accuracy: 98.00%\n",
"Batch 400/938, Loss: 0.008689, Accuracy: 98.06%\n",
"Batch 600/938, Loss: 0.035820, Accuracy: 98.03%\n",
"Batch 800/938, Loss: 0.024955, Accuracy: 98.05%\n",
"\n",
"Test Loss: 0.033828, Test Accuracy: 98.91%\n",
"\n",
"Epoch 2 - Train Loss: 0.060348, Train Acc: 98.10%, Test Acc: 98.91%\n",
"Training completed!\n"
"Batch 0/938, Loss: 5.569467, Accuracy: 12.50%\n",
"Batch 200/938, Loss: 0.297848, Accuracy: 82.00%\n",
"Batch 400/938, Loss: 0.080678, Accuracy: 88.43%\n"
]
}
],
Expand Down Expand Up @@ -365,7 +351,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "45bc2223",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -487,7 +473,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "009a2b4d",
"metadata": {},
"outputs": [],
Expand All @@ -513,7 +499,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "e21d3a10",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -553,7 +539,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "21e1f116",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -596,7 +582,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"id": "93e284d5",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -642,7 +628,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"id": "f5c224c5",
"metadata": {},
"outputs": [
Expand Down
13 changes: 11 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<p align="center">
<img src="./assets/images/white_logo_illia.png" class="logo-white" height="200" width="200"/>
<img src="./assets/images/black_logo_illia.png" class="logo-black" height="200" width="200"/>
<img src="./assets/images/white_logo_description.png" class="logo-white" height="400" width="400"/>
<img src="./assets/images/black_logo_description.png" class="logo-black" height="400" width="400"/>
<br />
</p>

Expand Down Expand Up @@ -79,6 +79,15 @@ print(f"Output std: {outputs.std()}")
print(f"Output var: {outputs.var()}")
```

## Non-Parametric Definitions

**illia** provides non-parametric layers (pooling, activation, normalization, regularization, and utility layers) imported from `illia.nn` using PyTorch-style naming conventions. However, the initialization parameters and API usage are backend-specific and not standardized across PyTorch, TensorFlow, and JAX.

Key characteristics:
- **PyTorch Naming**: Layer names follow PyTorch conventions (e.g., MaxPool2d, ReLU, BatchNorm2d)
- **Backend-Specific Parameters**: Initialization arguments vary by backend
- **Non-Standard API**: JAX uses functional definitions, while PyTorch and TensorFlow use class/object-oriented layers

## Contributing

We welcome contributions from the community! Whether you're fixing bugs, adding
Expand Down
26 changes: 8 additions & 18 deletions illia/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,19 @@
from illia import BackendManager


# Obtain the library to import
def __getattr__(name: str) -> None:
def __getattr__(name: str) -> Any:
"""
Dynamically import a class from backend distributions.
Dynamically import a class from backend.

Args:
name: Name of the class to be imported.

Returns:
None.
The requested layer/module class.
"""
backend = BackendManager.get_backend()
module = BackendManager.get_backend_module(backend, "nn")
layer_class = BackendManager.get_class(backend, name, "nn", module)

# Obtain parameters for nn
module_type: str = "nn"
backend: str = BackendManager.get_backend()
module_path: Any | dict[str, Any] = BackendManager.get_backend_module(
backend, module_type
)

# Set class to global namespace
globals()[name] = BackendManager.get_class(
backend_name=backend,
class_name=name,
module_type=module_type,
module_path=module_path,
)
globals()[name] = layer_class
return layer_class
18 changes: 17 additions & 1 deletion illia/nn/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,35 @@
"""

# Own modules
from illia.nn.jax.activation import GELU, ReLU, Sigmoid, Tanh
from illia.nn.jax.base import BayesianModule
from illia.nn.jax.conv1d import Conv1d
from illia.nn.jax.conv2d import Conv2d
from illia.nn.jax.embedding import Embedding
from illia.nn.jax.linear import Linear
from illia.nn.jax.lstm import LSTM
from illia.nn.jax.normalization import BatchNorm1d, BatchNorm2d, LayerNorm
from illia.nn.jax.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
from illia.nn.jax.regularization import Dropout


__all__: list[str] = [
"AvgPool1d",
"AvgPool2d",
"BatchNorm1d",
"BatchNorm2d",
"BayesianModule",
"Conv1d",
"Conv2d",
"Dropout",
"Embedding",
"Linear",
"GELU",
"LSTM",
"LayerNorm",
"Linear",
"MaxPool1d",
"MaxPool2d",
"ReLU",
"Sigmoid",
"Tanh",
]
35 changes: 35 additions & 0 deletions illia/nn/jax/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""JAX activation layer wrappers."""

# Standard libraries
from functools import partial

# 3pps
from flax import nnx


class ReLU:
"""Wrapper for JAX relu."""

def __new__(cls, *args, **kwargs):
return partial(nnx.relu, *args, **kwargs)


class Sigmoid:
"""Wrapper for JAX sigmoid."""

def __new__(cls, *args, **kwargs):
return partial(nnx.sigmoid, *args, **kwargs)


class Tanh:
"""Wrapper for JAX tanh."""

def __new__(cls, *args, **kwargs):
return partial(nnx.tanh, *args, **kwargs)


class GELU:
"""Wrapper for JAX gelu."""

def __new__(cls, *args, **kwargs):
return partial(nnx.gelu, *args, **kwargs)
25 changes: 25 additions & 0 deletions illia/nn/jax/normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""JAX normalization layer wrappers."""

# 3pps
from flax import nnx


class BatchNorm1d:
"""Wrapper for JAX BatchNorm for 1D inputs."""

def __new__(cls, num_features, *args, **kwargs):
return nnx.BatchNorm(num_features, *args, **kwargs)


class BatchNorm2d:
"""Wrapper for JAX BatchNorm for 2D inputs."""

def __new__(cls, num_features, *args, **kwargs):
return nnx.BatchNorm(num_features, *args, **kwargs)


class LayerNorm:
"""Wrapper for JAX LayerNorm."""

def __new__(cls, *args, **kwargs):
return nnx.LayerNorm(*args, **kwargs)
35 changes: 35 additions & 0 deletions illia/nn/jax/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""JAX pooling layer wrappers."""

# Standard libraries
from functools import partial

# 3pps
import flax.nnx as nnx


class MaxPool1d:
"""Wrapper for JAX max_pool with 1D window."""

def __new__(cls, *args, **kwargs):
return partial(nnx.max_pool, *args, **kwargs)


class MaxPool2d:
"""Wrapper for JAX max_pool with 2D window."""

def __new__(cls, *args, **kwargs):
return partial(nnx.max_pool, *args, **kwargs)


class AvgPool1d:
"""Wrapper for JAX avg_pool with 1D window."""

def __new__(cls, *args, **kwargs):
return partial(nnx.avg_pool, *args, **kwargs)


class AvgPool2d:
"""Wrapper for JAX avg_pool with 2D window."""

def __new__(cls, *args, **kwargs):
return partial(nnx.avg_pool, *args, **kwargs)
Loading
Loading