Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/sushi/macros/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

@macro()
def add_one(evaluator, column: int):
# typed column will be cast to an int and return an integer back
"""typed column will be cast to an int and return an integer back"""
assert isinstance(column, int)
return column + 1


@macro()
def multiply(evaluator, column, num):
# untyped column will be a sqlglot column and return a sqlglot exp "column > 0"
"""untyped column will be a sqlglot column and return a sqlglot exp "column > 0"""
assert isinstance(column, exp.Column)
return column * num

Expand Down
20 changes: 14 additions & 6 deletions sqlmesh/lsp/completions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import lru_cache
from sqlglot import Dialect, Tokenizer
from sqlmesh.lsp.custom import AllModelsResponse
from sqlmesh.lsp.custom import AllModelsResponse, MacroCompletion
from sqlmesh import macro
import typing as t
from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget
Expand Down Expand Up @@ -60,15 +60,23 @@ def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.
return all_models


def get_macros(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]:
"""Return a set of all macros with the ``@`` prefix."""
names = set(macro.get_registry())
def get_macros(
context: t.Optional[LSPContext], file_uri: t.Optional[URI]
) -> t.List[MacroCompletion]:
"""Return a list of macros with optional descriptions."""
macros: t.Dict[str, t.Optional[str]] = {}

for name, m in macro.get_registry().items():
macros[name] = getattr(m.func, "__doc__", None)

try:
if context is not None:
names.update(context.context._macros)
for name, m in context.context._macros.items():
macros[name] = getattr(m.func, "__doc__", None)
except Exception:
pass
return names

return [MacroCompletion(name=name, description=doc) for name, doc in macros.items()]


def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]:
Expand Down
9 changes: 8 additions & 1 deletion sqlmesh/lsp/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,21 @@ class AllModelsRequest(CustomMethodRequestBaseClass):
textDocument: types.TextDocumentIdentifier


class MacroCompletion(PydanticModel):
"""Information about a macro for autocompletion."""

name: str
description: t.Optional[str] = None


class AllModelsResponse(CustomMethodResponseBaseClass):
"""
Response to get all the models that are in the current project.
"""

models: t.List[str]
keywords: t.List[str]
macros: t.List[str]
macros: t.List[MacroCompletion]


RENDER_MODEL_FEATURE = "sqlmesh/render_model"
Expand Down
4 changes: 3 additions & 1 deletion sqlmesh/lsp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,8 @@ def completion(
and getattr(params.context, "trigger_character", None) == "@"
)

for macro_name in completion_response.macros:
for macro in completion_response.macros:
macro_name = macro.name
insert_text = macro_name if triggered_by_at else f"@{macro_name}"

completion_items.append(
Expand All @@ -636,6 +637,7 @@ def completion(
filter_text=macro_name,
kind=types.CompletionItemKind.Function,
detail="SQLMesh Macro",
documentation=macro.description,
)
)

Expand Down
8 changes: 6 additions & 2 deletions tests/lsp/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ def test_get_macros():
file_uri = URI.from_path(file_path)
completions = LSPContext.get_completions(lsp_context, file_uri, file_content)

assert "each" in completions.macros
assert "add_one" in completions.macros
each_macro = next((m for m in completions.macros if m.name == "each"))
assert each_macro.name == "each"
assert each_macro.description
add_one_macro = next((m for m in completions.macros if m.name == "add_one"))
assert add_one_macro.name == "add_one"
assert add_one_macro.description


def test_get_sql_completions_with_context_no_file_uri():
Expand Down