Skip to content

Cannot reproduce the FID=5.44 in Table 4 with EMA=0.9999 #42

@FutureXiang

Description

@FutureXiang

Hi, thanks for the impressive paper and code release!

However, I am not able to reproduce the FID=5.44 score in Table 4 (i.e., 200 epochs JiT-B/16, without in-context tokens, a pretty standard baseline).

I use the official codebase but change x = scaled_dot_product_attention to x = F.scaled_dot_product_attention for efficiency, and use 6 x NVIDIA 4090 GPU (local batch size=170, global=1020, should be almost equivalent to your 1024). However, I only got FID=7.41, which is far away from the reported 5.44.

Training script:

torchrun --nproc_per_node=6 --nnodes=1 --node_rank=0 \
main_jit.py \
--epochs 200 --warmup_epochs 5 \
--batch_size 170 --blr 5e-5 \
(......)

FID evaluation:

torchrun --nproc_per_node=6 --nnodes=1 --node_rank=0 \
main_jit.py \
--evaluate_gen \
--num_images 50000 --gen_bsz 256 \
--cfg 3.0 --interval_min 0.1 \
(......)

Could you kindly confirm: (1) will these minor differences (8GPU 1024 vs 6GPU 1020, scaled_dot_product_attention vs pytorch F.sdpa) lead to such a significant gap? (2) is the 5.44 FID obtained with EMA rate=0.9999 or 0.9996?

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions