Skip to content

Fine-tuning TinyLlama-1.1B-Chat using LoRA on the GSM8K dataset for solving grade school math word problems efficiently with low computational cost.

Notifications You must be signed in to change notification settings

chanupadeshan/LoRA-Fine-Tune

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Fine-Tune LoRA

A Python project for fine-tuning language models using LoRA (Low-Rank Adaptation) on the GSM8K mathematical reasoning dataset. This project demonstrates how to efficiently fine-tune large language models with minimal computational resources.

🚀 Features

  • LoRA Fine-tuning: Efficient parameter-efficient fine-tuning using LoRA
  • Mathematical Reasoning: Trained on GSM8K dataset for math problem solving
  • TinyLlama Model: Uses TinyLlama-1.1B-Chat-v1.0 as the base model
  • Modular Design: Clean, organized code structure
  • Easy Inference: Simple script for model inference

📁 Project Structure

├── config/
│   └── lora_config.py      # LoRA configuration settings
├── data/
│   └── prepare_dataset.py  # Dataset preparation utilities
├── model/
│   └── load_model.py       # Model and tokenizer loading
├── outputs/                # Training outputs and checkpoints
├── train.py               # Main training script
├── infer.py               # Inference script
└── README.md              # This file

🛠️ Installation

  1. Clone the repository:

    git clone <your-repo-url>
    cd Fine-Tune-LoRA
  2. Install dependencies:

    pip install transformers datasets peft torch accelerate
  3. Optional: Create a virtual environment:

    python -m venv venv
    source venv/bin/activate  # On Windows: venv\Scripts\activate
    pip install -r requirements.txt

🎯 Usage

Training

To start training the model:

python train.py

Training Configuration:

  • Base Model: TinyLlama-1.1B-Chat-v1.0
  • Dataset: GSM8K (mathematical reasoning)
  • LoRA Rank: 8
  • LoRA Alpha: 16
  • Target Modules: q_proj, v_proj
  • Batch Size: 1 (with gradient accumulation steps: 2)
  • Learning Rate: 0.0002
  • Epochs: 1

Inference

To run inference with the trained model:

python infer.py

The inference script includes an example question:

"If you have 2 apples and you buy 3 more apples, how many apples do you have in total?"

🔧 Configuration

LoRA Settings

The LoRA configuration is defined in config/lora_config.py:

LoraConfig(
    r=8,                    # LoRA rank
    lora_alpha=16,          # LoRA scaling parameter
    target_modules=["q_proj", "v_proj"],  # Target attention modules
    lora_dropout=0.05,      # Dropout rate
    bias="none",            # Bias handling
    task_type=TaskType.CAUSAL_LM
)

Training Parameters

Key training parameters in train.py:

  • per_device_train_batch_size=1: Batch size per device
  • gradient_accumulation_steps=2: Effective batch size = 1 × 2 = 2
  • num_train_epochs=1: Number of training epochs
  • learning_rate=0.0002: Learning rate for optimization
  • output_dir="./outputs": Output directory for checkpoints

📊 Dataset

This project uses the GSM8K dataset, which contains:

  • Mathematical reasoning problems
  • Step-by-step solutions
  • Training set: ~7,473 examples
  • Test set: ~1,319 examples

The dataset is automatically downloaded when running the training script.

🎯 Model Architecture

  • Base Model: TinyLlama-1.1B-Chat-v1.0
  • Fine-tuning Method: LoRA (Low-Rank Adaptation)
  • Target Modules: Query and Value projections in attention layers
  • Parameter Efficiency: Only ~0.1% of parameters are trained

📈 Training Process

  1. Data Preprocessing: Questions and answers are formatted as prompts
  2. Tokenization: Text is converted to token IDs with max length of 256
  3. LoRA Application: LoRA adapters are applied to the base model
  4. Training: Model is fine-tuned using the HuggingFace Trainer
  5. Checkpointing: Model checkpoints are saved after each epoch

🔍 Inference

The trained model can answer mathematical questions in the format:

Question: [Your math question]
Answer: [Model's response]

📁 Output Files

After training, the following files are generated in the outputs/ directory:

  • final_model/: Complete trained model with LoRA adapters
  • checkpoint-*/: Training checkpoints
  • trainer_state.json: Training state and metrics
  • training_args.bin: Training arguments
  • Various optimizer and scheduler states

🚨 Important Notes

  • Memory Requirements: Training requires sufficient GPU memory for the base model
  • Training Time: Depends on your hardware, typically 1-2 hours on a decent GPU
  • Model Size: The LoRA adapters are very small (~1-2 MB) compared to the full model
  • Git Ignore: Large model files are excluded from version control

🤝 Contributing

  1. Fork the repository
  2. Create a feature branch
  3. Make your changes
  4. Add tests if applicable
  5. Submit a pull request

🙏 Acknowledgments

About

Fine-tuning TinyLlama-1.1B-Chat using LoRA on the GSM8K dataset for solving grade school math word problems efficiently with low computational cost.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published