|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2026 The init2winit Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Integration test for MDLM training with a patterned fake dataset. |
| 17 | +
|
| 18 | +Verifies that the full training loop (model init -> training -> eval) works |
| 19 | +end-to-end and that loss decreases on a simple repeating pattern. |
| 20 | +
|
| 21 | +""" |
| 22 | + |
| 23 | +import os |
| 24 | +import shutil |
| 25 | +import tempfile |
| 26 | + |
| 27 | +from absl import logging |
| 28 | +from absl.testing import absltest |
| 29 | +from init2winit import utils |
| 30 | +from init2winit.dataset_lib import data_utils |
| 31 | +from init2winit.init_lib import initializers |
| 32 | +from init2winit.model_lib import models |
| 33 | +from init2winit.trainer_lib import trainer |
| 34 | +import jax |
| 35 | +import jax.numpy as jnp |
| 36 | +from ml_collections.config_dict import config_dict |
| 37 | +import numpy as np |
| 38 | +import pandas |
| 39 | +import tensorflow.compat.v1 as tf |
| 40 | + |
| 41 | +Dataset = data_utils.Dataset |
| 42 | + |
| 43 | +# Small vocab and sequence length so the test runs quickly on CPU. |
| 44 | +_VOCAB_SIZE = 16 |
| 45 | +_SEQ_LEN = 32 |
| 46 | +_BATCH_SIZE = 16 |
| 47 | +_EVAL_NUM_BATCHES = 10 |
| 48 | + |
| 49 | + |
| 50 | +def _make_patterned_batch(batch_size, vocab_size, seq_len): |
| 51 | + """Creates a batch where each row is a cyclic shift of [0, 1, ..., V-1]. |
| 52 | +
|
| 53 | + Row i = [(i % V), (i+1 % V), ..., (i+seq_len-1 % V)]. |
| 54 | + This gives the model a simple and learnable pattern. |
| 55 | +
|
| 56 | + Args: |
| 57 | + batch_size: Number of sequences in the batch. |
| 58 | + vocab_size: Size of the vocabulary. |
| 59 | + seq_len: Length of each sequence. |
| 60 | +
|
| 61 | + Returns: |
| 62 | + A dict with 'inputs', 'targets', and 'weights'. |
| 63 | + """ |
| 64 | + rows = [] |
| 65 | + for i in range(batch_size): |
| 66 | + row = [(i + j) % vocab_size for j in range(seq_len)] |
| 67 | + rows.append(row) |
| 68 | + tokens = jnp.array(rows, dtype=jnp.int32) |
| 69 | + return { |
| 70 | + 'inputs': tokens, |
| 71 | + 'targets': tokens, # MDLM: inputs == targets. |
| 72 | + 'weights': jnp.ones(tokens.shape), |
| 73 | + } |
| 74 | + |
| 75 | + |
| 76 | +def _get_patterned_mdlm_dataset(batch_size, eval_num_batches): |
| 77 | + """Returns a fake MDLM dataset with a cyclic-shift pattern.""" |
| 78 | + |
| 79 | + def train_iterator_fn(): |
| 80 | + while True: |
| 81 | + batch = _make_patterned_batch(batch_size, _VOCAB_SIZE, _SEQ_LEN) |
| 82 | + yield batch |
| 83 | + |
| 84 | + def eval_train_epoch(num_batches=None): |
| 85 | + if num_batches is None: |
| 86 | + num_batches = eval_num_batches |
| 87 | + for _ in range(num_batches): |
| 88 | + batch = _make_patterned_batch(batch_size, _VOCAB_SIZE, _SEQ_LEN) |
| 89 | + yield batch |
| 90 | + |
| 91 | + meta_data = { |
| 92 | + 'apply_one_hot_in_loss': False, |
| 93 | + 'shift_inputs': False, |
| 94 | + 'causal': False, |
| 95 | + } |
| 96 | + return ( |
| 97 | + Dataset( |
| 98 | + train_iterator_fn, |
| 99 | + eval_train_epoch, |
| 100 | + eval_train_epoch, |
| 101 | + eval_train_epoch, |
| 102 | + ), |
| 103 | + meta_data, |
| 104 | + ) |
| 105 | + |
| 106 | + |
| 107 | +class MDLMIntegrationTest(absltest.TestCase): |
| 108 | + """Integration test: train MDLM and verify loss decreases.""" |
| 109 | + |
| 110 | + def setUp(self): |
| 111 | + super().setUp() |
| 112 | + self.test_dir = tempfile.mkdtemp() |
| 113 | + self.trainer = None |
| 114 | + |
| 115 | + def tearDown(self): |
| 116 | + if self.trainer is not None: |
| 117 | + self.trainer.wait_until_orbax_checkpointer_finished() |
| 118 | + shutil.rmtree(self.test_dir) |
| 119 | + super().tearDown() |
| 120 | + |
| 121 | + def test_loss_decreases_on_pattern(self): |
| 122 | + """MDLM should learn a trivial cyclic pattern and decrease loss.""" |
| 123 | + rng = jax.random.PRNGKey(0) |
| 124 | + |
| 125 | + model_str = 'mdlm_rope_nanodo' |
| 126 | + model_cls = models.get_model(model_str) |
| 127 | + loss_name = 'passthrough' |
| 128 | + metrics_name = 'mdlm_metrics' |
| 129 | + |
| 130 | + hps = config_dict.ConfigDict({ |
| 131 | + 'batch_size': _BATCH_SIZE, |
| 132 | + 'emb_dim': 32, |
| 133 | + 'num_heads': 2, |
| 134 | + 'num_layers': 2, |
| 135 | + 'mlp_dim': 64, |
| 136 | + 'vocab_size': _VOCAB_SIZE, |
| 137 | + 'input_shape': (_SEQ_LEN,), |
| 138 | + 'output_shape': (_SEQ_LEN, _VOCAB_SIZE), |
| 139 | + 'computation_dtype': 'float32', |
| 140 | + 'model_dtype': 'float32', |
| 141 | + 'normalization': 'rmsnorm', |
| 142 | + 'mlp_activation': 'glu', |
| 143 | + 'qk_norm': True, |
| 144 | + 'tie_embeddings': True, |
| 145 | + 'noise_schedule': 'log_linear', |
| 146 | + 'optimizer': 'adam', |
| 147 | + 'opt_hparams': { |
| 148 | + 'beta1': 0.9, |
| 149 | + 'beta2': 0.999, |
| 150 | + 'epsilon': 1e-8, |
| 151 | + 'weight_decay': 0.0, |
| 152 | + }, |
| 153 | + 'lr_hparams': { |
| 154 | + 'base_lr': 0.003, |
| 155 | + 'schedule': 'constant', |
| 156 | + }, |
| 157 | + 'l2_decay_factor': 0.0, |
| 158 | + 'l2_decay_rank_threshold': 2, |
| 159 | + 'grad_clip': None, |
| 160 | + 'label_smoothing': 0.0, |
| 161 | + 'use_shallue_label_smoothing': False, |
| 162 | + 'rng_seed': 0, |
| 163 | + 'train_size': _BATCH_SIZE * 100, |
| 164 | + 'num_device_prefetches': 0, |
| 165 | + 'epsilon': 1e-9, |
| 166 | + }) |
| 167 | + |
| 168 | + dataset, dataset_meta_data = _get_patterned_mdlm_dataset( |
| 169 | + _BATCH_SIZE, _EVAL_NUM_BATCHES |
| 170 | + ) |
| 171 | + model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) |
| 172 | + initializer = initializers.get_initializer('noop') |
| 173 | + |
| 174 | + num_train_steps = 1200 |
| 175 | + eval_frequency = 200 |
| 176 | + |
| 177 | + metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) |
| 178 | + self.trainer = trainer.Trainer( |
| 179 | + train_dir=self.test_dir, |
| 180 | + model=model, |
| 181 | + dataset_builder=lambda *unused_args, **unused_kwargs: dataset, |
| 182 | + initializer=initializer, |
| 183 | + num_train_steps=num_train_steps, |
| 184 | + hps=hps, |
| 185 | + rng=rng, |
| 186 | + eval_batch_size=_BATCH_SIZE, |
| 187 | + eval_use_ema=False, |
| 188 | + eval_num_batches=_EVAL_NUM_BATCHES, |
| 189 | + test_num_batches=0, |
| 190 | + eval_train_num_batches=_EVAL_NUM_BATCHES, |
| 191 | + eval_frequency=eval_frequency, |
| 192 | + checkpoint_steps=[], |
| 193 | + metrics_logger=metrics_logger, |
| 194 | + init_logger=init_logger, |
| 195 | + ) |
| 196 | + _ = list(self.trainer.train()) |
| 197 | + |
| 198 | + # ---- Check loss trajectory ---- |
| 199 | + with tf.io.gfile.GFile( |
| 200 | + os.path.join(self.test_dir, 'measurements.csv') |
| 201 | + ) as f: |
| 202 | + df = pandas.read_csv(f) |
| 203 | + train_cost = df['train_cost'].values |
| 204 | + self.assertGreater( |
| 205 | + train_cost[0], |
| 206 | + train_cost[-1], |
| 207 | + msg=( |
| 208 | + 'Expected loss to decrease. ' |
| 209 | + f'Initial: {train_cost[0]:.4f}, Final: {train_cost[-1]:.4f}' |
| 210 | + ), |
| 211 | + ) |
| 212 | + self.assertLess( |
| 213 | + train_cost[-1], |
| 214 | + 0.5, |
| 215 | + msg=( |
| 216 | + 'Expected final loss well below random baseline. ' |
| 217 | + f'Final: {train_cost[-1]:.4f}' |
| 218 | + ), |
| 219 | + ) |
| 220 | + |
| 221 | + valid_ce = df['valid/ce_loss'].values |
| 222 | + valid_ppl = df['valid/perplexity'].values |
| 223 | + self.assertTrue( |
| 224 | + all(np.isfinite(valid_ce)), |
| 225 | + msg=f'valid/ce_loss contains non-finite: {valid_ce}', |
| 226 | + ) |
| 227 | + self.assertTrue( |
| 228 | + all(np.isfinite(valid_ppl)), |
| 229 | + msg=f'valid/perplexity contains non-finite: {valid_ppl}', |
| 230 | + ) |
| 231 | + self.assertLess( |
| 232 | + valid_ce[-1], |
| 233 | + valid_ce[0], |
| 234 | + msg=( |
| 235 | + 'Expected valid/ce_loss to decrease. ' |
| 236 | + f'Initial: {valid_ce[0]:.4f}, Final: {valid_ce[-1]:.4f}' |
| 237 | + ), |
| 238 | + ) |
| 239 | + self.assertGreater( |
| 240 | + valid_ppl[0], |
| 241 | + valid_ppl[-1], |
| 242 | + msg=( |
| 243 | + 'Expected valid/perplexity to decrease. ' |
| 244 | + f'Initial: {valid_ppl[0]:.4f}, Final: {valid_ppl[-1]:.4f}' |
| 245 | + ), |
| 246 | + ) |
| 247 | + |
| 248 | + # ---- Verify evaluate_batch ---- |
| 249 | + params = self.trainer.get_params() |
| 250 | + batch = _make_patterned_batch(_BATCH_SIZE, _VOCAB_SIZE, _SEQ_LEN) |
| 251 | + batch['eval_rng'] = jax.random.PRNGKey(42) |
| 252 | + eval_metrics = model.evaluate_batch(params, batch_stats=None, batch=batch) |
| 253 | + eval_results = eval_metrics.compute() |
| 254 | + self.assertTrue(np.isfinite(eval_results['ce_loss'])) |
| 255 | + self.assertTrue(np.isfinite(eval_results['perplexity'])) |
| 256 | + logging.info( |
| 257 | + 'Direct evaluate_batch: ce_loss=%.4f, perplexity=%.4f', |
| 258 | + eval_results['ce_loss'], |
| 259 | + eval_results['perplexity'], |
| 260 | + ) |
| 261 | + |
| 262 | +if __name__ == '__main__': |
| 263 | + absltest.main() |
0 commit comments