A PyTorch implementation of Factorization Machines (FM) with custom autograd function for efficient training.
Factorization Machines (FM) are a class of models that capture feature interactions using factorized parameters. This implementation offers:
- Efficient computation of second-order feature interactions
- Custom PyTorch autograd function for optimized backward pass
- Simple API similar to standard PyTorch modules
The model is defined as:
Where:
-
$b$ is the bias term -
$w_i$ are the weights of the linear terms -
$\boldsymbol{v}_i$ are k-dimensional factorized vectors -
$\langle \cdot, \cdot \rangle$ denotes the dot product
This implementation includes a custom FactorizationMachineFunction that efficiently computes both the forward pass and the gradients for backpropagation. The second-order interaction term is calculated using the formula:
This reduces the computational complexity from
- S. Rendle, Factorization Machines, in 2010 IEEE International Conference on Data Mining (IEEE, 2010), pp. 995–1000.
import torch
from fm_torch import SecondOrderFactorizationMachine
# Initialize model
dim_input = 10 # Input feature dimension
dim_factors = 8 # Latent factor dimension
model = SecondOrderFactorizationMachine(dim_input, dim_factors)
# Forward pass
batch_size = 32
x = torch.randn(batch_size, dim_input)
y_pred = model(x) # Shape: (batch_size,)
# Training
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
# Training loop
y_true = torch.randn(batch_size) # Replace with actual labels
epochs = 100
for epoch in range(epochs):
optimizer.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y_true)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')# Clone the repository
git clone https://github.com/yourusername/fm-torch.git
cd fm-torch
# If you are using mise, trust and install the dependencies
mise trust
mise install
# Set up development environment
uv syncThis project uses:
misefor development environment managementtaskfor running common development tasksuvfor python package managementrufffor linting and formattingmypyfor type checking
# Format code
task format
# Check code style
task check
# Fix autofixable issues
task fix
# Prepare and commit
task commit:prepare:srcThis project is licensed under the MIT License. See the LICENSE file for details.