From 4b97c32d156a8fac53afd284fbbef64d398b8c06 Mon Sep 17 00:00:00 2001 From: Qwix Developers Date: Tue, 4 Nov 2025 08:06:55 -0800 Subject: [PATCH] Remove the ptq change, move to another cl or use model surgery PiperOrigin-RevId: 827977832 --- qwix/_src/providers/ptq.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/qwix/_src/providers/ptq.py b/qwix/_src/providers/ptq.py index e56ad31..97ccc52 100644 --- a/qwix/_src/providers/ptq.py +++ b/qwix/_src/providers/ptq.py @@ -15,7 +15,7 @@ import functools from typing import Any, Callable, Generic, Sequence, TypeVar - +from absl import logging from flax import linen as nn from flax import nnx import flax.linen.dtypes @@ -489,7 +489,11 @@ def quantize_params( def get_value_from_path(obj, path: tuple[str, ...]): for key in path: - obj = obj[key] if isinstance(obj, dict) else getattr(obj, key) + if isinstance(obj, dict): + if key not in obj: + logging.info('[debugsa]: key: %s in path: %s not in obj', key, path) + return None + obj = obj.get(key) if isinstance(obj, dict) else getattr(obj, key) return obj quantized_params = {} @@ -497,7 +501,10 @@ def get_value_from_path(obj, path: tuple[str, ...]): if not isinstance(param, jax.Array): raise TypeError(f'params is not a pure dict of jax.Array: {type(param)}') abs_param = get_value_from_path(abstract_quantized_params, path) - if isinstance(abs_param, WithAux): + if abs_param is None: + logging.info('[debugsa]: abs_param is None for path: %s', path) + continue + if isinstance(abs_param, WithAux) and abs_param is not None: # The param might not be in the shape needed for compute, in case the # module reshapes before compute. Abstract param has the compute shape. param = param.reshape(abs_param.shape)