Skip to content

Commit 4b5f94f

Browse files
init2winit Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 868681655
1 parent 176a018 commit 4b5f94f

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

init2winit/model_lib/deepspeech.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def __call__(self, inputs):
343343
var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)
344344

345345
normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
346-
normed_inputs *= (1 + self.scale)
346+
normed_inputs *= 1 + self.scale
347347
normed_inputs += self.bias
348348

349349
return normed_inputs
@@ -963,6 +963,17 @@ def evaluate_batch(self, params, batch_stats, batch):
963963
targets=labels,
964964
target_paddings=label_paddings)
965965

966+
def apply_on_batch(self, params, batch_stats, batch, **apply_kwargs):
967+
"""Wrapper around flax_module.apply."""
968+
return self.flax_module.apply(
969+
{
970+
'params': params,
971+
'batch_stats': batch_stats
972+
},
973+
batch['inputs'],
974+
batch['input_paddings'],
975+
**apply_kwargs)
976+
966977
def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
967978
"""Return CTC loss."""
968979

0 commit comments

Comments
 (0)