Skip to content

Commit 71c842f

Browse files
committed
feat(vscode): gtd for macros
1 parent fa43912 commit 71c842f

9 files changed

Lines changed: 293 additions & 12 deletions

File tree

examples/sushi/macros/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def sql_literal(
2626
column_str: str,
2727
column_quoted: str,
2828
):
29+
"""A macro that accepts various types of SQL literals and returns the column."""
2930
assert isinstance(column, str)
3031
assert isinstance(str_lit, str)
3132
assert str_lit == "'x'"

sqlmesh/core/context.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -932,19 +932,22 @@ def get_snapshot(
932932

933933
return snapshot
934934

935-
def config_for_path(self, path: Path) -> Config:
935+
def config_for_path(self, path: Path) -> t.Tuple[Config, Path]:
936+
"""Returns the config and path of the said project for a given file path."""
936937
for config_path, config in self.configs.items():
937938
try:
938939
path.relative_to(config_path)
939-
return config
940+
return config, config_path
940941
except ValueError:
941942
pass
942-
return self.config
943+
return self.config, self.path
943944

944945
def config_for_node(self, node: str | Model | Audit) -> Config:
945946
if isinstance(node, str):
946-
return self.config_for_path(self.get_snapshot(node, raise_if_missing=True).node._path) # type: ignore
947-
return self.config_for_path(node._path) # type: ignore
947+
return self.config_for_path(self.get_snapshot(node, raise_if_missing=True).node._path)[
948+
0
949+
] # type: ignore
950+
return self.config_for_path(node._path)[0] # type: ignore
948951

949952
@property
950953
def models(self) -> MappingProxyType[str, Model]:

sqlmesh/lsp/reference.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from lsprotocol.types import Range, Position
22
import typing as t
3+
from pathlib import Path
34

5+
from sqlmesh.core.audit import StandaloneAudit
46
from sqlmesh.core.dialect import normalize_model_name
57
from sqlmesh.core.model.definition import SqlModel
68
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
@@ -10,6 +12,10 @@
1012
from sqlmesh.lsp.uri import URI
1113
from sqlmesh.utils.pydantic import PydanticModel
1214
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
15+
import ast
16+
from sqlmesh.core.model import Model
17+
from sqlmesh import macro
18+
import inspect
1319

1420

1521
class Reference(PydanticModel):
@@ -72,6 +78,11 @@ def get_references(
7278
A list of references at the given position
7379
"""
7480
references = get_model_definitions_for_a_path(lint_context, document_uri)
81+
82+
# Get macro references before filtering by position
83+
macro_references = get_macro_definitions_for_a_path(lint_context, document_uri)
84+
references.extend(macro_references)
85+
7586
filtered_references = list(filter(by_position(position), references))
7687
return filtered_references
7788

@@ -290,3 +301,180 @@ def _range_from_token_position_details(
290301
start=Position(line=start_line_0, character=start_col_0),
291302
end=Position(line=end_line_0, character=end_col_0),
292303
)
304+
305+
306+
def get_macro_definitions_for_a_path(
307+
lsp_context: LSPContext, document_uri: URI
308+
) -> t.List[Reference]:
309+
"""
310+
Get macro references for a given path.
311+
312+
This function finds all macro invocations (e.g., @ADD_ONE, @MULTIPLY) in a SQL file
313+
and creates references to their definitions in the Python macro files.
314+
315+
Args:
316+
lsp_context: The LSP context containing macro definitions
317+
document_uri: The URI of the document to search for macro invocations
318+
319+
Returns:
320+
A list of Reference objects for each macro invocation found
321+
"""
322+
path = document_uri.to_path()
323+
if path.suffix != ".sql":
324+
return []
325+
326+
# Get the file info from the context map
327+
if path not in lsp_context.map:
328+
return []
329+
330+
file_info = lsp_context.map[path]
331+
# Process based on whether it's a model or standalone audit
332+
if isinstance(file_info, ModelTarget):
333+
# It's a model
334+
target: t.Optional[t.Union[Model, StandaloneAudit]] = lsp_context.context.get_model(
335+
model_or_snapshot=file_info.names[0], raise_if_missing=False
336+
)
337+
if target is None or not isinstance(target, SqlModel):
338+
return []
339+
query = target.query
340+
file_path = target._path
341+
elif isinstance(file_info, AuditTarget):
342+
# It's a standalone audit
343+
target = lsp_context.context.standalone_audits.get(file_info.name)
344+
if target is None:
345+
return []
346+
query = target.query
347+
file_path = target._path
348+
else:
349+
return []
350+
351+
references = []
352+
config_for_model, config_path = lsp_context.context.config_for_path(
353+
file_path,
354+
)
355+
356+
with open(file_path, "r", encoding="utf-8") as file:
357+
read_file = file.readlines()
358+
359+
for node in query.find_all(exp.Anonymous):
360+
macro_name = node.name.lower()
361+
reference = get_macro_reference(
362+
node=node,
363+
target=target,
364+
read_file=read_file,
365+
config_path=config_path,
366+
macro_name=macro_name,
367+
)
368+
if reference is not None:
369+
references.append(reference)
370+
371+
return references
372+
373+
374+
def get_macro_reference(
375+
target: t.Union[Model, StandaloneAudit],
376+
read_file: t.List[str],
377+
config_path: t.Optional[Path],
378+
node: exp.Expression,
379+
macro_name: str,
380+
) -> t.Optional[Reference]:
381+
# Get the file path where the macro is defined
382+
try:
383+
# Get the position of the macro invocation in the source file first
384+
if hasattr(node, "meta") and node.meta:
385+
token_details = TokenPositionDetails.from_meta(node.meta)
386+
macro_range = _range_from_token_position_details(token_details, read_file)
387+
388+
# Check if it's a built-in method
389+
if builtin := get_built_in_macro_reference(macro_name, macro_range):
390+
return builtin
391+
else:
392+
# Skip if we can't get the position
393+
return None
394+
395+
# Find the macro definition information
396+
macro_def = target.python_env.get(macro_name)
397+
if macro_def is None:
398+
return None
399+
400+
function_name = macro_def.name
401+
if not function_name:
402+
return None
403+
if not macro_def.path:
404+
return None
405+
if not config_path:
406+
return None
407+
path = Path(config_path).joinpath(macro_def.path)
408+
409+
# Parse the Python file to find the function definition
410+
with open(path, "r") as f:
411+
tree = ast.parse(f.read())
412+
with open(path, "r") as f:
413+
output_read_line = f.readlines()
414+
415+
# Find the function definition by name
416+
start_line = None
417+
end_line = None
418+
get_length_of_end_line = None
419+
docstring = None
420+
for ast_node in ast.walk(tree):
421+
if isinstance(ast_node, ast.FunctionDef) and ast_node.name == function_name:
422+
start_line = ast_node.lineno
423+
end_line = ast_node.end_lineno
424+
get_length_of_end_line = (
425+
len(output_read_line[end_line - 1])
426+
if end_line is not None and end_line - 1 < len(read_file)
427+
else 0
428+
)
429+
# Extract docstring if present
430+
docstring = ast.get_docstring(ast_node)
431+
break
432+
433+
if start_line is None or end_line is None or get_length_of_end_line is None:
434+
return None
435+
436+
# Create a reference to the macro definition
437+
macro_uri = URI.from_path(path)
438+
439+
return Reference(
440+
uri=macro_uri.value,
441+
range=macro_range,
442+
target_range=Range(
443+
start=Position(line=start_line - 1, character=0),
444+
end=Position(line=end_line - 1, character=get_length_of_end_line),
445+
),
446+
markdown_description=docstring,
447+
)
448+
except Exception:
449+
return None
450+
451+
452+
def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optional[Reference]:
453+
"""
454+
Get a reference to a built-in macro by its name.
455+
456+
Args:
457+
macro_name: The name of the built-in macro (e.g., 'each', 'sql_literal')
458+
macro_range: The range of the macro invocation in the source file
459+
"""
460+
built_in_macros = macro.get_registry()
461+
built_in_macro = built_in_macros.get(macro_name)
462+
if built_in_macro is None:
463+
return None
464+
465+
func = built_in_macro.func
466+
filename = inspect.getfile(func)
467+
source_lines, line_number = inspect.getsourcelines(func)
468+
469+
# Calculate the end line number by counting the number of source lines
470+
end_line_number = line_number + len(source_lines) - 1
471+
472+
return Reference(
473+
uri=URI.from_path(Path(filename)).value,
474+
range=macro_range,
475+
target_range=Range(
476+
start=Position(line=line_number - 1, character=0),
477+
end=Position(line=end_line_number - 1, character=0),
478+
),
479+
markdown_description=func.__doc__ if func.__doc__ else None,
480+
)

tests/lsp/test_reference_macro.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
4+
from sqlmesh.lsp.reference import get_macro_definitions_for_a_path
5+
from sqlmesh.lsp.uri import URI
6+
7+
8+
@pytest.mark.fast
9+
def test_macro_references() -> None:
10+
"""Test that macro references (e.g., @ADD_ONE, @MULTIPLY) have proper go-to-definition support."""
11+
context = Context(paths=["examples/sushi"])
12+
lsp_context = LSPContext(context)
13+
14+
# Find the top_waiters model that uses macros
15+
top_waiters_path = next(
16+
path
17+
for path, info in lsp_context.map.items()
18+
if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names
19+
)
20+
21+
top_waiters_uri = URI.from_path(top_waiters_path)
22+
macro_references = get_macro_definitions_for_a_path(lsp_context, top_waiters_uri)
23+
24+
# We expect 3 macro references: @ADD_ONE, @MULTIPLY, @SQL_LITERAL
25+
assert len(macro_references) == 3
26+
27+
# Check that all references point to the utils.py file
28+
for ref in macro_references:
29+
assert ref.uri.endswith("sushi/macros/utils.py")
30+
assert ref.target_range is not None
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
4+
from sqlmesh.lsp.reference import get_macro_definitions_for_a_path
5+
from sqlmesh.lsp.uri import URI
6+
7+
8+
@pytest.mark.fast
9+
def test_macro_references_multirepo() -> None:
10+
context = Context(paths=["examples/multi/repo_1", "examples/multi/repo_2"])
11+
lsp_context = LSPContext(context)
12+
13+
d_path = next(
14+
path
15+
for path, info in lsp_context.map.items()
16+
if isinstance(info, ModelTarget) and "silver.d" in info.names
17+
)
18+
19+
d = URI.from_path(d_path)
20+
macro_references = get_macro_definitions_for_a_path(lsp_context, d)
21+
22+
assert len(macro_references) == 2
23+
for ref in macro_references:
24+
assert ref.uri.endswith("multi/repo_2/macros/__init__.py")
25+
assert ref.target_range is not None
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import { test, expect } from '@playwright/test';
2+
import path from 'path';
3+
import fs from 'fs-extra';
4+
import os from 'os';
5+
import { startVSCode, SUSHI_SOURCE_PATH } from './utils';
6+
7+
test('Go to definition for macro', async () => {
8+
const tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'vscode-test-sushi-'));
9+
await fs.copy(SUSHI_SOURCE_PATH, tempDir);
10+
11+
try {
12+
const { window, close } = await startVSCode(tempDir);
13+
14+
// Wait for the models folder to be visible
15+
await window.waitForSelector('text=models');
16+
17+
// Click on the models folder, excluding external_models
18+
await window.getByRole('treeitem', { name: 'models', exact: true }).locator('a').click();
19+
20+
// Open the customer_revenue_lifetime model
21+
await window.getByRole('treeitem', { name: 'top_waiters.sql', exact: true }).locator('a').click();
22+
23+
await window.waitForSelector('text=grain');
24+
await window.waitForSelector('text=Loaded SQLMesh Context')
25+
26+
// Render the model
27+
window.locator("text=@MULTIPLY").click({
28+
modifiers: ["Meta"]
29+
})
30+
31+
// Check if the model is rendered by check if "`oi`.`order_id` AS `order_id`," is in the window
32+
await expect(window.locator('text=def multiply(')).toBeVisible();
33+
34+
await close();
35+
} finally {
36+
await fs.removeSync(tempDir);
37+
}
38+
});

vscode/extension/tests/utils.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,8 @@ export const startVSCode = async (workspaceDir: string): Promise<{
4040
args,
4141
});
4242
const window = await electronApp.firstWindow();
43-
await window.waitForLoadState('domcontentloaded');
44-
await window.waitForLoadState('networkidle');
45-
await window.waitForTimeout(2_000);
4643
return { window, close: async () => {
4744
await electronApp.close();
48-
await fs.remove(userDataDir);
45+
await fs.removeSync(userDataDir);
4946
} };
5047
}

web/server/api/endpoints/files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def write_file(
6060
replace_file(settings.project_path / path, settings.project_path / path_or_new_path)
6161
else:
6262
full_path = settings.project_path / path
63-
config = context.config_for_path(Path(path_or_new_path)) if context else None
63+
config, _ = context.config_for_path(Path(path_or_new_path)) if context else (None, None)
6464
if (
6565
config
6666
and config.ui.format_on_save

web/server/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,10 @@ def validate_path(path: str, settings: Settings = Depends(get_settings)) -> str:
6464
if any(
6565
full_path.match(pattern)
6666
for pattern in (
67-
context.config_for_path(Path(path)).ignore_patterns if context else c.IGNORE_PATTERNS
67+
context.config_for_path(Path(path))[0].ignore_patterns if context else c.IGNORE_PATTERNS
6868
)
6969
):
7070
raise HTTPException(status_code=HTTP_404_NOT_FOUND)
71-
7271
return path
7372

7473

0 commit comments

Comments
 (0)