From 97b47b29c6d134e3570db6d561b6b382a2ebf270 Mon Sep 17 00:00:00 2001 From: init2winit Team Date: Wed, 25 Mar 2026 10:06:40 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 889305053 --- init2winit/model_lib/deepspeech.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/init2winit/model_lib/deepspeech.py b/init2winit/model_lib/deepspeech.py index 2d489e8e..3c4fba49 100644 --- a/init2winit/model_lib/deepspeech.py +++ b/init2winit/model_lib/deepspeech.py @@ -343,7 +343,7 @@ def __call__(self, inputs): var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True) normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) - normed_inputs *= (1 + self.scale) + normed_inputs *= 1 + self.scale normed_inputs += self.bias return normed_inputs @@ -963,6 +963,17 @@ def evaluate_batch(self, params, batch_stats, batch): targets=labels, target_paddings=label_paddings) + def apply_on_batch(self, params, batch_stats, batch, **apply_kwargs): + """Wrapper around flax_module.apply.""" + return self.flax_module.apply( + { + 'params': params, + 'batch_stats': batch_stats + }, + batch['inputs'], + batch['input_paddings'], + **apply_kwargs) + def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): """Return CTC loss."""