Skip to content

Commit 8cf93b4

Browse files
Ahmed Khaledcopybara-github
authored andcommitted
Add training scripts for the diffusion language model workload
PiperOrigin-RevId: 867908182
1 parent 888d8f8 commit 8cf93b4

File tree

4 files changed

+292
-8
lines changed

4 files changed

+292
-8
lines changed

init2winit/trainer_lib/base_trainer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def _check_early_stopping(self, report):
424424
self._early_stopping_target_value)
425425
return early_stopping_condition
426426

427-
def _eval(self, start_step, start_time, save=True):
427+
def _eval(self, start_step, start_time, eval_rng, save=True):
428428
"""Evaluate.
429429
430430
Has the side-effects of:
@@ -437,12 +437,14 @@ def _eval(self, start_step, start_time, save=True):
437437
Args:
438438
start_step: the training start step.
439439
start_time: the training start time.
440+
eval_rng: rng seed used in eval (chiefly for the MDLM workload).
440441
save: flag to save a checkpoint to disk. defaults to True.
441442
442443
Returns:
443444
A Dict[str, Any] eval report, originally created in
444445
trainer_utils.eval_metrics.
445446
"""
447+
446448
time_since_last_eval = time.time() - self._time_at_prev_eval_end
447449

448450
if self._eval_use_ema:
@@ -452,6 +454,8 @@ def _eval(self, start_step, start_time, save=True):
452454
else:
453455
eval_params = self._params
454456

457+
eval_rng = jax.random.fold_in(eval_rng, self._global_step)
458+
455459
report, eval_time = trainer_utils.eval_metrics(
456460
eval_params,
457461
self._batch_stats,
@@ -461,6 +465,7 @@ def _eval(self, start_step, start_time, save=True):
461465
self._eval_train_num_batches,
462466
self._evaluate_batch_jitted,
463467
self.finalize_batch_fn,
468+
eval_rng=eval_rng,
464469
)
465470
self._run_eval_callbacks(report)
466471
if save:
@@ -618,8 +623,7 @@ def train(self):
618623
# across hosts.
619624
rng, init_rng = jax.random.split(self._rng)
620625
rng = jax.random.fold_in(rng, jax.process_index())
621-
rng, data_rng = jax.random.split(rng)
622-
rng, callback_rng = jax.random.split(rng)
626+
rng, data_rng, callback_rng, eval_rng = jax.random.split(rng, 4)
623627

624628
if jax.process_index() == 0:
625629
logging.info('Let the training begin!')
@@ -705,7 +709,7 @@ def train(self):
705709
self._global_step, self._eval_frequency, self._eval_steps
706710
):
707711
try:
708-
report = self._eval(start_step, start_time)
712+
report = self._eval(start_step, start_time, eval_rng)
709713
except utils.TrainingDivergedError as e:
710714
self.wait_until_orbax_checkpointer_finished()
711715
raise utils.TrainingDivergedError(
@@ -720,7 +724,7 @@ def train(self):
720724
# If we moved where in the loop body evals happen then we would not need
721725
# this test.
722726
if self._prev_eval_step != self._num_train_steps:
723-
report = self._eval(start_step, start_time)
727+
report = self._eval(start_step, start_time, eval_rng)
724728
yield report
725729
# To make sure the last checkpoint was correctly saved.
726730
self.wait_until_orbax_checkpointer_finished()
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# coding=utf-8
2+
# Copyright 2026 The init2winit Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Integration test for MDLM training with a patterned fake dataset.
17+
18+
Verifies that the full training loop (model init -> training -> eval) works
19+
end-to-end and that loss decreases on a simple repeating pattern.
20+
21+
"""
22+
23+
import os
24+
import shutil
25+
import tempfile
26+
27+
from absl import logging
28+
from absl.testing import absltest
29+
from init2winit import utils
30+
from init2winit.dataset_lib import data_utils
31+
from init2winit.init_lib import initializers
32+
from init2winit.model_lib import models
33+
from init2winit.trainer_lib import trainer
34+
import jax
35+
import jax.numpy as jnp
36+
from ml_collections.config_dict import config_dict
37+
import numpy as np
38+
import pandas
39+
import tensorflow.compat.v1 as tf
40+
41+
Dataset = data_utils.Dataset
42+
43+
# Small vocab and sequence length so the test runs quickly on CPU.
44+
_VOCAB_SIZE = 16
45+
_SEQ_LEN = 32
46+
_BATCH_SIZE = 16
47+
_EVAL_NUM_BATCHES = 10
48+
49+
50+
def _make_patterned_batch(batch_size, vocab_size, seq_len):
51+
"""Creates a batch where each row is a cyclic shift of [0, 1, ..., V-1].
52+
53+
Row i = [(i % V), (i+1 % V), ..., (i+seq_len-1 % V)].
54+
This gives the model a simple and learnable pattern.
55+
56+
Args:
57+
batch_size: Number of sequences in the batch.
58+
vocab_size: Size of the vocabulary.
59+
seq_len: Length of each sequence.
60+
61+
Returns:
62+
A dict with 'inputs', 'targets', and 'weights'.
63+
"""
64+
rows = []
65+
for i in range(batch_size):
66+
row = [(i + j) % vocab_size for j in range(seq_len)]
67+
rows.append(row)
68+
tokens = jnp.array(rows, dtype=jnp.int32)
69+
return {
70+
'inputs': tokens,
71+
'targets': tokens, # MDLM: inputs == targets.
72+
'weights': jnp.ones(tokens.shape),
73+
}
74+
75+
76+
def _get_patterned_mdlm_dataset(batch_size, eval_num_batches):
77+
"""Returns a fake MDLM dataset with a cyclic-shift pattern."""
78+
79+
def train_iterator_fn():
80+
while True:
81+
batch = _make_patterned_batch(batch_size, _VOCAB_SIZE, _SEQ_LEN)
82+
yield batch
83+
84+
def eval_train_epoch(num_batches=None):
85+
if num_batches is None:
86+
num_batches = eval_num_batches
87+
for _ in range(num_batches):
88+
batch = _make_patterned_batch(batch_size, _VOCAB_SIZE, _SEQ_LEN)
89+
yield batch
90+
91+
meta_data = {
92+
'apply_one_hot_in_loss': False,
93+
'shift_inputs': False,
94+
'causal': False,
95+
}
96+
return (
97+
Dataset(
98+
train_iterator_fn,
99+
eval_train_epoch,
100+
eval_train_epoch,
101+
eval_train_epoch,
102+
),
103+
meta_data,
104+
)
105+
106+
107+
class MDLMIntegrationTest(absltest.TestCase):
108+
"""Integration test: train MDLM and verify loss decreases."""
109+
110+
def setUp(self):
111+
super().setUp()
112+
self.test_dir = tempfile.mkdtemp()
113+
self.trainer = None
114+
115+
def tearDown(self):
116+
if self.trainer is not None:
117+
self.trainer.wait_until_orbax_checkpointer_finished()
118+
shutil.rmtree(self.test_dir)
119+
super().tearDown()
120+
121+
def test_loss_decreases_on_pattern(self):
122+
"""MDLM should learn a trivial cyclic pattern and decrease loss."""
123+
rng = jax.random.PRNGKey(0)
124+
125+
model_str = 'mdlm_rope_nanodo'
126+
model_cls = models.get_model(model_str)
127+
loss_name = 'passthrough'
128+
metrics_name = 'mdlm_metrics'
129+
130+
hps = config_dict.ConfigDict({
131+
'batch_size': _BATCH_SIZE,
132+
'emb_dim': 32,
133+
'num_heads': 2,
134+
'num_layers': 2,
135+
'mlp_dim': 64,
136+
'vocab_size': _VOCAB_SIZE,
137+
'input_shape': (_SEQ_LEN,),
138+
'output_shape': (_SEQ_LEN, _VOCAB_SIZE),
139+
'computation_dtype': 'float32',
140+
'model_dtype': 'float32',
141+
'normalization': 'rmsnorm',
142+
'mlp_activation': 'glu',
143+
'qk_norm': True,
144+
'tie_embeddings': True,
145+
'noise_schedule': 'log_linear',
146+
'optimizer': 'adam',
147+
'opt_hparams': {
148+
'beta1': 0.9,
149+
'beta2': 0.999,
150+
'epsilon': 1e-8,
151+
'weight_decay': 0.0,
152+
},
153+
'lr_hparams': {
154+
'base_lr': 0.003,
155+
'schedule': 'constant',
156+
},
157+
'l2_decay_factor': 0.0,
158+
'l2_decay_rank_threshold': 2,
159+
'grad_clip': None,
160+
'label_smoothing': 0.0,
161+
'use_shallue_label_smoothing': False,
162+
'rng_seed': 0,
163+
'train_size': _BATCH_SIZE * 100,
164+
'num_device_prefetches': 0,
165+
'epsilon': 1e-9,
166+
})
167+
168+
dataset, dataset_meta_data = _get_patterned_mdlm_dataset(
169+
_BATCH_SIZE, _EVAL_NUM_BATCHES
170+
)
171+
model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)
172+
initializer = initializers.get_initializer('noop')
173+
174+
num_train_steps = 1200
175+
eval_frequency = 200
176+
177+
metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
178+
self.trainer = trainer.Trainer(
179+
train_dir=self.test_dir,
180+
model=model,
181+
dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
182+
initializer=initializer,
183+
num_train_steps=num_train_steps,
184+
hps=hps,
185+
rng=rng,
186+
eval_batch_size=_BATCH_SIZE,
187+
eval_use_ema=False,
188+
eval_num_batches=_EVAL_NUM_BATCHES,
189+
test_num_batches=0,
190+
eval_train_num_batches=_EVAL_NUM_BATCHES,
191+
eval_frequency=eval_frequency,
192+
checkpoint_steps=[],
193+
metrics_logger=metrics_logger,
194+
init_logger=init_logger,
195+
)
196+
_ = list(self.trainer.train())
197+
198+
# ---- Check loss trajectory ----
199+
with tf.io.gfile.GFile(
200+
os.path.join(self.test_dir, 'measurements.csv')
201+
) as f:
202+
df = pandas.read_csv(f)
203+
train_cost = df['train_cost'].values
204+
self.assertGreater(
205+
train_cost[0],
206+
train_cost[-1],
207+
msg=(
208+
'Expected loss to decrease. '
209+
f'Initial: {train_cost[0]:.4f}, Final: {train_cost[-1]:.4f}'
210+
),
211+
)
212+
self.assertLess(
213+
train_cost[-1],
214+
0.5,
215+
msg=(
216+
'Expected final loss well below random baseline. '
217+
f'Final: {train_cost[-1]:.4f}'
218+
),
219+
)
220+
221+
valid_ce = df['valid/ce_loss'].values
222+
valid_ppl = df['valid/perplexity'].values
223+
self.assertTrue(
224+
all(np.isfinite(valid_ce)),
225+
msg=f'valid/ce_loss contains non-finite: {valid_ce}',
226+
)
227+
self.assertTrue(
228+
all(np.isfinite(valid_ppl)),
229+
msg=f'valid/perplexity contains non-finite: {valid_ppl}',
230+
)
231+
self.assertLess(
232+
valid_ce[-1],
233+
valid_ce[0],
234+
msg=(
235+
'Expected valid/ce_loss to decrease. '
236+
f'Initial: {valid_ce[0]:.4f}, Final: {valid_ce[-1]:.4f}'
237+
),
238+
)
239+
self.assertGreater(
240+
valid_ppl[0],
241+
valid_ppl[-1],
242+
msg=(
243+
'Expected valid/perplexity to decrease. '
244+
f'Initial: {valid_ppl[0]:.4f}, Final: {valid_ppl[-1]:.4f}'
245+
),
246+
)
247+
248+
# ---- Verify evaluate_batch ----
249+
params = self.trainer.get_params()
250+
batch = _make_patterned_batch(_BATCH_SIZE, _VOCAB_SIZE, _SEQ_LEN)
251+
batch['eval_rng'] = jax.random.PRNGKey(42)
252+
eval_metrics = model.evaluate_batch(params, batch_stats=None, batch=batch)
253+
eval_results = eval_metrics.compute()
254+
self.assertTrue(np.isfinite(eval_results['ce_loss']))
255+
self.assertTrue(np.isfinite(eval_results['perplexity']))
256+
logging.info(
257+
'Direct evaluate_batch: ce_loss=%.4f, perplexity=%.4f',
258+
eval_results['ce_loss'],
259+
eval_results['perplexity'],
260+
)
261+
262+
if __name__ == '__main__':
263+
absltest.main()

init2winit/trainer_lib/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,7 @@ def finalize_batch_fn(self, batch):
138138
"""Finalize the batch by making a global array out of the shards."""
139139

140140
return trainer_utils.make_finalize_batch_fn(self._mesh)(batch)
141+
142+
def get_params(self):
143+
"""Returns the model parameters."""
144+
return self._params

0 commit comments

Comments
 (0)