Question 1. Did anyone get Orb running in float16 or bfloat16?
Question 2. If I take float32 weights and train a bit in float16, what would convince us model didn't break? (e.g., would low validation MAE on MPTraj be sufficient?)
Apologies if I'm missing anything.
(update) Sorry for wall of text below, added details to make others didn't have to reproduce. If useful happy to write PR (with non-hacky fix for brute_force_kNN).
TLDR: bf16 is 2-3x faster on rtx4090. To get xs model working I had to hack brute_force_kNN. This gave 6.5steps/s = 0.28ns/day for a 36k atom system. I expect around 0.6-0.8ns/day on h100. Big thing left is to evaluate performance drop, e.g., reproduce Figure 2 in bf16.
Question 1. Did anyone get Orb running in
float16orbfloat16?Question 2. If I take
float32weights and train a bit infloat16, what would convince us model didn't break? (e.g., would low validation MAE on MPTraj be sufficient?)Apologies if I'm missing anything.
(update) Sorry for wall of text below, added details to make others didn't have to reproduce. If useful happy to write PR (with non-hacky fix for
brute_force_kNN).TLDR:
bf16is 2-3x faster on rtx4090. To getxsmodel working I had to hackbrute_force_kNN. This gave 6.5steps/s = 0.28ns/day for a 36k atom system. I expect around 0.6-0.8ns/day on h100. Big thing left is to evaluate performance drop, e.g., reproduce Figure 2 in bf16.