Skip to content

Conversation

@josephdviviano
Copy link
Collaborator

@josephdviviano josephdviviano commented Dec 13, 2025

  • I've read the .github/CONTRIBUTING.md file
  • My code follows the typing guidelines
  • I've added appropriate tests
  • I've run pre-commit hooks locally

Description

  • Many tweaks to loss computations / score computations to enable some speedups (some significant, many are minor).
  • Tried to balance typical use-cases (compile vs eager, cpu vs gpu).
  • Adds vectorized ops for all losses (including flow matching).
  • Added a benchmark (>1000 loc) isolating the nature of the improvement:
python tutorials/misc/bench_get_scores_all.py --device cpu --size-scale 1.0 --repeat 10
Benchmarking TB/LPV/SubTB/DB/ModDB in sequence
torch version: 2.8.0
device: cpu
dtype (TB/LPV): torch.float32
num threads: 8
size-scale: 1.0
compile: True
repeat: 10
forward-looking (DB): False

Columns: original, original+compile, current, current+compile, speedup vs original (eager and compiled).

=== TB loss ===
         N   chk    orig(ms)  orig+c(ms)    curr(ms)  curr+c(ms)     spd   spd_c
     10240  PASS     123.256     226.142     128.075     227.104    0.96    0.54
     40960  PASS     426.625     465.959     432.483     467.667    0.99    0.91
    163840  PASS    1502.938    3019.063    1511.459    3036.792    0.99    0.49

=== LPV loss ===
         N   chk    orig(ms)  orig+c(ms)    curr(ms)  curr+c(ms)     spd   spd_c
     10240  PASS     122.998     265.660     122.117     266.198    1.01    0.46
     40960  PASS     458.521     527.375     458.521     528.479    1.00    0.87
    163840  PASS    1529.146    2474.459    1535.209    2476.625    1.00    0.62

=== SubTB get_scores ===
        size   chk    orig(ms)  orig+c(ms)    curr(ms)  curr+c(ms)     spd   spd_c
80x  640  PASS   22106.916   22212.291    3425.333    5188.938    6.45    4.26
160x 1280  PASS   96741.416   98876.834   18239.584   16141.709    5.30    5.99
320x 2560  PASS  357786.179  359309.822   60888.395   57334.043    5.88    6.24

=== DB get_scores ===
   n_trans   chk    orig(ms)  orig+c(ms)    curr(ms)  curr+c(ms)     spd   spd_c
     65536  PASS    1754.916    2128.500    1561.667     943.876    1.12    1.86
    131072  PASS    2214.625    2377.625    1762.334    1003.208    1.26    2.21
    262144  PASS    2820.520    3131.541    2327.792    1145.459    1.21    2.46

=== Modified DB get_scores ===
   n_trans   chk    orig(ms)  orig+c(ms)    curr(ms)  curr+c(ms)     spd   spd_c
     65536  PASS    1986.667    2120.395    1992.729    2668.625    1.00    0.74
    131072  PASS    3331.375    3657.042    3322.604    3593.625    1.00    0.93
    262144  PASS    4761.896    4865.604    4835.500    5147.833    0.98    0.93
    ```

@josephdviviano josephdviviano changed the title Gflownet optimize Gflownet optimize loss computation Dec 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants