Skip to content

Commit 0217019

Browse files
committed
feat(vscode): gtd for macros
[ci skip]
1 parent f4fa53f commit 0217019

2 files changed

Lines changed: 227 additions & 0 deletions

File tree

sqlmesh/lsp/reference.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from lsprotocol.types import Range, Position
22
import typing as t
3+
from pathlib import Path
34

45
from sqlmesh.core.dialect import normalize_model_name
56
from sqlmesh.core.model.definition import SqlModel
@@ -72,6 +73,11 @@ def get_references(
7273
A list of references at the given position
7374
"""
7475
references = get_model_definitions_for_a_path(lint_context, document_uri)
76+
77+
# Get macro references before filtering by position
78+
macro_references = get_macro_definitions_for_a_path(lint_context, document_uri)
79+
references.extend(macro_references)
80+
7581
filtered_references = list(filter(by_position(position), references))
7682
return filtered_references
7783

@@ -287,3 +293,100 @@ def _range_from_token_position_details(
287293
start=Position(line=start_line_0, character=start_col_0),
288294
end=Position(line=end_line_0, character=end_col_0),
289295
)
296+
297+
298+
def get_macro_definitions_for_a_path(
299+
lint_context: LSPContext, document_uri: URI
300+
) -> t.List[Reference]:
301+
"""
302+
Get macro references for a given path.
303+
304+
This function finds all macro invocations (e.g., @ADD_ONE, @MULTIPLY) in a SQL file
305+
and creates references to their definitions in the Python macro files.
306+
307+
Args:
308+
lint_context: The LSP context containing macro definitions
309+
document_uri: The URI of the document to search for macro invocations
310+
311+
Returns:
312+
A list of Reference objects for each macro invocation found
313+
"""
314+
path = document_uri.to_path()
315+
if path.suffix != ".sql":
316+
return []
317+
318+
# Get the file info from the context map
319+
if path not in lint_context.map:
320+
return []
321+
322+
file_info = lint_context.map[path]
323+
# Process based on whether it's a model or standalone audit
324+
if isinstance(file_info, ModelTarget):
325+
# It's a model
326+
model = lint_context.context.get_model(
327+
model_or_snapshot=file_info.names[0], raise_if_missing=False
328+
)
329+
if model is None or not isinstance(model, SqlModel):
330+
return []
331+
query = model.query
332+
file_path = model._path
333+
elif isinstance(file_info, AuditTarget):
334+
# It's a standalone audit
335+
audit = lint_context.context.standalone_audits.get(file_info.name)
336+
if audit is None:
337+
return []
338+
query = audit.query
339+
file_path = audit._path
340+
else:
341+
return []
342+
343+
references = []
344+
345+
with open(file_path, "r", encoding="utf-8") as file:
346+
read_file = file.readlines()
347+
348+
# Find all macro function calls in the query
349+
from sqlmesh.core.macros import MacroFunc
350+
351+
for node in query.find_all(MacroFunc):
352+
macro_name = node.name.lower()
353+
354+
# Find the macro definition in the context
355+
macro_def = lint_context.context._macros.get(macro_name)
356+
if macro_def is None:
357+
continue
358+
359+
# Get the file path where the macro is defined
360+
import inspect
361+
try:
362+
source_file = inspect.getsourcefile(macro_def)
363+
if source_file is None:
364+
continue
365+
366+
# Get the line number where the macro is defined
367+
source_lines, start_line = inspect.getsourcelines(macro_def)
368+
369+
# Create a reference to the macro definition
370+
macro_uri = URI.from_path(Path(source_file))
371+
372+
# Get the position of the macro invocation in the source file
373+
if hasattr(node, 'meta') and node.meta:
374+
token_details = TokenPositionDetails.from_meta(node.meta)
375+
macro_range = _range_from_token_position_details(token_details, read_file)
376+
377+
references.append(
378+
Reference(
379+
uri=macro_uri.value,
380+
range=macro_range,
381+
target_range=Range(
382+
start=Position(line=start_line - 1, character=0),
383+
end=Position(line=start_line - 1, character=0)
384+
),
385+
markdown_description=f"Macro: `@{macro_name.upper()}`"
386+
)
387+
)
388+
except (OSError, TypeError):
389+
# If we can't get the source file, skip this macro
390+
continue
391+
392+
return references

tests/lsp/test_macro_reference.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import pytest
2+
from lsprotocol.types import Position
3+
from pathlib import Path
4+
from sqlmesh.core.context import Context
5+
from sqlmesh.lsp.context import LSPContext, ModelTarget
6+
from sqlmesh.lsp.reference import get_macro_definitions_for_a_path, get_references, by_position
7+
from sqlmesh.lsp.uri import URI
8+
9+
10+
@pytest.mark.fast
11+
def test_macro_references() -> None:
12+
"""Test that macro references (e.g., @ADD_ONE, @MULTIPLY) have proper go-to-definition support."""
13+
context = Context(paths=["examples/sushi"])
14+
lsp_context = LSPContext(context)
15+
16+
# Find the top_waiters model that uses macros
17+
top_waiters_path = next(
18+
path
19+
for path, info in lsp_context.map.items()
20+
if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names
21+
)
22+
23+
top_waiters_uri = URI.from_path(top_waiters_path)
24+
macro_references = get_macro_definitions_for_a_path(lsp_context, top_waiters_uri)
25+
26+
# We expect 3 macro references: @ADD_ONE, @MULTIPLY, @SQL_LITERAL
27+
assert len(macro_references) == 3
28+
29+
# Check that all references point to the utils.py file
30+
for ref in macro_references:
31+
assert ref.uri.endswith("macros/utils.py")
32+
assert ref.target_range is not None # Should have target range for go-to-definition
33+
34+
# Read the SQL file to verify macro positions
35+
with open(top_waiters_path, "r") as file:
36+
content = file.read()
37+
38+
# Verify that the references are at the expected macro invocations
39+
assert "@ADD_ONE" in content
40+
assert "@MULTIPLY" in content
41+
assert "@SQL_LITERAL" in content
42+
43+
44+
@pytest.mark.fast
45+
def test_macro_go_to_definition_with_position() -> None:
46+
"""Test go-to-definition for specific macro positions."""
47+
context = Context(paths=["examples/sushi"])
48+
lsp_context = LSPContext(context)
49+
50+
# Find the top_waiters model
51+
top_waiters_path = next(
52+
path
53+
for path, info in lsp_context.map.items()
54+
if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names
55+
)
56+
57+
# Read the file to find exact positions
58+
with open(top_waiters_path, "r") as file:
59+
lines = file.readlines()
60+
61+
# Find line with @ADD_ONE
62+
add_one_line = None
63+
add_one_char = None
64+
for i, line in enumerate(lines):
65+
if "@ADD_ONE" in line:
66+
add_one_line = i
67+
add_one_char = line.index("@ADD_ONE") + 4 # Position inside the macro name
68+
break
69+
70+
assert add_one_line is not None
71+
72+
# Test go-to-definition at the @ADD_ONE position
73+
position = Position(line=add_one_line, character=add_one_char)
74+
references = get_references(lsp_context, URI.from_path(top_waiters_path), position)
75+
76+
# Should find exactly one reference
77+
assert len(references) == 1
78+
assert references[0].uri.endswith("macros/utils.py")
79+
assert references[0].target_range is not None
80+
81+
# The target should be at the line where add_one is defined
82+
# (approximately line 6 based on the utils.py file)
83+
assert references[0].target_range.start.line >= 5 # Zero-indexed, so line 6 is index 5
84+
85+
86+
@pytest.mark.fast
87+
def test_all_three_macros_positions() -> None:
88+
"""Test that all three macros (@ADD_ONE, @MULTIPLY, @SQL_LITERAL) work with position-based lookup."""
89+
context = Context(paths=["examples/sushi"])
90+
lsp_context = LSPContext(context)
91+
92+
# Find the top_waiters model
93+
top_waiters_path = next(
94+
path
95+
for path, info in lsp_context.map.items()
96+
if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names
97+
)
98+
99+
top_waiters_uri = URI.from_path(top_waiters_path)
100+
101+
# Get all references (models and macros)
102+
with open(top_waiters_path, "r") as file:
103+
lines = file.readlines()
104+
105+
macro_names = ["@ADD_ONE", "@MULTIPLY", "@SQL_LITERAL"]
106+
found_macros = []
107+
108+
for macro_name in macro_names:
109+
for i, line in enumerate(lines):
110+
if macro_name in line:
111+
char_pos = line.index(macro_name) + len(macro_name) // 2
112+
position = Position(line=i, character=char_pos)
113+
refs = get_references(lsp_context, top_waiters_uri, position)
114+
115+
# Filter to only macro references
116+
macro_refs = [r for r in refs if r.uri.endswith("utils.py")]
117+
if macro_refs:
118+
found_macros.append(macro_name)
119+
assert len(macro_refs) == 1
120+
assert macro_refs[0].markdown_description == f"Macro: `{macro_name}`"
121+
122+
# Should find all three macros
123+
assert len(found_macros) == 3
124+
assert set(found_macros) == set(macro_names)

0 commit comments

Comments
 (0)