Skip to content

Commit 3935e1e

Browse files
authored
precommit check modelopt recipes (#1218)
### What does this PR do? Add pre-commit hook to validate modelopt recipes - Adds a pre-commit hook (check-modelopt-recipes) that validates recipe YAML files under modelopt_recipes/ on commit. - Checks that quant_cfg uses the list-of-dicts format (rejects legacy dict format), PTQ recipes use quantize as the top-level key (not ptq_cfg), and recipes load successfully via load_recipe(). - Scopes validation to PTQ recipes only for now — non-PTQ YAML files (e.g. speculative decoding training configs like eagle3.yaml) are silently skipped. Will defer to future. ### Testing - Modify a PTQ recipe to use the legacy quant_cfg dict format → hook should reject - Modify a PTQ recipe to use ptq_cfg instead of quantize → hook should reject - Commit a valid PTQ recipe change → hook should pass - Commit a non-PTQ YAML file (e.g. eagle3.yaml) → hook should skip it without errors <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Added a pre-commit validation that checks ModelOpt PTQ recipe files for correct structure and schema, rejects legacy/incorrect quantization config formats, enforces proper top-level keys, and surfaces clear error messages to block commits with invalid recipes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent 14e9bea commit 3935e1e

2 files changed

Lines changed: 218 additions & 0 deletions

File tree

.pre-commit-config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ repos:
5555
args: [--mapping=2, --sequence=4, --offset=2, --implicit_start, --implicit_end, --preserve-quotes]
5656
exclude: ^.github/workflows/
5757

58+
- repo: local
59+
hooks:
60+
- id: check-modelopt-recipes
61+
name: validate modelopt recipes
62+
entry: python tools/precommit/check_modelopt_recipes.py
63+
language: system
64+
files: ^modelopt_recipes/
65+
5866
# Instructions to change license file if ever needed:
5967
# https://github.com/Lucas-C/pre-commit-hooks#removing-old-license-and-replacing-it-with-a-new-one
6068
- repo: https://github.com/Lucas-C/pre-commit-hooks
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Pre-commit hook: validate ModelOpt recipes.
17+
18+
Pre-commit passes changed file paths as arguments. This script resolves each
19+
file to its parent recipe (single-file or directory format), deduplicates, and
20+
validates each recipe exactly once.
21+
22+
Checks performed:
23+
24+
1. ``quant_cfg`` must use the list-of-dicts format with explicit
25+
``quantizer_name`` keys (legacy dict format is rejected).
26+
2. PTQ recipes must use ``quantize`` as the top-level key
27+
(not ``ptq_cfg`` or other variants).
28+
3. Each recipe is loaded via ``load_recipe()`` to catch structural and
29+
validation errors (skipped if modelopt is not installed).
30+
"""
31+
32+
from __future__ import annotations
33+
34+
import sys
35+
from pathlib import Path
36+
37+
import yaml
38+
39+
40+
def _check_quant_cfg(quant_cfg, label: str) -> list[str]:
41+
"""Validate quant_cfg format. *label* is used in error messages."""
42+
errors: list[str] = []
43+
if isinstance(quant_cfg, dict):
44+
errors.append(
45+
f"{label}: quant_cfg uses the legacy dict format. "
46+
"Use the list-of-dicts format with explicit 'quantizer_name' keys instead. "
47+
"See https://nvidia.github.io/Model-Optimizer/guides/_quant_cfg.html for the format specification."
48+
)
49+
elif isinstance(quant_cfg, list):
50+
for i, entry in enumerate(quant_cfg):
51+
if not isinstance(entry, dict):
52+
errors.append(
53+
f"{label}: quant_cfg[{i}] must be a dict with "
54+
f"'quantizer_name', got {type(entry).__name__}. "
55+
"See https://nvidia.github.io/Model-Optimizer/guides/_quant_cfg.html"
56+
)
57+
continue
58+
if "quantizer_name" not in entry:
59+
errors.append(
60+
f"{label}: quant_cfg[{i}] is missing 'quantizer_name'. "
61+
"Each entry must have an explicit 'quantizer_name' key. "
62+
"See https://nvidia.github.io/Model-Optimizer/guides/_quant_cfg.html"
63+
)
64+
return errors
65+
66+
67+
def _load_yaml(path: Path) -> dict | None:
68+
"""Load a YAML file, returning None on parse failure."""
69+
try:
70+
data = yaml.safe_load(path.read_text(encoding="utf-8"))
71+
except Exception:
72+
return None
73+
return data if isinstance(data, dict) else None
74+
75+
76+
def _check_single_file_recipe(path: Path) -> list[str]:
77+
"""Check a single-file recipe (metadata + quantize in one file)."""
78+
errors: list[str] = []
79+
label = str(path)
80+
data = _load_yaml(path)
81+
if data is None:
82+
return [f"{label}: failed to parse YAML"]
83+
84+
metadata = data.get("metadata")
85+
if not isinstance(metadata, dict) or "recipe_type" not in metadata:
86+
return [] # not a recipe file
87+
88+
if "ptq_cfg" in data:
89+
errors.append(
90+
f"{label}: uses 'ptq_cfg' as the top-level key. "
91+
"PTQ recipes must use 'quantize' instead."
92+
)
93+
if "quantize" in data:
94+
quant_section = data["quantize"]
95+
elif "ptq_cfg" in data:
96+
quant_section = data["ptq_cfg"]
97+
else:
98+
return errors
99+
100+
if isinstance(quant_section, dict):
101+
quant_cfg = quant_section.get("quant_cfg")
102+
if quant_cfg is not None:
103+
errors.extend(_check_quant_cfg(quant_cfg, label))
104+
105+
return errors
106+
107+
108+
def _check_dir_recipe(dir_path: Path) -> list[str]:
109+
"""Check a directory-format recipe (recipe.yml + quantize.yml)."""
110+
errors: list[str] = []
111+
112+
for name in ("quantize.yml", "quantize.yaml"):
113+
quantize_file = dir_path / name
114+
if quantize_file.is_file():
115+
data = _load_yaml(quantize_file)
116+
if data is not None:
117+
quant_cfg = data.get("quant_cfg")
118+
if quant_cfg is not None:
119+
errors.extend(_check_quant_cfg(quant_cfg, str(quantize_file)))
120+
break
121+
122+
return errors
123+
124+
125+
def _try_load_recipe(path: str) -> list[str]:
126+
"""Try loading a recipe via modelopt; return errors or []."""
127+
try:
128+
from modelopt.recipe.loader import load_recipe
129+
except ImportError:
130+
return [] # modelopt not installed, skip
131+
132+
try:
133+
load_recipe(path)
134+
except Exception as exc:
135+
return [f"{path}: recipe failed to load: {exc}"]
136+
return []
137+
138+
139+
def _is_dir_recipe(dir_path: Path) -> bool:
140+
"""Return True if *dir_path* is a directory-format recipe."""
141+
return any((dir_path / n).is_file() for n in ("recipe.yml", "recipe.yaml"))
142+
143+
144+
def _is_recipe_file(path: Path) -> bool:
145+
"""Return True if *path* looks like a recipe file that should be validated.
146+
147+
Currently only PTQ recipes are checked; other recipe types (e.g. QAT) can
148+
be added here in the future.
149+
150+
Malformed or unparseable files return True so that ``load_recipe()`` can
151+
report the actual error.
152+
"""
153+
data = _load_yaml(path)
154+
if data is None:
155+
return True # let load_recipe report the parse error
156+
metadata = data.get("metadata")
157+
if not isinstance(metadata, dict) or "recipe_type" not in metadata:
158+
return False # not a recipe file at all
159+
return metadata["recipe_type"] == "ptq"
160+
161+
162+
def _resolve_recipes(changed_files: list[str]) -> dict[Path, str]:
163+
"""Resolve changed files to recipes. Returns {recipe_path: kind} mapping.
164+
165+
Non-recipe YAML files are silently skipped.
166+
kind is "file" for single-file recipes or "dir" for directory-format recipes.
167+
"""
168+
recipes: dict[Path, str] = {}
169+
for f in changed_files:
170+
path = Path(f)
171+
172+
# Check if this file is inside a directory-format recipe.
173+
if _is_dir_recipe(path.parent):
174+
# Directory recipes have a recipe.yml with metadata; check it.
175+
for name in ("recipe.yml", "recipe.yaml"):
176+
candidate = path.parent / name
177+
if candidate.is_file() and _is_recipe_file(candidate):
178+
recipes.setdefault(path.parent, "dir")
179+
break
180+
elif path.is_file() and path.suffix in (".yml", ".yaml"):
181+
if _is_recipe_file(path):
182+
recipes.setdefault(path, "file")
183+
184+
return recipes
185+
186+
187+
def main() -> int:
188+
"""Validate changed recipes passed as CLI args, exit 1 on errors."""
189+
recipes = _resolve_recipes(sys.argv[1:])
190+
errors: list[str] = []
191+
192+
for recipe_path, kind in recipes.items():
193+
if kind == "dir":
194+
recipe_errors = _check_dir_recipe(recipe_path)
195+
else:
196+
recipe_errors = _check_single_file_recipe(recipe_path)
197+
198+
errors.extend(recipe_errors)
199+
if not recipe_errors:
200+
errors.extend(_try_load_recipe(str(recipe_path)))
201+
202+
if errors:
203+
for e in errors:
204+
print(f"ERROR: {e}", file=sys.stderr)
205+
return 1
206+
return 0
207+
208+
209+
if __name__ == "__main__":
210+
raise SystemExit(main())

0 commit comments

Comments
 (0)