From fcfa8effd163166daa17cd55c9a999d382cd61b2 Mon Sep 17 00:00:00 2001 From: Ben <9087625+benfdking@users.noreply.github.com> Date: Wed, 11 Jun 2025 10:04:16 +0100 Subject: [PATCH] feat(lsp): include macro details in completions --- examples/sushi/macros/utils.py | 4 ++-- sqlmesh/lsp/completions.py | 20 ++++++++++++++------ sqlmesh/lsp/custom.py | 9 ++++++++- sqlmesh/lsp/main.py | 4 +++- tests/lsp/test_completions.py | 8 ++++++-- 5 files changed, 33 insertions(+), 12 deletions(-) diff --git a/examples/sushi/macros/utils.py b/examples/sushi/macros/utils.py index a76bc3bfe0..fb2ccc21b0 100644 --- a/examples/sushi/macros/utils.py +++ b/examples/sushi/macros/utils.py @@ -5,14 +5,14 @@ @macro() def add_one(evaluator, column: int): - # typed column will be cast to an int and return an integer back + """typed column will be cast to an int and return an integer back""" assert isinstance(column, int) return column + 1 @macro() def multiply(evaluator, column, num): - # untyped column will be a sqlglot column and return a sqlglot exp "column > 0" + """untyped column will be a sqlglot column and return a sqlglot exp "column > 0""" assert isinstance(column, exp.Column) return column * num diff --git a/sqlmesh/lsp/completions.py b/sqlmesh/lsp/completions.py index 1b1483c271..7e3781a550 100644 --- a/sqlmesh/lsp/completions.py +++ b/sqlmesh/lsp/completions.py @@ -1,6 +1,6 @@ from functools import lru_cache from sqlglot import Dialect, Tokenizer -from sqlmesh.lsp.custom import AllModelsResponse +from sqlmesh.lsp.custom import AllModelsResponse, MacroCompletion from sqlmesh import macro import typing as t from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget @@ -60,15 +60,23 @@ def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t. return all_models -def get_macros(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]: - """Return a set of all macros with the ``@`` prefix.""" - names = set(macro.get_registry()) +def get_macros( + context: t.Optional[LSPContext], file_uri: t.Optional[URI] +) -> t.List[MacroCompletion]: + """Return a list of macros with optional descriptions.""" + macros: t.Dict[str, t.Optional[str]] = {} + + for name, m in macro.get_registry().items(): + macros[name] = getattr(m.func, "__doc__", None) + try: if context is not None: - names.update(context.context._macros) + for name, m in context.context._macros.items(): + macros[name] = getattr(m.func, "__doc__", None) except Exception: pass - return names + + return [MacroCompletion(name=name, description=doc) for name, doc in macros.items()] def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]: diff --git a/sqlmesh/lsp/custom.py b/sqlmesh/lsp/custom.py index c432802950..5c8123b7a0 100644 --- a/sqlmesh/lsp/custom.py +++ b/sqlmesh/lsp/custom.py @@ -23,6 +23,13 @@ class AllModelsRequest(CustomMethodRequestBaseClass): textDocument: types.TextDocumentIdentifier +class MacroCompletion(PydanticModel): + """Information about a macro for autocompletion.""" + + name: str + description: t.Optional[str] = None + + class AllModelsResponse(CustomMethodResponseBaseClass): """ Response to get all the models that are in the current project. @@ -30,7 +37,7 @@ class AllModelsResponse(CustomMethodResponseBaseClass): models: t.List[str] keywords: t.List[str] - macros: t.List[str] + macros: t.List[MacroCompletion] RENDER_MODEL_FEATURE = "sqlmesh/render_model" diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index 3247bd5b50..137ca7c07c 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -625,7 +625,8 @@ def completion( and getattr(params.context, "trigger_character", None) == "@" ) - for macro_name in completion_response.macros: + for macro in completion_response.macros: + macro_name = macro.name insert_text = macro_name if triggered_by_at else f"@{macro_name}" completion_items.append( @@ -636,6 +637,7 @@ def completion( filter_text=macro_name, kind=types.CompletionItemKind.Function, detail="SQLMesh Macro", + documentation=macro.description, ) ) diff --git a/tests/lsp/test_completions.py b/tests/lsp/test_completions.py index 8977d178ba..7e193d77d6 100644 --- a/tests/lsp/test_completions.py +++ b/tests/lsp/test_completions.py @@ -33,8 +33,12 @@ def test_get_macros(): file_uri = URI.from_path(file_path) completions = LSPContext.get_completions(lsp_context, file_uri, file_content) - assert "each" in completions.macros - assert "add_one" in completions.macros + each_macro = next((m for m in completions.macros if m.name == "each")) + assert each_macro.name == "each" + assert each_macro.description + add_one_macro = next((m for m in completions.macros if m.name == "add_one")) + assert add_one_macro.name == "add_one" + assert add_one_macro.description def test_get_sql_completions_with_context_no_file_uri():