Skip to content

Commit f660a93

Browse files
support Qwen3.5 quantization
1 parent 0357cb9 commit f660a93

7 files changed

Lines changed: 106 additions & 1 deletion

File tree

examples/llm_ptq/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
109109
| Gemma 3 | ✅<sup>2</sup> | - || - | - |
110110
| QWen 2, 2.5 <sup>4</sup> ||||||
111111
| QWen3, 3.5 MOE, Next <sup>6</sup> || - | - | - ||
112+
| QWen3.5 <sup>6</sup> || - || - | - |
112113
| QwQ || - | - | - ||
113114
| DeepSeek V3, R1, V3.1, V3.2<sup>7</sup> | - | - | - | - ||
114115
| GLM-4.7<sup>8</sup> || - | - | - ||

examples/llm_ptq/example_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ def build_quant_cfg(
252252
quant_cfg["quant_cfg"].append({"quantizer_name": "*image*", "enable": False})
253253
quant_cfg["quant_cfg"].append({"quantizer_name": "*vision*", "enable": False})
254254

255+
if model_type == "qwen3_5":
256+
# GatedDeltaNet's in_proj_b and in_proj_a have very narrow output dimensions
257+
# (hidden_size -> num_v_heads, e.g. 1024 -> 16), quantizing them causes accuracy loss.
258+
quant_cfg["quant_cfg"].append({"quantizer_name": "*in_proj_b*", "enable": False})
259+
quant_cfg["quant_cfg"].append({"quantizer_name": "*in_proj_a*", "enable": False})
260+
255261
return quant_cfg
256262

257263

examples/vlm_ptq/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Please refer to the [llm_ptq/README.md](../llm_ptq/README.md#getting-started) fo
3838
| VILA ||||| - |
3939
| Phi-3-vision, Phi-4-multimodal ||||||
4040
| Qwen2, 2.5-VL ||||||
41+
| Qwen3.5 || - || - | - |
4142
| Gemma3 || - | - | - | - |
4243

4344
> *<sup>1.</sup>Only TensorRT-LLM checkpoint export is supported. Not compatible with the TensorRT-LLM torch backend* \

modelopt/torch/export/model_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
"MPT": "mpt",
3030
"Bloom": "bloom",
3131
"ChatGLM": "chatglm",
32+
"Qwen3_5Moe": "qwen3_5moe",
33+
"Qwen3_5": "qwen3_5",
3234
"Qwen3Moe": "qwen3moe",
3335
"Qwen3Next": "qwen3next",
3436
"QWen": "qwen",

modelopt/torch/export/quant_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,7 @@ def _update_svdquant(modules, new_pre_quant_scale):
12211221
# Mathematical equivalence:
12221222
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
12231223
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
1224-
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")),
1224+
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP", "Qwen3_5MLP"], ("up_proj", "down_proj")),
12251225
]
12261226

12271227

tests/_test_utils/torch/transformers_models.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
T5ForConditionalGeneration,
3636
)
3737

38+
try:
39+
from transformers import Qwen3_5TextConfig
40+
except ImportError:
41+
Qwen3_5TextConfig = None
42+
3843
import modelopt.torch.opt as mto
3944

4045
SEED = 1234
@@ -117,6 +122,37 @@ def create_tiny_qwen3_moe_dir(
117122
get_tiny_qwen3_moe(**config_kwargs).save_pretrained(qwen3_moe_dir)
118123
return qwen3_moe_dir
119124

125+
##### Qwen3.5 (hybrid linear attention + full attention) #####
126+
def get_tiny_qwen3_5(**config_kwargs) -> PreTrainedModel:
127+
if Qwen3_5TextConfig is None:
128+
pytest.skip("Qwen3_5TextConfig not available (requires transformers >= 4.57)")
129+
130+
set_seed(SEED)
131+
132+
kwargs = {
133+
"dtype": torch.bfloat16,
134+
"hidden_size": 32,
135+
"intermediate_size": 32,
136+
"num_hidden_layers": 4,
137+
"num_attention_heads": 2,
138+
"num_key_value_heads": 1,
139+
"head_dim": 16,
140+
"linear_num_key_heads": 4,
141+
"linear_num_value_heads": 4,
142+
"linear_key_head_dim": 8,
143+
"linear_value_head_dim": 8,
144+
"linear_conv_kernel_dim": 4,
145+
"full_attention_interval": 4,
146+
"attn_output_gate": True,
147+
"max_position_embeddings": 32,
148+
"vocab_size": 32,
149+
"rms_norm_eps": 1e-6,
150+
}
151+
kwargs.update(**config_kwargs)
152+
tiny_qwen3_5 = AutoModelForCausalLM.from_config(Qwen3_5TextConfig(**kwargs))
153+
154+
return tiny_qwen3_5
155+
120156

121157
##### GPT-OSS #####
122158
def get_tiny_gpt_oss(**config_kwargs) -> PreTrainedModel:

tests/unit/torch/quantization/plugins/test_huggingface.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
create_tiny_llama_dir,
2525
get_tiny_gpt_oss,
2626
get_tiny_llama,
27+
get_tiny_qwen3_5,
2728
get_tiny_qwen3_moe,
2829
tf_modelopt_state_and_output_tester,
2930
)
@@ -243,3 +244,61 @@ def test_hf_decoder_discoverer_registration_path():
243244
assert LayerActivationCollector.get_decoder_layers(model) is get_homogeneous_hf_decoder_layers(
244245
model
245246
)
247+
248+
249+
@pytest.mark.parametrize(
250+
"quant_config",
251+
[mtq.FP8_DEFAULT_CFG, mtq.INT4_AWQ_CFG],
252+
ids=["fp8", "int4_awq"],
253+
)
254+
def test_qwen3_5_hybrid_attention_quantize(quant_config):
255+
"""Verify FP8 and AWQ quantization works for Qwen3.5 hybrid (GatedDeltaNet + Attention)."""
256+
import copy
257+
258+
model = get_tiny_qwen3_5()
259+
260+
quant_cfg = copy.deepcopy(quant_config)
261+
if quant_config is mtq.INT4_AWQ_CFG:
262+
for entry in quant_cfg["quant_cfg"]:
263+
if entry["quantizer_name"] == "*weight_quantizer":
264+
entry.setdefault("cfg", {})["block_sizes"] = {-1: 16}
265+
break
266+
267+
# Disable narrow GatedDeltaNet projections (same as example_utils does for qwen3_5)
268+
quant_cfg["quant_cfg"].append({"quantizer_name": "*in_proj_b*", "enable": False})
269+
quant_cfg["quant_cfg"].append({"quantizer_name": "*in_proj_a*", "enable": False})
270+
271+
def calib_fn(model):
272+
x = model.dummy_inputs["input_ids"]
273+
for _ in range(2):
274+
model(x)
275+
276+
mtq.quantize(model, quant_cfg, calib_fn)
277+
278+
# Verify the model still produces output
279+
with torch.no_grad():
280+
out = model(model.dummy_inputs["input_ids"])
281+
assert out.logits is not None
282+
283+
# Verify both GatedDeltaNet and Attention linear layers got quantized
284+
has_gdn_quantized = False
285+
has_attn_quantized = False
286+
for name, module in model.named_modules():
287+
if hasattr(module, "weight_quantizer") and hasattr(module, "weight"):
288+
if "linear_attn.in_proj_qkv" in name:
289+
has_gdn_quantized = True
290+
if "self_attn.q_proj" in name:
291+
has_attn_quantized = True
292+
assert has_gdn_quantized, "GatedDeltaNet linear layers should be quantized"
293+
assert has_attn_quantized, "Attention linear layers should be quantized"
294+
295+
# Verify narrow projections are NOT quantized
296+
for name, module in model.named_modules():
297+
if "in_proj_b" in name and hasattr(module, "weight_quantizer"):
298+
assert not module.weight_quantizer.is_enabled, (
299+
f"in_proj_b should have quantization disabled: {name}"
300+
)
301+
if "in_proj_a" in name and hasattr(module, "weight_quantizer"):
302+
assert not module.weight_quantizer.is_enabled, (
303+
f"in_proj_a should have quantization disabled: {name}"
304+
)

0 commit comments

Comments
 (0)