|
1 | 1 | from lsprotocol.types import Range, Position |
2 | 2 | import typing as t |
| 3 | +from pathlib import Path |
3 | 4 |
|
| 5 | +from sqlmesh.core.audit import StandaloneAudit |
4 | 6 | from sqlmesh.core.dialect import normalize_model_name |
5 | 7 | from sqlmesh.core.model.definition import SqlModel |
6 | 8 | from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget |
|
10 | 12 | from sqlmesh.lsp.uri import URI |
11 | 13 | from sqlmesh.utils.pydantic import PydanticModel |
12 | 14 | from sqlglot.optimizer.normalize_identifiers import normalize_identifiers |
| 15 | +import ast |
| 16 | +from sqlmesh.core.model import Model |
13 | 17 |
|
14 | 18 |
|
15 | 19 | class Reference(PydanticModel): |
@@ -72,6 +76,11 @@ def get_references( |
72 | 76 | A list of references at the given position |
73 | 77 | """ |
74 | 78 | references = get_model_definitions_for_a_path(lint_context, document_uri) |
| 79 | + |
| 80 | + # Get macro references before filtering by position |
| 81 | + macro_references = get_macro_definitions_for_a_path(lint_context, document_uri) |
| 82 | + references.extend(macro_references) |
| 83 | + |
75 | 84 | filtered_references = list(filter(by_position(position), references)) |
76 | 85 | return filtered_references |
77 | 86 |
|
@@ -287,3 +296,118 @@ def _range_from_token_position_details( |
287 | 296 | start=Position(line=start_line_0, character=start_col_0), |
288 | 297 | end=Position(line=end_line_0, character=end_col_0), |
289 | 298 | ) |
| 299 | + |
| 300 | + |
| 301 | +def get_macro_definitions_for_a_path( |
| 302 | + lsp_context: LSPContext, document_uri: URI |
| 303 | +) -> t.List[Reference]: |
| 304 | + """ |
| 305 | + Get macro references for a given path. |
| 306 | +
|
| 307 | + This function finds all macro invocations (e.g., @ADD_ONE, @MULTIPLY) in a SQL file |
| 308 | + and creates references to their definitions in the Python macro files. |
| 309 | +
|
| 310 | + Args: |
| 311 | + lsp_context: The LSP context containing macro definitions |
| 312 | + document_uri: The URI of the document to search for macro invocations |
| 313 | +
|
| 314 | + Returns: |
| 315 | + A list of Reference objects for each macro invocation found |
| 316 | + """ |
| 317 | + path = document_uri.to_path() |
| 318 | + if path.suffix != ".sql": |
| 319 | + return [] |
| 320 | + |
| 321 | + # Get the file info from the context map |
| 322 | + if path not in lsp_context.map: |
| 323 | + return [] |
| 324 | + |
| 325 | + file_info = lsp_context.map[path] |
| 326 | + # Process based on whether it's a model or standalone audit |
| 327 | + if isinstance(file_info, ModelTarget): |
| 328 | + # It's a model |
| 329 | + target: t.Optional[t.Union[Model, StandaloneAudit]] = lsp_context.context.get_model( |
| 330 | + model_or_snapshot=file_info.names[0], raise_if_missing=False |
| 331 | + ) |
| 332 | + if target is None or not isinstance(target, SqlModel): |
| 333 | + return [] |
| 334 | + query = target.query |
| 335 | + file_path = target._path |
| 336 | + elif isinstance(file_info, AuditTarget): |
| 337 | + # It's a standalone audit |
| 338 | + target = lsp_context.context.standalone_audits.get(file_info.name) |
| 339 | + if target is None: |
| 340 | + return [] |
| 341 | + query = target.query |
| 342 | + file_path = target._path |
| 343 | + else: |
| 344 | + return [] |
| 345 | + |
| 346 | + references = [] |
| 347 | + config_for_model, config_path = lsp_context.context.config_for_path( |
| 348 | + file_path, |
| 349 | + ) |
| 350 | + |
| 351 | + with open(file_path, "r", encoding="utf-8") as file: |
| 352 | + read_file = file.readlines() |
| 353 | + |
| 354 | + for node in query.find_all(exp.Anonymous): |
| 355 | + macro_name = node.name.lower() |
| 356 | + |
| 357 | + # Find the macro definition information |
| 358 | + macro_def = target.python_env.get(macro_name) |
| 359 | + if macro_def is None: |
| 360 | + continue |
| 361 | + |
| 362 | + # Get the file path where the macro is defined |
| 363 | + try: |
| 364 | + function_name = macro_def.name |
| 365 | + if not function_name: |
| 366 | + continue |
| 367 | + if not macro_def.path: |
| 368 | + continue |
| 369 | + path = Path(config_path).joinpath(macro_def.path) |
| 370 | + |
| 371 | + # Parse the Python file to find the function definition |
| 372 | + with open(path, "r") as f: |
| 373 | + tree = ast.parse(f.read()) |
| 374 | + |
| 375 | + # Find the function definition by name |
| 376 | + start_line = None |
| 377 | + end_line = None |
| 378 | + docstring = None |
| 379 | + for ast_node in ast.walk(tree): |
| 380 | + if isinstance(ast_node, ast.FunctionDef) and ast_node.name == function_name: |
| 381 | + start_line = ast_node.lineno |
| 382 | + end_line = ast_node.end_lineno |
| 383 | + # Extract docstring if present |
| 384 | + docstring = ast.get_docstring(ast_node) |
| 385 | + break |
| 386 | + |
| 387 | + if start_line is None or end_line is None: |
| 388 | + continue |
| 389 | + |
| 390 | + # Create a reference to the macro definition |
| 391 | + macro_uri = URI.from_path(path) |
| 392 | + |
| 393 | + # Get the position of the macro invocation in the source file |
| 394 | + if hasattr(node, "meta") and node.meta: |
| 395 | + token_details = TokenPositionDetails.from_meta(node.meta) |
| 396 | + macro_range = _range_from_token_position_details(token_details, read_file) |
| 397 | + |
| 398 | + references.append( |
| 399 | + Reference( |
| 400 | + uri=macro_uri.value, |
| 401 | + range=macro_range, |
| 402 | + target_range=Range( |
| 403 | + start=Position(line=start_line - 1, character=0), |
| 404 | + end=Position(line=end_line - 1, character=0), |
| 405 | + ), |
| 406 | + markdown_description=docstring, |
| 407 | + ) |
| 408 | + ) |
| 409 | + except (OSError, TypeError): |
| 410 | + # If we can't get the source file, skip this macro |
| 411 | + continue |
| 412 | + |
| 413 | + return references |
0 commit comments