Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
- "docker/finetuning/rocm-silogen-finetuning-worker.Dockerfile"
- "requirements.txt"
- ".github/workflows/package.yml"
- "tests/"
- "tests/**"

jobs:
common:
Expand Down
18 changes: 12 additions & 6 deletions src/finetuning/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@ class ConcatenationDataInput(DataInput):
The datasets themselves need to be in the finetuning supported JSONL formats.
For SFT this means lines:

{"messages": {"content": "string", "role": "string"}}
{"messages": [{"content": "string", "role": "string"}]}

For DPO this means lines of:

{"prompt_messages": {"content": "string", "role": "string"}, "chosen_messages": {"content": "string", "role": "string"}, "rejected_messages": {"content": "string", "role": "string"}}
{
"prompt_messages": [{"content": "string", "role": "string"}],
"chosen_messages": [{"content": "string", "role": "string"}],
"rejected_messages": [{"content": "string", "role": "string"}]
}
"""

type: Literal[DataInputType.CONCATENATION]
Expand All @@ -89,11 +92,14 @@ class WeightedMixDataInput(DataInput):
The datasets themselves need to be in the finetuning supported JSONL formats.
For SFT this means lines:

{"messages": {"content": "string", "role": "string"}}
{"messages": [{"content": "string", "role": "string"}]}

For DPO this means lines of:

{"prompt_messages": {"content": "string", "role": "string"}, "chosen_messages": {"content": "string", "role": "string"}, "rejected_messages": {"content": "string", "role": "string"}}
{
"prompt_messages": [{"content": "string", "role": "string"}],
"chosen_messages": [{"content": "string", "role": "string"}],
"rejected_messages": [{"content": "string", "role": "string"}]
}
"""

type: Literal[DataInputType.PRECOMPUTE_WEIGHTED_MIX]
Expand Down
68 changes: 35 additions & 33 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_sft_gpt(tmpdir):
assert trainer_state["log_history"][-1]["loss"] < 4.0


def test_sft_mistral_lora(tmpdir):
def test_sft_gemma2_lora(tmpdir):
# 0. Config definition:
config = """\
data_conf:
Expand Down Expand Up @@ -267,11 +267,11 @@ def test_sft_mistral_lora(tmpdir):
- v_proj
- o_proj
run_conf:
model: hf-internal-testing/tiny-random-MistralForCausalLM
model: hf-internal-testing/tiny-random-Gemma2ForCausalLM
model_args:
attn_implementation: "eager"
use_cache: False
revision: 7517176934d50d00707900b92ef6a95c782e51a2
revision: de7c11b6c25d26ddd1bf4324fcf479b61d18e440
resume_from_checkpoint: False
"""
num_steps = 15
Expand All @@ -295,10 +295,10 @@ def test_sft_mistral_lora(tmpdir):
# 3. Inspect results!
with open(ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi:
trainer_state = json.loads(fi.read())
# At the start, loss should still be around -log(1/32002)~10.3, (Even the tiny Mistral has 32002 units after Chat-ML)
assert np.isclose(trainer_state["log_history"][0]["loss"], 10.3, atol=0.5)
# In current setup, loss goes at least below 2.5 by 15 steps
assert trainer_state["log_history"][-1]["loss"] < 2.5
# At the start, loss should still be around -log(1/256000)~12.4529327234617, (Even the tiny Gemma2 has 256000 units after Chat-ML)
assert np.isclose(trainer_state["log_history"][0]["loss"], 12.45, atol=0.5)
# In current setup, loss goes at least below 5.0 by 15 steps
assert trainer_state["log_history"][-1]["loss"] < 5.0
# LoRA setup should lead to adapter checkpoint
assert (ckpt_dir / "checkpoint-final" / "adapter_model.safetensors").isfile()
# And since the basemodel vocab was modified, there is a new basemodel checkpoint too
Expand All @@ -310,7 +310,7 @@ def intercept_tokenizer_call(*args, overwrite_chat_template=False, **kwargs):


@patch("finetuning.dpo.subsetup_tokenizer", intercept_tokenizer_call)
def test_dpo_mistral(tmpdir):
def test_dpo_gemma2(tmpdir):
# 0. DPO Config definition:
dpo_config = """\
data_conf:
Expand Down Expand Up @@ -350,11 +350,11 @@ def test_dpo_mistral(tmpdir):
peft_conf:
peft_type: "NO_PEFT"
run_conf:
model: hf-internal-testing/tiny-random-MistralForCausalLM
model: hf-internal-testing/tiny-random-Gemma2ForCausalLM
model_args:
attn_implementation: "eager"
use_cache: False
revision: 7517176934d50d00707900b92ef6a95c782e51a2
revision: de7c11b6c25d26ddd1bf4324fcf479b61d18e440
resume_from_checkpoint: False
"""
num_steps = 15
Expand Down Expand Up @@ -383,12 +383,12 @@ def test_dpo_mistral(tmpdir):
with open(dpo_ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi:
trainer_state = json.loads(fi.read())
# The margins start close to zero
assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.1)
# But they rise to >0.10
assert trainer_state["log_history"][-1]["rewards/margins"] > 0.10
assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.03)
# But they rise to >0.07
assert trainer_state["log_history"][-1]["rewards/margins"] > 0.07


def test_sft_and_dpo_mistral(tmpdir):
def test_sft_and_dpo_gemma2(tmpdir):
# 0. Config definition:
sft_config = """\
data_conf:
Expand Down Expand Up @@ -427,11 +427,11 @@ def test_sft_and_dpo_mistral(tmpdir):
peft_conf:
peft_type: "NO_PEFT"
run_conf:
model: hf-internal-testing/tiny-random-MistralForCausalLM
model: hf-internal-testing/tiny-random-Gemma2ForCausalLM
model_args:
attn_implementation: "eager"
use_cache: False
revision: 7517176934d50d00707900b92ef6a95c782e51a2
revision: de7c11b6c25d26ddd1bf4324fcf479b61d18e440
resume_from_checkpoint: False
"""
num_steps = 15
Expand All @@ -455,8 +455,8 @@ def test_sft_and_dpo_mistral(tmpdir):
# # 3. Inspect SFT results!
with open(sft_ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi:
trainer_state = json.loads(fi.read())
# At the start, loss should still be around -log(1/32002)~10.3, (Even the tiny Mistral has 32002 units after Chat-ML)
assert np.isclose(trainer_state["log_history"][0]["loss"], 10.3, atol=0.5)
# At the start, loss should still be around -log(1/256000)~12.4529327234617, (Even the tiny Gemma2 has 256000 units after Chat-ML)
assert np.isclose(trainer_state["log_history"][0]["loss"], 12.45, atol=0.5)
# In current setup, loss goes at least below 4.5 by 15 steps
assert trainer_state["log_history"][-1]["loss"] < 4.5

Expand Down Expand Up @@ -534,12 +534,12 @@ def test_sft_and_dpo_mistral(tmpdir):
with open(dpo_ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi:
trainer_state = json.loads(fi.read())
# The margins start close to zero
assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.1)
# But they rise to >0.15
assert trainer_state["log_history"][-1]["rewards/margins"] > 0.15
assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.03)
# But they rise to >0.07
assert trainer_state["log_history"][-1]["rewards/margins"] > 0.07


def test_sft_and_dpo_lora_mistral(tmpdir):
def test_sft_and_dpo_lora_gemma2(tmpdir):
# 0. Config definition:
sft_config = """\
data_conf:
Expand All @@ -561,6 +561,7 @@ def test_sft_and_dpo_lora_mistral(tmpdir):
gradient_checkpointing_kwargs:
use_reentrant: false
learning_rate: 0.5
logging_first_step: true
logging_steps: 1
logging_strategy: "steps"
lr_scheduler_type: "constant"
Expand All @@ -578,11 +579,11 @@ def test_sft_and_dpo_lora_mistral(tmpdir):
peft_conf:
peft_type: "NO_PEFT"
run_conf:
model: hf-internal-testing/tiny-random-MistralForCausalLM
model: hf-internal-testing/tiny-random-Gemma2ForCausalLM
model_args:
attn_implementation: "eager"
use_cache: False
revision: 7517176934d50d00707900b92ef6a95c782e51a2
revision: de7c11b6c25d26ddd1bf4324fcf479b61d18e440
resume_from_checkpoint: False
"""
num_steps = 15
Expand All @@ -606,8 +607,8 @@ def test_sft_and_dpo_lora_mistral(tmpdir):
# 3. Inspect SFT results!
with open(sft_ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi:
trainer_state = json.loads(fi.read())
# At the start, loss should still be around -log(1/32002)~10.3, (Even the tiny Mistral has 32002 units after Chat-ML)
assert np.isclose(trainer_state["log_history"][0]["loss"], 10.3, atol=0.5)
# At the start, loss should still be around -log(1/256000)~12.4529327234617, (Even the tiny Gemma2 has 256000 units after Chat-ML)
assert np.isclose(trainer_state["log_history"][0]["loss"], 12.45, atol=0.5)
# In current setup, loss goes at least below 4.5 by 15 steps
assert trainer_state["log_history"][-1]["loss"] < 4.5

Expand All @@ -633,6 +634,7 @@ def test_sft_and_dpo_lora_mistral(tmpdir):
gradient_checkpointing_kwargs:
use_reentrant: false
learning_rate: 0.5
logging_first_step: true
logging_steps: 1
logging_strategy: "steps"
lr_scheduler_type: "constant"
Expand Down Expand Up @@ -668,7 +670,7 @@ def test_sft_and_dpo_lora_mistral(tmpdir):
resume_from_checkpoint: False
"""
# note: potential to redefine
num_steps = 15
num_steps = 20
# 5. setup
# Set paths
dpotmpdir = tmpdir / "dpo"
Expand All @@ -695,9 +697,9 @@ def test_sft_and_dpo_lora_mistral(tmpdir):
with open(dpo_ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi:
trainer_state = json.loads(fi.read())
# The margins start close to zero
assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.1)
# But they rise to >0.15
assert trainer_state["log_history"][-1]["rewards/margins"] > 0.15
assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.02)
# But they rise to >0.03
assert trainer_state["log_history"][-1]["rewards/margins"] > 0.03

# LoRA setup should lead to adapter checkpoint
assert (dpo_ckpt_dir / "checkpoint-final" / "adapter_model.safetensors").isfile()
Expand All @@ -707,12 +709,12 @@ def test_generate(tmpdir):
# 0. Config definition:
generate_config = """\
model_conf:
model: hf-internal-testing/tiny-random-MistralForCausalLM
model: hf-internal-testing/tiny-random-Gemma2ForCausalLM
model_args:
attn_implementation: "eager"
use_cache: False
device_map: cpu
revision: 7517176934d50d00707900b92ef6a95c782e51a2
revision: de7c11b6c25d26ddd1bf4324fcf479b61d18e440
prompt_conf:
type: "open-input"
input: "Hi my name is "
Expand All @@ -725,4 +727,4 @@ def test_generate(tmpdir):
finetuning.cli.generation_cli([str(tmpdir / "gen_conf.yaml"), str(tmpdir / "generated_output.yaml")])
with open(tmpdir / "generated_output.yaml") as fin:
list_of_outputs = yaml.safe_load(fin)
assert list_of_outputs[0]["answer"] == "Bod OFF soon FT" # This is what tiny-random-MistralForCausalLM replies.
assert list_of_outputs[0]["answer"] == "amientosamientosérômeérôme" # This is what tiny-random-Gemma2 replies.