Create a python 3.10 environment. Use the follwing command to install required libraries.
pip install -r requirements.txt
This repository supports fine-tuning an LLM to generate SQL queries from natural language using LoRA and TRL's SFTTrainer.
After fine-tuning, you can test the model's performance on natural language to SQL generation.
The test.py script loads:
- The base model defined in
config/config.py - The fine-tuned LoRA adapter from
data/sql-sft-lora/checkpoint-1250
To test the model with a custom prompt:
python test.pyThe script uses the following example:
Prompt: Show all users with gender is female.
It returns the generated SQL query based on your fine-tuned model.
- The script uses
top_psampling and beam search for diversity and quality. - Results are printed to stdout. Dataset
We use the gretelai/synthetic_text_to_sql dataset. Each entry includes:
- A
sql_prompt: natural language question - A
sql: corresponding SQL query
- Model: By default,
TinyLlama/TinyLlama-1.1B-Chat-v1.0(can be changed inconfig/config.py) - LoRA: Lightweight fine-tuning using PEFT
- Trainer: TRL's
SFTTrainerwith gradient checkpointing, cosine scheduler, and TensorBoard logging
Make sure your config/config.py defines:
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
OUT_DIR = "data/sql-sft-lora"Run:
python fine_tune.pyThis will:
- Load and preprocess the dataset
- Create the tokenizer and trainer
- Start LoRA fine-tuning
- Save the model to the output directory
- Model checkpoints are saved to
data/sql-sft-lora - TensorBoard logs are available for monitoring training progress
To upload the app, follow the huggingFace_utils/upload2space.py script. Mind that you need the proper HF token and your repo on HF for uploading.
The files which are necessary to run the app on teh HF space can found inside huggingFace_utils/app.
- Follow the link https://huggingface.co/spaces/rat45/sql-sft-lora-model
- Enter a prompt like "Select all users where gender is male."
- In the output section you'll see the SQL query.
Since the model was trained only for an hour on a subset of the dataset. Some of the outputs will be imperfact.