Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions qwix/_src/providers/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import functools
from typing import Any, Callable, Generic, Sequence, TypeVar

from absl import logging
from flax import linen as nn
from flax import nnx
import flax.linen.dtypes
Expand Down Expand Up @@ -489,15 +489,22 @@ def quantize_params(

def get_value_from_path(obj, path: tuple[str, ...]):
for key in path:
obj = obj[key] if isinstance(obj, dict) else getattr(obj, key)
if isinstance(obj, dict):
if key not in obj:
logging.info('[debugsa]: key: %s in path: %s not in obj', key, path)
return None
obj = obj.get(key) if isinstance(obj, dict) else getattr(obj, key)
return obj

quantized_params = {}
for path, param in flax.traverse_util.flatten_dict(params).items():
if not isinstance(param, jax.Array):
raise TypeError(f'params is not a pure dict of jax.Array: {type(param)}')
abs_param = get_value_from_path(abstract_quantized_params, path)
if isinstance(abs_param, WithAux):
if abs_param is None:
logging.info('[debugsa]: abs_param is None for path: %s', path)
continue
if isinstance(abs_param, WithAux) and abs_param is not None:
# The param might not be in the shape needed for compute, in case the
# module reshapes before compute. Abstract param has the compute shape.
param = param.reshape(abs_param.shape)
Expand Down