|
15 | 15 |
|
16 | 16 | """Utilities for inspecting and working with function entrypoints.""" |
17 | 17 |
|
| 18 | +import ast |
18 | 19 | import importlib.util |
19 | 20 | import inspect |
20 | 21 | import json |
@@ -107,6 +108,93 @@ def get_function_signature_types( |
107 | 108 | return request_type, response_type, request_type_name, response_type_name |
108 | 109 |
|
109 | 110 |
|
| 111 | +def inspect_function_types_static(entrypoint_path: str) -> Tuple[Optional[str], Optional[str]]: |
| 112 | + """Inspect function types using static AST parsing (no imports). |
| 113 | +
|
| 114 | + This parses the Python file without executing it, so it doesn't |
| 115 | + require dependencies to be installed. |
| 116 | +
|
| 117 | + Args: |
| 118 | + entrypoint_path: Path to the entrypoint.py file |
| 119 | +
|
| 120 | + Returns: |
| 121 | + Tuple of (request_type_name, response_type_name) |
| 122 | + """ |
| 123 | + try: |
| 124 | + with open(entrypoint_path, 'r') as f: |
| 125 | + tree = ast.parse(f.read(), filename=entrypoint_path) |
| 126 | + |
| 127 | + # Find the 'function' definition |
| 128 | + for node in ast.walk(tree): |
| 129 | + if isinstance(node, ast.FunctionDef) and node.name == "function": |
| 130 | + # Get request type (first parameter annotation) |
| 131 | + request_type_name = None |
| 132 | + if node.args.args and len(node.args.args) > 0: |
| 133 | + first_param = node.args.args[0] |
| 134 | + if first_param.annotation: |
| 135 | + request_type_name = _get_type_name_from_ast(first_param.annotation) |
| 136 | + |
| 137 | + # Get response type (return annotation) |
| 138 | + response_type_name = None |
| 139 | + if node.returns: |
| 140 | + response_type_name = _get_type_name_from_ast(node.returns) |
| 141 | + |
| 142 | + return request_type_name, response_type_name |
| 143 | + |
| 144 | + return None, None |
| 145 | + except Exception: |
| 146 | + return None, None |
| 147 | + |
| 148 | + |
| 149 | +def _get_type_name_from_ast(annotation) -> Optional[str]: |
| 150 | + """Extract type name from an AST annotation node.""" |
| 151 | + if isinstance(annotation, ast.Name): |
| 152 | + # Simple type: MyType |
| 153 | + return annotation.id |
| 154 | + elif isinstance(annotation, ast.Attribute): |
| 155 | + # Module.Type - just return the type name |
| 156 | + return annotation.attr |
| 157 | + elif isinstance(annotation, ast.Subscript): |
| 158 | + # Generic type: List[MyType], Optional[MyType] |
| 159 | + # Return the base type name |
| 160 | + return _get_type_name_from_ast(annotation.value) |
| 161 | + return None |
| 162 | + |
| 163 | + |
| 164 | +def _import_pydantic_model(entrypoint_path: str, type_name: str) -> Optional[Any]: |
| 165 | + """Import a Pydantic model by finding its import statement. |
| 166 | +
|
| 167 | + Parses the entrypoint to find where the type is imported from, |
| 168 | + then imports just that module (not the entrypoint itself). |
| 169 | +
|
| 170 | + Args: |
| 171 | + entrypoint_path: Path to entrypoint.py |
| 172 | + type_name: Name of the type to import (e.g., "SearchIndexChunkingV1Request") |
| 173 | +
|
| 174 | + Returns: |
| 175 | + The Pydantic model class, or None if not found |
| 176 | + """ |
| 177 | + try: |
| 178 | + with open(entrypoint_path, 'r') as f: |
| 179 | + tree = ast.parse(f.read(), filename=entrypoint_path) |
| 180 | + |
| 181 | + # Find where this type is imported from |
| 182 | + for node in ast.walk(tree): |
| 183 | + if isinstance(node, ast.ImportFrom): |
| 184 | + # from module import Type1, Type2 |
| 185 | + for alias in node.names: |
| 186 | + if alias.name == type_name: |
| 187 | + # Found it! Import from the module |
| 188 | + module_name = node.module |
| 189 | + if module_name: |
| 190 | + module = importlib.import_module(module_name) |
| 191 | + return getattr(module, type_name, None) |
| 192 | + |
| 193 | + return None |
| 194 | + except Exception: |
| 195 | + return None |
| 196 | + |
| 197 | + |
110 | 198 | def inspect_function_types( |
111 | 199 | entrypoint_path: str, |
112 | 200 | ) -> Tuple[Optional[str], Optional[str]]: |
@@ -230,17 +318,29 @@ def generate_sample_value(field_type, field_name: str): |
230 | 318 | def generate_test_json(entrypoint_path: str, output_path: str) -> None: |
231 | 319 | """Generate a sample test.json file for a function. |
232 | 320 |
|
| 321 | + First tries static AST parsing to get type names, then uses those |
| 322 | + to import only the Pydantic model classes (not the entrypoint). |
| 323 | +
|
233 | 324 | Args: |
234 | 325 | entrypoint_path: Path to the function entrypoint.py |
235 | 326 | output_path: Output path for test.json |
236 | 327 |
|
237 | 328 | Raises: |
238 | | - ImportError: If the module cannot be loaded |
239 | | - AttributeError: If the function is not found |
240 | | - ValueError: If the request type is not a Pydantic model |
| 329 | + ImportError: If the Pydantic model cannot be loaded |
| 330 | + ValueError: If the request type is not found or not a Pydantic model |
241 | 331 | """ |
242 | | - # Get the request type |
243 | | - request_type = get_request_type(entrypoint_path) |
| 332 | + # First, get the type name using static parsing (no imports) |
| 333 | + request_type_name, _ = inspect_function_types_static(entrypoint_path) |
| 334 | + |
| 335 | + if not request_type_name: |
| 336 | + raise ValueError("Could not determine request type from function signature") |
| 337 | + |
| 338 | + # Now try to import the Pydantic model class |
| 339 | + # Look for it in the entrypoint's imports |
| 340 | + request_type = _import_pydantic_model(entrypoint_path, request_type_name) |
| 341 | + |
| 342 | + if not request_type: |
| 343 | + raise ValueError(f"Could not import Pydantic model: {request_type_name}") |
244 | 344 |
|
245 | 345 | # Check if it's a Pydantic model |
246 | 346 | if not hasattr(request_type, "model_fields"): |
|
0 commit comments