by Gouki Minegishi, Yusuke Iwasawa, Yutaka Matsuo
arXiv link
- Create a virtual environment using Python 3.7.4. You can use either
pyvenvorcondafor this.python -m venv env # For pyvenv conda create -n myenv python=3.7.4 # For conda
- install the required dependencies.
pip install -r requirements.txt
- configs/config.py : Modular Addition task
- configs/config_mnist.py : Mnist Cllasification task
python train.py --config configs/config.pyTraining confguration is written in config/config.py.
python train_mnist.py --config configs/config_mnist.pyTraining confguration is written in config/config_mnist.py.
python prune.py --config configs/config_pruning.pyTraining confguration is written in config/config_pruning.py.
python prune_mnist.py --config configs/config_pruning_mnist.pyTraining confguration is written in config/config_pruning_mnist.py.
You can check the experimental results from wandb. The following figure compares the Base Model (Dense) and the Grokking Ticket. It can be observed that the Grokking Ticket almost eliminates delayed generalization.
The following command visualizes the difference in the acquisition dynamics of representations between the Base Model (left) and the Grokking Ticket (right).
python visualize.py --grok_weight_path <path to grok weight> ----weight_folder <path to base weight folder> --ticket_folder <path to ticket folder> --output_folder <path to output folder>


