-
Notifications
You must be signed in to change notification settings - Fork 1
Maya-Simhi/OLALa
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
Federated Learning with Trainable Lattice Quantization This project implements a Federated Learning (FL) framework that incorporates lattice quantization for compressing model updates. A key feature of this framework is the ability to learn the quantization parameters, such as the lattice generator matrix and scaling factors, as part of the training process. This allows the quantization scheme to adapt to the data and model, potentially improving communication efficiency while maintaining model performance. Table of Contents Project Overview Features File Structure Setup and Installation How to Run Configuration Options Core Concepts Project Overview In Federated Learning, multiple clients collaboratively train a model without sharing their raw data. They send model updates (weights or gradients) to a central server, which aggregates them to produce a global model. The communication of these updates can be a significant bottleneck. This project explores using lattice quantization to compress these model updates. Instead of transmitting high-precision floating-point numbers, weights are mapped to points on a multi-dimensional lattice, and only the index of the lattice point is sent. Furthermore, the project introduces a novel approach where the lattice structure itself is learned and optimized during training to minimize quantization error or maximize model accuracy. Features Federated Learning: Implements a standard federated averaging (FedAvg) workflow. Lattice Quantization: Uses multi-dimensional lattice quantization to compress model weights. Trainable Quantization Matrix: The generator matrix of the lattice can be trained to create a custom, data-aware quantization scheme. Trainable Scaling Factor: An MLP-based model (AlphhMLP) can be trained to learn an optimal scaling factor (overloading) for the quantization process. Flexible Training Strategies: Supports various strategies for training the models and the quantizer, including: Standard FL with fixed quantization. Offline training of quantization parameters followed by FL. Online (joint) training of the main model and the quantization parameters. Configurability: A wide range of parameters can be controlled via command-line arguments, including dataset, model architecture, number of users, and quantization settings. Datasets: Supports MNIST and CIFAR-10. Models: Includes Linear, MLP, and CNN architectures. File Structure Here is a brief description of the key files in the project: main.py The main entry point for running all federated learning experiments. It parses arguments and orchestrates the different training flows. configurations.py Defines all command-line arguments for configuring experiments. federatedUtils.py Handles the federated learning logic, including data distribution among clients, model aggregation, and applying quantization during aggregation. utils.py Contains general utility functions, including the client training loop (train_one_epoch), logging, and experiment initialization. quantizer.py Implements the LatticeQuantization class, which handles the core logic of mapping vectors to lattice codewords. DNNQuantize.py Defines an MLP (PyTorchMLP) used to represent the trainable generator matrix and includes the function (train_model) to train it. alphModel.py Defines the AlphhMLP model used for learning the adaptive scaling/overloading factor. models.py Contains the definitions for the neural network models (e.g., CNN2Layer, FC2Layer). dataGetUtils.py Provides helper functions for data handling and for flattening/restoring model weights. Setup and Installation Clone the repository Install the required dependencies: pip install torch torchvision numpy matplotlib tensorboardX torchinfo sympy vit-pytorch Note: Ensure you install a version of torch that is compatible with your CUDA version if you plan to use a GPU. How to Run All experiments are run from main.py. You can specify the desired configuration using command-line arguments. Core Concepts Federated Learning Loop The core FL process in main.py and federatedUtils.py follows these steps: The server initializes a global model. For each global epoch (communication round): a. The server sends the current global model to a set of clients. b. Each client trains the model on its local data for a number of iterations (utils.train_one_epoch). c. (Quantization Step): Each client's updated weights are quantized using the specified lattice scheme (federatedUtils.aggregate_models). This involves flattening the weights, dividing them into vectors of lattice_dim, and mapping them to the nearest codeword. d. The quantized updates are sent back to the server. e. The server aggregates the updates (e.g., by averaging) to produce the new global model. The process repeats.
About
No description, website, or topics provided.
Resources
Stars
Watchers
Forks
Releases
No releases published
Packages 0
No packages published