|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Pre-commit hook: validate ModelOpt recipes. |
| 17 | +
|
| 18 | +Pre-commit passes changed file paths as arguments. This script resolves each |
| 19 | +file to its parent recipe (single-file or directory format), deduplicates, and |
| 20 | +validates each recipe exactly once. |
| 21 | +
|
| 22 | +Checks performed: |
| 23 | +
|
| 24 | +1. ``quant_cfg`` must use the list-of-dicts format with explicit |
| 25 | + ``quantizer_name`` keys (legacy dict format is rejected). |
| 26 | +2. PTQ recipes must use ``quantize`` as the top-level key |
| 27 | + (not ``ptq_cfg`` or other variants). |
| 28 | +3. Each recipe is loaded via ``load_recipe()`` to catch structural and |
| 29 | + validation errors (skipped if modelopt is not installed). |
| 30 | +""" |
| 31 | + |
| 32 | +from __future__ import annotations |
| 33 | + |
| 34 | +import sys |
| 35 | +from pathlib import Path |
| 36 | + |
| 37 | +import yaml |
| 38 | + |
| 39 | + |
| 40 | +def _check_quant_cfg(quant_cfg, label: str) -> list[str]: |
| 41 | + """Validate quant_cfg format. *label* is used in error messages.""" |
| 42 | + errors: list[str] = [] |
| 43 | + if isinstance(quant_cfg, dict): |
| 44 | + errors.append( |
| 45 | + f"{label}: quant_cfg uses the legacy dict format. " |
| 46 | + "Use the list-of-dicts format with explicit 'quantizer_name' keys instead. " |
| 47 | + "See https://nvidia.github.io/Model-Optimizer/guides/_quant_cfg.html for the format specification." |
| 48 | + ) |
| 49 | + elif isinstance(quant_cfg, list): |
| 50 | + for i, entry in enumerate(quant_cfg): |
| 51 | + if not isinstance(entry, dict): |
| 52 | + errors.append( |
| 53 | + f"{label}: quant_cfg[{i}] must be a dict with " |
| 54 | + f"'quantizer_name', got {type(entry).__name__}. " |
| 55 | + "See https://nvidia.github.io/Model-Optimizer/guides/_quant_cfg.html" |
| 56 | + ) |
| 57 | + continue |
| 58 | + if "quantizer_name" not in entry: |
| 59 | + errors.append( |
| 60 | + f"{label}: quant_cfg[{i}] is missing 'quantizer_name'. " |
| 61 | + "Each entry must have an explicit 'quantizer_name' key. " |
| 62 | + "See https://nvidia.github.io/Model-Optimizer/guides/_quant_cfg.html" |
| 63 | + ) |
| 64 | + return errors |
| 65 | + |
| 66 | + |
| 67 | +def _load_yaml(path: Path) -> dict | None: |
| 68 | + """Load a YAML file, returning None on parse failure.""" |
| 69 | + try: |
| 70 | + data = yaml.safe_load(path.read_text(encoding="utf-8")) |
| 71 | + except Exception: |
| 72 | + return None |
| 73 | + return data if isinstance(data, dict) else None |
| 74 | + |
| 75 | + |
| 76 | +def _check_single_file_recipe(path: Path) -> list[str]: |
| 77 | + """Check a single-file recipe (metadata + quantize in one file).""" |
| 78 | + errors: list[str] = [] |
| 79 | + label = str(path) |
| 80 | + data = _load_yaml(path) |
| 81 | + if data is None: |
| 82 | + return [f"{label}: failed to parse YAML"] |
| 83 | + |
| 84 | + metadata = data.get("metadata") |
| 85 | + if not isinstance(metadata, dict) or "recipe_type" not in metadata: |
| 86 | + return [] # not a recipe file |
| 87 | + |
| 88 | + if "ptq_cfg" in data: |
| 89 | + errors.append( |
| 90 | + f"{label}: uses 'ptq_cfg' as the top-level key. " |
| 91 | + "PTQ recipes must use 'quantize' instead." |
| 92 | + ) |
| 93 | + if "quantize" in data: |
| 94 | + quant_section = data["quantize"] |
| 95 | + elif "ptq_cfg" in data: |
| 96 | + quant_section = data["ptq_cfg"] |
| 97 | + else: |
| 98 | + return errors |
| 99 | + |
| 100 | + if isinstance(quant_section, dict): |
| 101 | + quant_cfg = quant_section.get("quant_cfg") |
| 102 | + if quant_cfg is not None: |
| 103 | + errors.extend(_check_quant_cfg(quant_cfg, label)) |
| 104 | + |
| 105 | + return errors |
| 106 | + |
| 107 | + |
| 108 | +def _check_dir_recipe(dir_path: Path) -> list[str]: |
| 109 | + """Check a directory-format recipe (recipe.yml + quantize.yml).""" |
| 110 | + errors: list[str] = [] |
| 111 | + |
| 112 | + for name in ("quantize.yml", "quantize.yaml"): |
| 113 | + quantize_file = dir_path / name |
| 114 | + if quantize_file.is_file(): |
| 115 | + data = _load_yaml(quantize_file) |
| 116 | + if data is not None: |
| 117 | + quant_cfg = data.get("quant_cfg") |
| 118 | + if quant_cfg is not None: |
| 119 | + errors.extend(_check_quant_cfg(quant_cfg, str(quantize_file))) |
| 120 | + break |
| 121 | + |
| 122 | + return errors |
| 123 | + |
| 124 | + |
| 125 | +def _try_load_recipe(path: str) -> list[str]: |
| 126 | + """Try loading a recipe via modelopt; return errors or [].""" |
| 127 | + try: |
| 128 | + from modelopt.recipe.loader import load_recipe |
| 129 | + except ImportError: |
| 130 | + return [] # modelopt not installed, skip |
| 131 | + |
| 132 | + try: |
| 133 | + load_recipe(path) |
| 134 | + except Exception as exc: |
| 135 | + return [f"{path}: recipe failed to load: {exc}"] |
| 136 | + return [] |
| 137 | + |
| 138 | + |
| 139 | +def _is_dir_recipe(dir_path: Path) -> bool: |
| 140 | + """Return True if *dir_path* is a directory-format recipe.""" |
| 141 | + return any((dir_path / n).is_file() for n in ("recipe.yml", "recipe.yaml")) |
| 142 | + |
| 143 | + |
| 144 | +def _is_recipe_file(path: Path) -> bool: |
| 145 | + """Return True if *path* looks like a recipe file that should be validated. |
| 146 | +
|
| 147 | + Currently only PTQ recipes are checked; other recipe types (e.g. QAT) can |
| 148 | + be added here in the future. |
| 149 | +
|
| 150 | + Malformed or unparseable files return True so that ``load_recipe()`` can |
| 151 | + report the actual error. |
| 152 | + """ |
| 153 | + data = _load_yaml(path) |
| 154 | + if data is None: |
| 155 | + return True # let load_recipe report the parse error |
| 156 | + metadata = data.get("metadata") |
| 157 | + if not isinstance(metadata, dict) or "recipe_type" not in metadata: |
| 158 | + return False # not a recipe file at all |
| 159 | + return metadata["recipe_type"] == "ptq" |
| 160 | + |
| 161 | + |
| 162 | +def _resolve_recipes(changed_files: list[str]) -> dict[Path, str]: |
| 163 | + """Resolve changed files to recipes. Returns {recipe_path: kind} mapping. |
| 164 | +
|
| 165 | + Non-recipe YAML files are silently skipped. |
| 166 | + kind is "file" for single-file recipes or "dir" for directory-format recipes. |
| 167 | + """ |
| 168 | + recipes: dict[Path, str] = {} |
| 169 | + for f in changed_files: |
| 170 | + path = Path(f) |
| 171 | + |
| 172 | + # Check if this file is inside a directory-format recipe. |
| 173 | + if _is_dir_recipe(path.parent): |
| 174 | + # Directory recipes have a recipe.yml with metadata; check it. |
| 175 | + for name in ("recipe.yml", "recipe.yaml"): |
| 176 | + candidate = path.parent / name |
| 177 | + if candidate.is_file() and _is_recipe_file(candidate): |
| 178 | + recipes.setdefault(path.parent, "dir") |
| 179 | + break |
| 180 | + elif path.is_file() and path.suffix in (".yml", ".yaml"): |
| 181 | + if _is_recipe_file(path): |
| 182 | + recipes.setdefault(path, "file") |
| 183 | + |
| 184 | + return recipes |
| 185 | + |
| 186 | + |
| 187 | +def main() -> int: |
| 188 | + """Validate changed recipes passed as CLI args, exit 1 on errors.""" |
| 189 | + recipes = _resolve_recipes(sys.argv[1:]) |
| 190 | + errors: list[str] = [] |
| 191 | + |
| 192 | + for recipe_path, kind in recipes.items(): |
| 193 | + if kind == "dir": |
| 194 | + recipe_errors = _check_dir_recipe(recipe_path) |
| 195 | + else: |
| 196 | + recipe_errors = _check_single_file_recipe(recipe_path) |
| 197 | + |
| 198 | + errors.extend(recipe_errors) |
| 199 | + if not recipe_errors: |
| 200 | + errors.extend(_try_load_recipe(str(recipe_path))) |
| 201 | + |
| 202 | + if errors: |
| 203 | + for e in errors: |
| 204 | + print(f"ERROR: {e}", file=sys.stderr) |
| 205 | + return 1 |
| 206 | + return 0 |
| 207 | + |
| 208 | + |
| 209 | +if __name__ == "__main__": |
| 210 | + raise SystemExit(main()) |
0 commit comments