diff --git a/extension/scripts/inf_qwen_32b_fp8.py b/extension/scripts/inf_qwen_32b_fp8.py new file mode 100644 index 0000000000..44c0231130 --- /dev/null +++ b/extension/scripts/inf_qwen_32b_fp8.py @@ -0,0 +1,46 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +# model_name = "jerryzh168/Qwen3-8B-FP8" +model_name= "/mnt/raid0/pretrained_model/pytorch/Qwen3-32B-FP8/" + +# load the tokenizer and the model +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="auto" +) + +# prepare the model input +prompt = "Give me a short introduction to large language model." +messages = [ + {"role": "user", "content": prompt} +] +text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True # Switches between thinking and non-thinking modes. Default is True. +) +model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + +# conduct text completion +generated_ids = model.generate( + **model_inputs, + max_new_tokens=32768 +) +output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() + +# parsing thinking content +try: + # rindex finding 151668 () + index = len(output_ids) - output_ids[::-1].index(151668) +except ValueError: + index = 0 + +thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("") +content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("") + +print("thinking content:", thinking_content) +print("content:", content) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 1b7584cc77..8a81a490da 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -45,6 +45,7 @@ fill_defaults, is_sm_at_least_90, is_sm_at_least_100, + is_MI300, ) if _is_mslk_available(): @@ -78,6 +79,14 @@ class QuantizeTensorToFloat8Kwargs(QuantizeTensorKwargs): hp_value_ub: Optional[float] = None kernel_preference: KernelPreference = KernelPreference.AUTO +def e4m3fn_to_e4m3fnuz(t: torch.Tensor, t_scale: torch.Tensor): + ROCM_FP8_NAN_AS_INT = -128 + t_as_int8 = t.view(torch.int8) + t_as_int8[t_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + t = t_as_int8.view(torch.float8_e4m3fnuz) + + t_scale = t_scale * 2.0 + return t, t_scale class Float8Tensor(TorchAOBaseTensor): """ @@ -145,6 +154,10 @@ def __init__( self.act_quant_kwargs = act_quant_kwargs self.kernel_preference = kernel_preference + if torch.version.hip and is_MI300(): + if self.qdata.dtype == torch.float8_e4m3fn: + self.qdata, self.scale = e4m3fn_to_e4m3fnuz(self.qdata, self.scale) + def __repr__(self): return ( f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, "