-
Notifications
You must be signed in to change notification settings - Fork 143
Description
Hi, thanks for the impressive paper and code release!
However, I am not able to reproduce the FID=5.44 score in Table 4 (i.e., 200 epochs JiT-B/16, without in-context tokens, a pretty standard baseline).
I use the official codebase but change x = scaled_dot_product_attention to x = F.scaled_dot_product_attention for efficiency, and use 6 x NVIDIA 4090 GPU (local batch size=170, global=1020, should be almost equivalent to your 1024). However, I only got FID=7.41, which is far away from the reported 5.44.
Training script:
torchrun --nproc_per_node=6 --nnodes=1 --node_rank=0 \
main_jit.py \
--epochs 200 --warmup_epochs 5 \
--batch_size 170 --blr 5e-5 \
(......)
FID evaluation:
torchrun --nproc_per_node=6 --nnodes=1 --node_rank=0 \
main_jit.py \
--evaluate_gen \
--num_images 50000 --gen_bsz 256 \
--cfg 3.0 --interval_min 0.1 \
(......)
Could you kindly confirm: (1) will these minor differences (8GPU 1024 vs 6GPU 1020, scaled_dot_product_attention vs pytorch F.sdpa) lead to such a significant gap? (2) is the 5.44 FID obtained with EMA rate=0.9999 or 0.9996?
Thank you!