Skip to content

Commit 88d9a6f

Browse files
committed
feat(vscode): gtd for macros
1 parent 1fb40ee commit 88d9a6f

9 files changed

Lines changed: 185 additions & 8 deletions

File tree

examples/multi/repo_1/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ gateways:
1010
connection:
1111
type: duckdb
1212

13-
default_gateway: local
13+
default_gateway: memory
1414

1515

1616
before_all:

examples/multi/repo_2/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ gateways:
1010
connection:
1111
type: duckdb
1212

13-
default_gateway: local
13+
default_gateway: memory
1414

1515

1616
before_all:

examples/sushi/macros/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def sql_literal(
2626
column_str: str,
2727
column_quoted: str,
2828
):
29+
"""A macro that accepts various types of SQL literals and returns the column."""
2930
assert isinstance(column, str)
3031
assert isinstance(str_lit, str)
3132
assert str_lit == "'x'"

sqlmesh/core/context.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -932,14 +932,15 @@ def get_snapshot(
932932

933933
return snapshot
934934

935-
def config_for_path(self, path: Path) -> Config:
935+
def config_for_path(self, path: Path) -> t.Tuple[Config, Path]:
936+
"""Returns the config and path of the said project for a given file path."""
936937
for config_path, config in self.configs.items():
937938
try:
938939
path.relative_to(config_path)
939-
return config
940+
return config, config_path
940941
except ValueError:
941942
pass
942-
return self.config
943+
return self.config, self.path
943944

944945
def config_for_node(self, node: str | Model | Audit) -> Config:
945946
if isinstance(node, str):

sqlmesh/lsp/reference.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from lsprotocol.types import Range, Position
22
import typing as t
3+
from pathlib import Path
34

5+
from sqlmesh.core.audit import StandaloneAudit
46
from sqlmesh.core.dialect import normalize_model_name
57
from sqlmesh.core.model.definition import SqlModel
68
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
@@ -10,6 +12,8 @@
1012
from sqlmesh.lsp.uri import URI
1113
from sqlmesh.utils.pydantic import PydanticModel
1214
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
15+
import ast
16+
from sqlmesh.core.model import Model
1317

1418

1519
class Reference(PydanticModel):
@@ -72,6 +76,11 @@ def get_references(
7276
A list of references at the given position
7377
"""
7478
references = get_model_definitions_for_a_path(lint_context, document_uri)
79+
80+
# Get macro references before filtering by position
81+
macro_references = get_macro_definitions_for_a_path(lint_context, document_uri)
82+
references.extend(macro_references)
83+
7584
filtered_references = list(filter(by_position(position), references))
7685
return filtered_references
7786

@@ -287,3 +296,118 @@ def _range_from_token_position_details(
287296
start=Position(line=start_line_0, character=start_col_0),
288297
end=Position(line=end_line_0, character=end_col_0),
289298
)
299+
300+
301+
def get_macro_definitions_for_a_path(
302+
lsp_context: LSPContext, document_uri: URI
303+
) -> t.List[Reference]:
304+
"""
305+
Get macro references for a given path.
306+
307+
This function finds all macro invocations (e.g., @ADD_ONE, @MULTIPLY) in a SQL file
308+
and creates references to their definitions in the Python macro files.
309+
310+
Args:
311+
lsp_context: The LSP context containing macro definitions
312+
document_uri: The URI of the document to search for macro invocations
313+
314+
Returns:
315+
A list of Reference objects for each macro invocation found
316+
"""
317+
path = document_uri.to_path()
318+
if path.suffix != ".sql":
319+
return []
320+
321+
# Get the file info from the context map
322+
if path not in lsp_context.map:
323+
return []
324+
325+
file_info = lsp_context.map[path]
326+
# Process based on whether it's a model or standalone audit
327+
if isinstance(file_info, ModelTarget):
328+
# It's a model
329+
target: t.Optional[t.Union[Model, StandaloneAudit]] = lsp_context.context.get_model(
330+
model_or_snapshot=file_info.names[0], raise_if_missing=False
331+
)
332+
if target is None or not isinstance(target, SqlModel):
333+
return []
334+
query = target.query
335+
file_path = target._path
336+
elif isinstance(file_info, AuditTarget):
337+
# It's a standalone audit
338+
target = lsp_context.context.standalone_audits.get(file_info.name)
339+
if target is None:
340+
return []
341+
query = target.query
342+
file_path = target._path
343+
else:
344+
return []
345+
346+
references = []
347+
config_for_model, config_path = lsp_context.context.config_for_path(
348+
file_path,
349+
)
350+
351+
with open(file_path, "r", encoding="utf-8") as file:
352+
read_file = file.readlines()
353+
354+
for node in query.find_all(exp.Anonymous):
355+
macro_name = node.name.lower()
356+
357+
# Find the macro definition information
358+
macro_def = target.python_env.get(macro_name)
359+
if macro_def is None:
360+
continue
361+
362+
# Get the file path where the macro is defined
363+
try:
364+
function_name = macro_def.name
365+
if not function_name:
366+
continue
367+
if not macro_def.path:
368+
continue
369+
path = Path(config_path).joinpath(macro_def.path)
370+
371+
# Parse the Python file to find the function definition
372+
with open(path, "r") as f:
373+
tree = ast.parse(f.read())
374+
375+
# Find the function definition by name
376+
start_line = None
377+
end_line = None
378+
docstring = None
379+
for ast_node in ast.walk(tree):
380+
if isinstance(ast_node, ast.FunctionDef) and ast_node.name == function_name:
381+
start_line = ast_node.lineno
382+
end_line = ast_node.end_lineno
383+
# Extract docstring if present
384+
docstring = ast.get_docstring(ast_node)
385+
break
386+
387+
if start_line is None or end_line is None:
388+
continue
389+
390+
# Create a reference to the macro definition
391+
macro_uri = URI.from_path(path)
392+
393+
# Get the position of the macro invocation in the source file
394+
if hasattr(node, "meta") and node.meta:
395+
token_details = TokenPositionDetails.from_meta(node.meta)
396+
macro_range = _range_from_token_position_details(token_details, read_file)
397+
398+
references.append(
399+
Reference(
400+
uri=macro_uri.value,
401+
range=macro_range,
402+
target_range=Range(
403+
start=Position(line=start_line - 1, character=0),
404+
end=Position(line=end_line - 1, character=0),
405+
),
406+
markdown_description=docstring,
407+
)
408+
)
409+
except (OSError, TypeError):
410+
# If we can't get the source file, skip this macro
411+
continue
412+
413+
return references

tests/lsp/test_reference_macro.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
4+
from sqlmesh.lsp.reference import get_macro_definitions_for_a_path
5+
from sqlmesh.lsp.uri import URI
6+
7+
8+
@pytest.mark.fast
9+
def test_macro_references() -> None:
10+
"""Test that macro references (e.g., @ADD_ONE, @MULTIPLY) have proper go-to-definition support."""
11+
context = Context(paths=["examples/sushi"])
12+
lsp_context = LSPContext(context)
13+
14+
# Find the top_waiters model that uses macros
15+
top_waiters_path = next(
16+
path
17+
for path, info in lsp_context.map.items()
18+
if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names
19+
)
20+
21+
top_waiters_uri = URI.from_path(top_waiters_path)
22+
macro_references = get_macro_definitions_for_a_path(lsp_context, top_waiters_uri)
23+
24+
# We expect 3 macro references: @ADD_ONE, @MULTIPLY, @SQL_LITERAL
25+
assert len(macro_references) == 3
26+
27+
# Check that all references point to the utils.py file
28+
for ref in macro_references:
29+
assert ref.uri.endswith("macros/utils.py")
30+
assert ref.target_range is not None # Should have target range for go-to-definition
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
4+
from sqlmesh.lsp.reference import get_macro_definitions_for_a_path
5+
from sqlmesh.lsp.uri import URI
6+
7+
8+
@pytest.mark.fast
9+
def test_macro_references_multirepo() -> None:
10+
context = Context(paths=["examples/multi/repo_1", "examples/multi/repo_2"])
11+
lsp_context = LSPContext(context)
12+
13+
d_path = next(
14+
path
15+
for path, info in lsp_context.map.items()
16+
if isinstance(info, ModelTarget) and "silver.d" in info.names
17+
)
18+
19+
d = URI.from_path(d_path)
20+
macro_references = get_macro_definitions_for_a_path(lsp_context, d)
21+
22+
assert len(macro_references) == 2

web/server/api/endpoints/files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def write_file(
6060
replace_file(settings.project_path / path, settings.project_path / path_or_new_path)
6161
else:
6262
full_path = settings.project_path / path
63-
config = context.config_for_path(Path(path_or_new_path)) if context else None
63+
config, _ = context.config_for_path(Path(path_or_new_path)) if context else (None, None)
6464
if (
6565
config
6666
and config.ui.format_on_save

web/server/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,10 @@ def validate_path(path: str, settings: Settings = Depends(get_settings)) -> str:
6464
if any(
6565
full_path.match(pattern)
6666
for pattern in (
67-
context.config_for_path(Path(path)).ignore_patterns if context else c.IGNORE_PATTERNS
67+
context.config_for_path(Path(path))[0].ignore_patterns if context else c.IGNORE_PATTERNS
6868
)
6969
):
7070
raise HTTPException(status_code=HTTP_404_NOT_FOUND)
71-
7271
return path
7372

7473

0 commit comments

Comments
 (0)