diff --git a/qwix/_src/providers/qt.py b/qwix/_src/providers/qt.py index 40e0790..1f0d0cd 100644 --- a/qwix/_src/providers/qt.py +++ b/qwix/_src/providers/qt.py @@ -207,7 +207,7 @@ def ragged_dot( group_offset: jax.Array | None = None, ) -> jax.Array: """QT ragged_dot.""" - rule, op_id = self._get_current_rule_and_op_id('ragged_dot') + rule, _ = self._get_current_rule_and_op_id('ragged_dot') if rule is None or rule.weight_qtype is None: return jax.lax.ragged_dot( lhs, @@ -217,7 +217,7 @@ def ragged_dot( preferred_element_type=preferred_element_type, group_offset=group_offset, ) - config = self._create_dot_general_qt_config(rule, op_id, lhs, rhs) + config = self._create_ragged_dot_qt_config(rule) return ragged_dot_qt.ragged_dot_qt( lhs, rhs,