Skip to content

Latest commit

 

History

History
66 lines (54 loc) · 2.07 KB

File metadata and controls

66 lines (54 loc) · 2.07 KB

Bridging Lottery Ticket and Grokking: Is Weight Norm Sufficient to Explain Delayed Generalization?


by Gouki Minegishi, Yusuke Iwasawa, Yutaka Matsuo
arXiv link

Setup

  1. Create a virtual environment using Python 3.7.4. You can use either pyvenv or conda for this.
    python -m venv env   # For pyvenv
    conda create -n myenv python=3.7.4   # For conda
  2. install the required dependencies.
    pip install -r requirements.txt

Configuration

  • configs/config.py : Modular Addition task
  • configs/config_mnist.py : Mnist Cllasification task

Training Base Model (Dense)

Modular addition

python train.py --config configs/config.py

Training confguration is written in config/config.py.

Mnist

python train_mnist.py --config configs/config_mnist.py

Training confguration is written in config/config_mnist.py.

Grokking Tickets

Modular addition

python prune.py --config configs/config_pruning.py

Training confguration is written in config/config_pruning.py.

Mnist

python prune_mnist.py --config configs/config_pruning_mnist.py

Training confguration is written in config/config_pruning_mnist.py.

Relusts

You can check the experimental results from wandb.    The following figure compares the Base Model (Dense) and the Grokking Ticket. It can be observed that the Grokking Ticket almost eliminates delayed generalization.

Visualize

The following command visualizes the difference in the acquisition dynamics of representations between the Base Model (left) and the Grokking Ticket (right).

python visualize.py --grok_weight_path <path to grok weight> ----weight_folder <path to base weight folder> --ticket_folder <path to ticket folder> --output_folder <path to output folder>