|
1 | 1 | from lsprotocol.types import Range, Position |
2 | 2 | import typing as t |
| 3 | +from pathlib import Path |
3 | 4 |
|
| 5 | +from sqlmesh.core.audit import StandaloneAudit |
4 | 6 | from sqlmesh.core.dialect import normalize_model_name |
5 | 7 | from sqlmesh.core.model.definition import SqlModel |
6 | 8 | from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget |
|
10 | 12 | from sqlmesh.lsp.uri import URI |
11 | 13 | from sqlmesh.utils.pydantic import PydanticModel |
12 | 14 | from sqlglot.optimizer.normalize_identifiers import normalize_identifiers |
| 15 | +import ast |
| 16 | +from sqlmesh.core.model import Model |
| 17 | +from sqlmesh import macro |
| 18 | +import inspect |
13 | 19 |
|
14 | 20 |
|
15 | 21 | class Reference(PydanticModel): |
@@ -72,6 +78,11 @@ def get_references( |
72 | 78 | A list of references at the given position |
73 | 79 | """ |
74 | 80 | references = get_model_definitions_for_a_path(lint_context, document_uri) |
| 81 | + |
| 82 | + # Get macro references before filtering by position |
| 83 | + macro_references = get_macro_definitions_for_a_path(lint_context, document_uri) |
| 84 | + references.extend(macro_references) |
| 85 | + |
75 | 86 | filtered_references = list(filter(by_position(position), references)) |
76 | 87 | return filtered_references |
77 | 88 |
|
@@ -290,3 +301,180 @@ def _range_from_token_position_details( |
290 | 301 | start=Position(line=start_line_0, character=start_col_0), |
291 | 302 | end=Position(line=end_line_0, character=end_col_0), |
292 | 303 | ) |
| 304 | + |
| 305 | + |
| 306 | +def get_macro_definitions_for_a_path( |
| 307 | + lsp_context: LSPContext, document_uri: URI |
| 308 | +) -> t.List[Reference]: |
| 309 | + """ |
| 310 | + Get macro references for a given path. |
| 311 | +
|
| 312 | + This function finds all macro invocations (e.g., @ADD_ONE, @MULTIPLY) in a SQL file |
| 313 | + and creates references to their definitions in the Python macro files. |
| 314 | +
|
| 315 | + Args: |
| 316 | + lsp_context: The LSP context containing macro definitions |
| 317 | + document_uri: The URI of the document to search for macro invocations |
| 318 | +
|
| 319 | + Returns: |
| 320 | + A list of Reference objects for each macro invocation found |
| 321 | + """ |
| 322 | + path = document_uri.to_path() |
| 323 | + if path.suffix != ".sql": |
| 324 | + return [] |
| 325 | + |
| 326 | + # Get the file info from the context map |
| 327 | + if path not in lsp_context.map: |
| 328 | + return [] |
| 329 | + |
| 330 | + file_info = lsp_context.map[path] |
| 331 | + # Process based on whether it's a model or standalone audit |
| 332 | + if isinstance(file_info, ModelTarget): |
| 333 | + # It's a model |
| 334 | + target: t.Optional[t.Union[Model, StandaloneAudit]] = lsp_context.context.get_model( |
| 335 | + model_or_snapshot=file_info.names[0], raise_if_missing=False |
| 336 | + ) |
| 337 | + if target is None or not isinstance(target, SqlModel): |
| 338 | + return [] |
| 339 | + query = target.query |
| 340 | + file_path = target._path |
| 341 | + elif isinstance(file_info, AuditTarget): |
| 342 | + # It's a standalone audit |
| 343 | + target = lsp_context.context.standalone_audits.get(file_info.name) |
| 344 | + if target is None: |
| 345 | + return [] |
| 346 | + query = target.query |
| 347 | + file_path = target._path |
| 348 | + else: |
| 349 | + return [] |
| 350 | + |
| 351 | + references = [] |
| 352 | + config_for_model, config_path = lsp_context.context.config_for_path( |
| 353 | + file_path, |
| 354 | + ) |
| 355 | + |
| 356 | + with open(file_path, "r", encoding="utf-8") as file: |
| 357 | + read_file = file.readlines() |
| 358 | + |
| 359 | + for node in query.find_all(exp.Anonymous): |
| 360 | + macro_name = node.name.lower() |
| 361 | + reference = get_macro_reference( |
| 362 | + node=node, |
| 363 | + target=target, |
| 364 | + read_file=read_file, |
| 365 | + config_path=config_path, |
| 366 | + macro_name=macro_name, |
| 367 | + ) |
| 368 | + if reference is not None: |
| 369 | + references.append(reference) |
| 370 | + |
| 371 | + return references |
| 372 | + |
| 373 | + |
| 374 | +def get_macro_reference( |
| 375 | + target: t.Union[Model, StandaloneAudit], |
| 376 | + read_file: t.List[str], |
| 377 | + config_path: t.Optional[Path], |
| 378 | + node: exp.Expression, |
| 379 | + macro_name: str, |
| 380 | +) -> t.Optional[Reference]: |
| 381 | + # Get the file path where the macro is defined |
| 382 | + try: |
| 383 | + # Get the position of the macro invocation in the source file first |
| 384 | + if hasattr(node, "meta") and node.meta: |
| 385 | + token_details = TokenPositionDetails.from_meta(node.meta) |
| 386 | + macro_range = _range_from_token_position_details(token_details, read_file) |
| 387 | + |
| 388 | + # Check if it's a built-in method |
| 389 | + if builtin := get_built_in_macro_reference(macro_name, macro_range): |
| 390 | + return builtin |
| 391 | + else: |
| 392 | + # Skip if we can't get the position |
| 393 | + return None |
| 394 | + |
| 395 | + # Find the macro definition information |
| 396 | + macro_def = target.python_env.get(macro_name) |
| 397 | + if macro_def is None: |
| 398 | + return None |
| 399 | + |
| 400 | + function_name = macro_def.name |
| 401 | + if not function_name: |
| 402 | + return None |
| 403 | + if not macro_def.path: |
| 404 | + return None |
| 405 | + if not config_path: |
| 406 | + return None |
| 407 | + path = Path(config_path).joinpath(macro_def.path) |
| 408 | + |
| 409 | + # Parse the Python file to find the function definition |
| 410 | + with open(path, "r") as f: |
| 411 | + tree = ast.parse(f.read()) |
| 412 | + with open(path, "r") as f: |
| 413 | + output_read_line = f.readlines() |
| 414 | + |
| 415 | + # Find the function definition by name |
| 416 | + start_line = None |
| 417 | + end_line = None |
| 418 | + get_length_of_end_line = None |
| 419 | + docstring = None |
| 420 | + for ast_node in ast.walk(tree): |
| 421 | + if isinstance(ast_node, ast.FunctionDef) and ast_node.name == function_name: |
| 422 | + start_line = ast_node.lineno |
| 423 | + end_line = ast_node.end_lineno |
| 424 | + get_length_of_end_line = ( |
| 425 | + len(output_read_line[end_line - 1]) |
| 426 | + if end_line is not None and end_line - 1 < len(read_file) |
| 427 | + else 0 |
| 428 | + ) |
| 429 | + # Extract docstring if present |
| 430 | + docstring = ast.get_docstring(ast_node) |
| 431 | + break |
| 432 | + |
| 433 | + if start_line is None or end_line is None or get_length_of_end_line is None: |
| 434 | + return None |
| 435 | + |
| 436 | + # Create a reference to the macro definition |
| 437 | + macro_uri = URI.from_path(path) |
| 438 | + |
| 439 | + return Reference( |
| 440 | + uri=macro_uri.value, |
| 441 | + range=macro_range, |
| 442 | + target_range=Range( |
| 443 | + start=Position(line=start_line - 1, character=0), |
| 444 | + end=Position(line=end_line - 1, character=get_length_of_end_line), |
| 445 | + ), |
| 446 | + markdown_description=docstring, |
| 447 | + ) |
| 448 | + except Exception: |
| 449 | + return None |
| 450 | + |
| 451 | + |
| 452 | +def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optional[Reference]: |
| 453 | + """ |
| 454 | + Get a reference to a built-in macro by its name. |
| 455 | +
|
| 456 | + Args: |
| 457 | + macro_name: The name of the built-in macro (e.g., 'each', 'sql_literal') |
| 458 | + macro_range: The range of the macro invocation in the source file |
| 459 | + """ |
| 460 | + built_in_macros = macro.get_registry() |
| 461 | + built_in_macro = built_in_macros.get(macro_name) |
| 462 | + if built_in_macro is None: |
| 463 | + return None |
| 464 | + |
| 465 | + func = built_in_macro.func |
| 466 | + filename = inspect.getfile(func) |
| 467 | + source_lines, line_number = inspect.getsourcelines(func) |
| 468 | + |
| 469 | + # Calculate the end line number by counting the number of source lines |
| 470 | + end_line_number = line_number + len(source_lines) - 1 |
| 471 | + |
| 472 | + return Reference( |
| 473 | + uri=URI.from_path(Path(filename)).value, |
| 474 | + range=macro_range, |
| 475 | + target_range=Range( |
| 476 | + start=Position(line=line_number - 1, character=0), |
| 477 | + end=Position(line=end_line_number - 1, character=0), |
| 478 | + ), |
| 479 | + markdown_description=func.__doc__ if func.__doc__ else None, |
| 480 | + ) |
0 commit comments