From 435433aca26e64febe80ff558d45682da3027514 Mon Sep 17 00:00:00 2001 From: Aku Rouhe Date: Fri, 10 Oct 2025 16:55:32 +0300 Subject: [PATCH 1/7] Fix the data format documentation --- src/finetuning/config/data.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/finetuning/config/data.py b/src/finetuning/config/data.py index 8d9f1bd..210de09 100644 --- a/src/finetuning/config/data.py +++ b/src/finetuning/config/data.py @@ -68,11 +68,14 @@ class ConcatenationDataInput(DataInput): The datasets themselves need to be in the finetuning supported JSONL formats. For SFT this means lines: - {"messages": {"content": "string", "role": "string"}} + {"messages": [{"content": "string", "role": "string"}]} For DPO this means lines of: - - {"prompt_messages": {"content": "string", "role": "string"}, "chosen_messages": {"content": "string", "role": "string"}, "rejected_messages": {"content": "string", "role": "string"}} + { + "prompt_messages": [{"content": "string", "role": "string"}], + "chosen_messages": [{"content": "string", "role": "string"}], + "rejected_messages": [{"content": "string", "role": "string"}] + } """ type: Literal[DataInputType.CONCATENATION] @@ -89,11 +92,14 @@ class WeightedMixDataInput(DataInput): The datasets themselves need to be in the finetuning supported JSONL formats. For SFT this means lines: - {"messages": {"content": "string", "role": "string"}} + {"messages": [{"content": "string", "role": "string"}]} For DPO this means lines of: - - {"prompt_messages": {"content": "string", "role": "string"}, "chosen_messages": {"content": "string", "role": "string"}, "rejected_messages": {"content": "string", "role": "string"}} + { + "prompt_messages": [{"content": "string", "role": "string"}], + "chosen_messages": [{"content": "string", "role": "string"}], + "rejected_messages": [{"content": "string", "role": "string"}] + } """ type: Literal[DataInputType.PRECOMPUTE_WEIGHTED_MIX] From 9cdb0ad0b47ab6da951050cce19a16761e5eb5d1 Mon Sep 17 00:00:00 2001 From: Aku Rouhe Date: Wed, 15 Oct 2025 09:54:51 +0300 Subject: [PATCH 2/7] Try to avoid segmentation fault on GHA --- tests/test_e2e.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 1594439..1df1a01 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -218,6 +218,7 @@ def test_sft_gpt(tmpdir): assert trainer_state["log_history"][-1]["loss"] < 4.0 +@pytest.mark.skip(reason="segmentation fault on GHA") def test_sft_mistral_lora(tmpdir): # 0. Config definition: config = """\ From 5de22bee1dcaaf68e3eb491ab2c4dada17c94fde Mon Sep 17 00:00:00 2001 From: Aku Rouhe Date: Wed, 15 Oct 2025 10:12:03 +0300 Subject: [PATCH 3/7] Try to avoid segmentation fault on GHA - trigger fix --- .github/workflows/package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/package.yml b/.github/workflows/package.yml index 3ff8c3f..f9244b5 100644 --- a/.github/workflows/package.yml +++ b/.github/workflows/package.yml @@ -10,7 +10,7 @@ on: - "docker/finetuning/rocm-silogen-finetuning-worker.Dockerfile" - "requirements.txt" - ".github/workflows/package.yml" - - "tests/" + - "tests/**" jobs: common: From 8ea7b816d13205a4185e8b263b33a79f4fbef2af Mon Sep 17 00:00:00 2001 From: Aku Rouhe Date: Wed, 15 Oct 2025 12:00:59 +0300 Subject: [PATCH 4/7] Use gemma2 instead of mistral --- tests/test_e2e.py | 66 +++++++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 1df1a01..c02bad1 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -218,8 +218,7 @@ def test_sft_gpt(tmpdir): assert trainer_state["log_history"][-1]["loss"] < 4.0 -@pytest.mark.skip(reason="segmentation fault on GHA") -def test_sft_mistral_lora(tmpdir): +def test_sft_gemma2_lora(tmpdir): # 0. Config definition: config = """\ data_conf: @@ -268,11 +267,11 @@ def test_sft_mistral_lora(tmpdir): - v_proj - o_proj run_conf: - model: hf-internal-testing/tiny-random-MistralForCausalLM + model: hf-internal-testing/tiny-random-Gemma2ForCausalLM model_args: attn_implementation: "eager" use_cache: False - revision: 7517176934d50d00707900b92ef6a95c782e51a2 + revision: de7c11b6c25d26ddd1bf4324fcf479b61d18e440 resume_from_checkpoint: False """ num_steps = 15 @@ -296,10 +295,10 @@ def test_sft_mistral_lora(tmpdir): # 3. Inspect results! with open(ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi: trainer_state = json.loads(fi.read()) - # At the start, loss should still be around -log(1/32002)~10.3, (Even the tiny Mistral has 32002 units after Chat-ML) - assert np.isclose(trainer_state["log_history"][0]["loss"], 10.3, atol=0.5) - # In current setup, loss goes at least below 2.5 by 15 steps - assert trainer_state["log_history"][-1]["loss"] < 2.5 + # At the start, loss should still be around -log(1/256000)~12.4529327234617, (Even the tiny Gemma2 has 256000 units after Chat-ML) + assert np.isclose(trainer_state["log_history"][0]["loss"], 12.45, atol=0.5) + # In current setup, loss goes at least below 5.0 by 15 steps + assert trainer_state["log_history"][-1]["loss"] < 5.0 # LoRA setup should lead to adapter checkpoint assert (ckpt_dir / "checkpoint-final" / "adapter_model.safetensors").isfile() # And since the basemodel vocab was modified, there is a new basemodel checkpoint too @@ -311,7 +310,7 @@ def intercept_tokenizer_call(*args, overwrite_chat_template=False, **kwargs): @patch("finetuning.dpo.subsetup_tokenizer", intercept_tokenizer_call) -def test_dpo_mistral(tmpdir): +def test_dpo_gemma2(tmpdir): # 0. DPO Config definition: dpo_config = """\ data_conf: @@ -351,11 +350,11 @@ def test_dpo_mistral(tmpdir): peft_conf: peft_type: "NO_PEFT" run_conf: - model: hf-internal-testing/tiny-random-MistralForCausalLM + model: hf-internal-testing/tiny-random-Gemma2ForCausalLM model_args: attn_implementation: "eager" use_cache: False - revision: 7517176934d50d00707900b92ef6a95c782e51a2 + revision: de7c11b6c25d26ddd1bf4324fcf479b61d18e440 resume_from_checkpoint: False """ num_steps = 15 @@ -384,12 +383,12 @@ def test_dpo_mistral(tmpdir): with open(dpo_ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi: trainer_state = json.loads(fi.read()) # The margins start close to zero - assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.1) - # But they rise to >0.10 - assert trainer_state["log_history"][-1]["rewards/margins"] > 0.10 + assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.03) + # But they rise to >0.07 + assert trainer_state["log_history"][-1]["rewards/margins"] > 0.07 -def test_sft_and_dpo_mistral(tmpdir): +def test_sft_and_dpo_gemma2(tmpdir): # 0. Config definition: sft_config = """\ data_conf: @@ -428,11 +427,11 @@ def test_sft_and_dpo_mistral(tmpdir): peft_conf: peft_type: "NO_PEFT" run_conf: - model: hf-internal-testing/tiny-random-MistralForCausalLM + model: hf-internal-testing/tiny-random-Gemma2ForCausalLM model_args: attn_implementation: "eager" use_cache: False - revision: 7517176934d50d00707900b92ef6a95c782e51a2 + revision: de7c11b6c25d26ddd1bf4324fcf479b61d18e440 resume_from_checkpoint: False """ num_steps = 15 @@ -456,8 +455,8 @@ def test_sft_and_dpo_mistral(tmpdir): # # 3. Inspect SFT results! with open(sft_ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi: trainer_state = json.loads(fi.read()) - # At the start, loss should still be around -log(1/32002)~10.3, (Even the tiny Mistral has 32002 units after Chat-ML) - assert np.isclose(trainer_state["log_history"][0]["loss"], 10.3, atol=0.5) + # At the start, loss should still be around -log(1/256000)~12.4529327234617, (Even the tiny Gemma2 has 256000 units after Chat-ML) + assert np.isclose(trainer_state["log_history"][0]["loss"], 12.45, atol=0.5) # In current setup, loss goes at least below 4.5 by 15 steps assert trainer_state["log_history"][-1]["loss"] < 4.5 @@ -535,12 +534,12 @@ def test_sft_and_dpo_mistral(tmpdir): with open(dpo_ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi: trainer_state = json.loads(fi.read()) # The margins start close to zero - assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.1) - # But they rise to >0.15 - assert trainer_state["log_history"][-1]["rewards/margins"] > 0.15 + assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.03) + # But they rise to >0.07 + assert trainer_state["log_history"][-1]["rewards/margins"] > 0.07 -def test_sft_and_dpo_lora_mistral(tmpdir): +def test_sft_and_dpo_lora_gemma2(tmpdir): # 0. Config definition: sft_config = """\ data_conf: @@ -562,6 +561,7 @@ def test_sft_and_dpo_lora_mistral(tmpdir): gradient_checkpointing_kwargs: use_reentrant: false learning_rate: 0.5 + logging_first_step: True logging_steps: 1 logging_strategy: "steps" lr_scheduler_type: "constant" @@ -579,11 +579,11 @@ def test_sft_and_dpo_lora_mistral(tmpdir): peft_conf: peft_type: "NO_PEFT" run_conf: - model: hf-internal-testing/tiny-random-MistralForCausalLM + model: hf-internal-testing/tiny-random-Gemma2ForCausalLM model_args: attn_implementation: "eager" use_cache: False - revision: 7517176934d50d00707900b92ef6a95c782e51a2 + revision: de7c11b6c25d26ddd1bf4324fcf479b61d18e440 resume_from_checkpoint: False """ num_steps = 15 @@ -607,8 +607,8 @@ def test_sft_and_dpo_lora_mistral(tmpdir): # 3. Inspect SFT results! with open(sft_ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi: trainer_state = json.loads(fi.read()) - # At the start, loss should still be around -log(1/32002)~10.3, (Even the tiny Mistral has 32002 units after Chat-ML) - assert np.isclose(trainer_state["log_history"][0]["loss"], 10.3, atol=0.5) + # At the start, loss should still be around -log(1/256000)~12.4529327234617, (Even the tiny Gemma2 has 256000 units after Chat-ML) + assert np.isclose(trainer_state["log_history"][0]["loss"], 12.45, atol=0.5) # In current setup, loss goes at least below 4.5 by 15 steps assert trainer_state["log_history"][-1]["loss"] < 4.5 @@ -696,9 +696,9 @@ def test_sft_and_dpo_lora_mistral(tmpdir): with open(dpo_ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi: trainer_state = json.loads(fi.read()) # The margins start close to zero - assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.1) - # But they rise to >0.15 - assert trainer_state["log_history"][-1]["rewards/margins"] > 0.15 + assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.03) + # But they rise to >0.07 + assert trainer_state["log_history"][-1]["rewards/margins"] > 0.07 # LoRA setup should lead to adapter checkpoint assert (dpo_ckpt_dir / "checkpoint-final" / "adapter_model.safetensors").isfile() @@ -708,12 +708,12 @@ def test_generate(tmpdir): # 0. Config definition: generate_config = """\ model_conf: - model: hf-internal-testing/tiny-random-MistralForCausalLM + model: hf-internal-testing/tiny-random-Gemma2ForCausalLM model_args: attn_implementation: "eager" use_cache: False device_map: cpu - revision: 7517176934d50d00707900b92ef6a95c782e51a2 + revision: de7c11b6c25d26ddd1bf4324fcf479b61d18e440 prompt_conf: type: "open-input" input: "Hi my name is " @@ -726,4 +726,4 @@ def test_generate(tmpdir): finetuning.cli.generation_cli([str(tmpdir / "gen_conf.yaml"), str(tmpdir / "generated_output.yaml")]) with open(tmpdir / "generated_output.yaml") as fin: list_of_outputs = yaml.safe_load(fin) - assert list_of_outputs[0]["answer"] == "Bod OFF soon FT" # This is what tiny-random-MistralForCausalLM replies. + assert list_of_outputs[0]["answer"] == "amientosamientosérômeérôme" # This is what tiny-random-Gemma2 replies. From d2ae49935a7d54dca2b2cf5dbbc8282db67f4bea Mon Sep 17 00:00:00 2001 From: Aku Rouhe Date: Wed, 15 Oct 2025 13:43:43 +0300 Subject: [PATCH 5/7] gemma2 loss margins lower --- tests/test_e2e.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_e2e.py b/tests/test_e2e.py index c02bad1..2d405f4 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -561,7 +561,7 @@ def test_sft_and_dpo_lora_gemma2(tmpdir): gradient_checkpointing_kwargs: use_reentrant: false learning_rate: 0.5 - logging_first_step: True + logging_first_step: true logging_steps: 1 logging_strategy: "steps" lr_scheduler_type: "constant" @@ -634,6 +634,7 @@ def test_sft_and_dpo_lora_gemma2(tmpdir): gradient_checkpointing_kwargs: use_reentrant: false learning_rate: 0.5 + logging_first_step: true logging_steps: 1 logging_strategy: "steps" lr_scheduler_type: "constant" @@ -669,7 +670,7 @@ def test_sft_and_dpo_lora_gemma2(tmpdir): resume_from_checkpoint: False """ # note: potential to redefine - num_steps = 15 + num_steps = 20 # 5. setup # Set paths dpotmpdir = tmpdir / "dpo" @@ -696,9 +697,9 @@ def test_sft_and_dpo_lora_gemma2(tmpdir): with open(dpo_ckpt_dir / f"checkpoint-{num_steps}" / "trainer_state.json") as fi: trainer_state = json.loads(fi.read()) # The margins start close to zero - assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.03) - # But they rise to >0.07 - assert trainer_state["log_history"][-1]["rewards/margins"] > 0.07 + assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.02) + # But they rise to >0.04 + assert trainer_state["log_history"][-1]["rewards/margins"] > 0.04 # LoRA setup should lead to adapter checkpoint assert (dpo_ckpt_dir / "checkpoint-final" / "adapter_model.safetensors").isfile() From 5c576f5330924bc52176a71ff6a8f887bca06411 Mon Sep 17 00:00:00 2001 From: Aku Rouhe Date: Wed, 15 Oct 2025 13:59:02 +0300 Subject: [PATCH 6/7] gemma2 loss margins even lower --- tests/test_e2e.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 2d405f4..37ea0eb 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -698,8 +698,8 @@ def test_sft_and_dpo_lora_gemma2(tmpdir): trainer_state = json.loads(fi.read()) # The margins start close to zero assert np.isclose(trainer_state["log_history"][0]["rewards/margins"], 0.0, atol=0.02) - # But they rise to >0.04 - assert trainer_state["log_history"][-1]["rewards/margins"] > 0.04 + # But they rise to >0.03 + assert trainer_state["log_history"][-1]["rewards/margins"] > 0.03 # LoRA setup should lead to adapter checkpoint assert (dpo_ckpt_dir / "checkpoint-final" / "adapter_model.safetensors").isfile() From 6fa1485d6a3814324969c1a175930c2eaf513b51 Mon Sep 17 00:00:00 2001 From: Aku Rouhe Date: Wed, 15 Oct 2025 14:23:44 +0300 Subject: [PATCH 7/7] Add the indentation back --- src/finetuning/config/data.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/finetuning/config/data.py b/src/finetuning/config/data.py index 210de09..dfa3b96 100644 --- a/src/finetuning/config/data.py +++ b/src/finetuning/config/data.py @@ -68,14 +68,14 @@ class ConcatenationDataInput(DataInput): The datasets themselves need to be in the finetuning supported JSONL formats. For SFT this means lines: - {"messages": [{"content": "string", "role": "string"}]} + {"messages": [{"content": "string", "role": "string"}]} For DPO this means lines of: - { - "prompt_messages": [{"content": "string", "role": "string"}], - "chosen_messages": [{"content": "string", "role": "string"}], - "rejected_messages": [{"content": "string", "role": "string"}] - } + { + "prompt_messages": [{"content": "string", "role": "string"}], + "chosen_messages": [{"content": "string", "role": "string"}], + "rejected_messages": [{"content": "string", "role": "string"}] + } """ type: Literal[DataInputType.CONCATENATION] @@ -92,14 +92,14 @@ class WeightedMixDataInput(DataInput): The datasets themselves need to be in the finetuning supported JSONL formats. For SFT this means lines: - {"messages": [{"content": "string", "role": "string"}]} + {"messages": [{"content": "string", "role": "string"}]} For DPO this means lines of: - { - "prompt_messages": [{"content": "string", "role": "string"}], - "chosen_messages": [{"content": "string", "role": "string"}], - "rejected_messages": [{"content": "string", "role": "string"}] - } + { + "prompt_messages": [{"content": "string", "role": "string"}], + "chosen_messages": [{"content": "string", "role": "string"}], + "rejected_messages": [{"content": "string", "role": "string"}] + } """ type: Literal[DataInputType.PRECOMPUTE_WEIGHTED_MIX]