Skip to content

Unable to reproduce MedQA, MMLU-Med results on technical report #37

@HuangChiEn

Description

@HuangChiEn

Thanks for releasing this amazing works.
However, i have encountered the reproduction issue, the following results is the medgemma-27b-it text only version, medQA & MMLU-Med is quite low as title.
medQA is very low, even lower then the 4B version (without test-time scaling!)

Since i'm using lm_eval, well-known tools for eval, so it's barely happens the issue on their side.

"results": {
    "medqa_4options": {
      "alias": "medqa_4options",
      "acc,none": 0.5648075412411626,     # Technical Report :  87.7,   4B : 64.4
      "acc_stderr,none": 0.013901042408038702,
      "acc_norm,none": 0.5648075412411626,
      "acc_norm_stderr,none": 0.013901042408038702
    },
    "mmlu_anatomy": {      # Technical Report : 83.7,   4B : 59.3
      "alias": "anatomy",
      "acc,none": 0.562962962962963,
      "acc_stderr,none": 0.042849586397534056
    },
    "mmlu_clinical_knowledge": {     # Technical Report : 86,   4B : 71.3
      "alias": "clinical_knowledge",
      "acc,none": 0.5622641509433962,
      "acc_stderr,none": 0.030533338430467478
    },
    "mmlu_college_biology": {   # & so on.....
      "alias": "college_biology",
      "acc,none": 0.6319444444444444,
      "acc_stderr,none": 0.04032999053960717
    },
    "mmlu_college_medicine": {
      "alias": "college_medicine",
      "acc,none": 0.49710982658959535,
      "acc_stderr,none": 0.03812400565974833
    },
    "mmlu_medical_genetics": {
      "alias": "medical_genetics",
      "acc,none": 0.63,
      "acc_stderr,none": 0.048523658709390974
    },
    "mmlu_professional_medicine": {
      "alias": "professional_medicine",
      "acc,none": 0.6911764705882353,
      "acc_stderr,none": 0.028064998167040053
    }
  }

the following code snippet is a minimum reproduced script.
my exec cmd : CUDA_VISIBLE_DEVICES="0" python llm_eval.py \ --model google/medgemma-27b-text-it \ --output res \ --tasks medqa_4options mmlu_anatomy mmlu_clinical_knowledge mmlu_college_medicine mmlu_medical_genetics mmlu_professional_medicine mmlu_college_biology

import argparse
import json
from lm_eval import evaluator
from model import *

def build_args():
    parser = argparse.ArgumentParser()

    # Model arguments
    parser.add_argument(
        "--model",
        type=str,
        default="google/medgemma-3-27b",
        help="HF model name or local path",
    )
    parser.add_argument(
        "--lora_path",
        type=str,
        default="",
        help="Lora path",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        choices=["bfloat16", "float16", "float32"],
    )
    parser.add_argument(
        "--device_map",
        type=str,
        default="auto",
    )
    parser.add_argument(
        "--max_length",
        type=int,
        default=4096,
    )
    parser.add_argument(
        "--max_gen_toks",
        type=int,
        default=512,
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=4,
    )

    # Evaluation tasks
    parser.add_argument(
        "--tasks",
        nargs="+",
        default=["pubmedqa"],
        help="List of evaluation tasks",
    )

    # Output file
    parser.add_argument(
        "--output",
        type=str,
        default="gemma3_medical_eval_results.json",
        help="where to store results",
    )

    return parser.parse_args()

def main():
    args = build_args()
    
    print("[INFO] Running tasks:", args.tasks)
    print("[INFO] Model:", args.model)

    # lm-eval model_args format: key=value,key=value,...
    model_args = (
        f"pretrained={args.model},"
        f"lora_path={args.lora_path},"
        f"dtype={args.dtype},"
        f"device_map={args.device_map},"
        f"max_length={args.max_length},"
        f"max_gen_toks={args.max_gen_toks}"
    )

    results = evaluator.simple_evaluate(
        model="gemma3_hf",
        model_args=model_args,
        tasks=args.tasks,
        batch_size=args.batch_size,
        log_samples=True,
    )

    print("\n======= SUMMARY =======")
    for task, r in results["results"].items():
        print(task)
        acc = r.get("acc")
        acc_norm = r.get("acc_norm")
        if acc is not None:
            print(f"  acc      = {acc:.4f}")
        if acc_norm is not None:
            print(f"  acc_norm = {acc_norm:.4f}")
        print("------------------")

    with open(args.output, "w") as f:
        json.dump(results, f, indent=2)

    print(f"[INFO] Saved results to {args.output}")


if __name__ == "__main__":
    main()

@dgolden1
Any suggestion will be appreciated!!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions