diff --git a/sqlmesh/lsp/completions.py b/sqlmesh/lsp/completions.py index 6716f16e2f..1b1483c271 100644 --- a/sqlmesh/lsp/completions.py +++ b/sqlmesh/lsp/completions.py @@ -1,13 +1,16 @@ from functools import lru_cache from sqlglot import Dialect, Tokenizer from sqlmesh.lsp.custom import AllModelsResponse +from sqlmesh import macro import typing as t from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget from sqlmesh.lsp.uri import URI def get_sql_completions( - context: t.Optional[LSPContext], file_uri: t.Optional[URI], content: t.Optional[str] = None + context: t.Optional[LSPContext] = None, + file_uri: t.Optional[URI] = None, + content: t.Optional[str] = None, ) -> AllModelsResponse: """ Return a list of completions for a given file. @@ -26,6 +29,7 @@ def get_sql_completions( return AllModelsResponse( models=list(get_models(context, file_uri)), keywords=all_keywords, + macros=list(get_macros(context, file_uri)), ) @@ -56,6 +60,17 @@ def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t. return all_models +def get_macros(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]: + """Return a set of all macros with the ``@`` prefix.""" + names = set(macro.get_registry()) + try: + if context is not None: + names.update(context.context._macros) + except Exception: + pass + return names + + def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]: """ Return a list of sql keywords for a given file. @@ -138,6 +153,7 @@ def get_dialect(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t def extract_keywords_from_content(content: str, dialect: t.Optional[str] = None) -> t.Set[str]: """ Extract identifiers from SQL content using the tokenizer. + Only extracts identifiers (variable names, table names, column names, etc.) that are not SQL keywords. """ @@ -155,7 +171,7 @@ def extract_keywords_from_content(content: str, dialect: t.Optional[str] = None) keywords.add(token.text) except Exception: - # If tokenization fails, return empty set + # If tokenization fails, return an empty set pass return keywords diff --git a/sqlmesh/lsp/context.py b/sqlmesh/lsp/context.py index 4ac55f1a22..f3bdcc13e3 100644 --- a/sqlmesh/lsp/context.py +++ b/sqlmesh/lsp/context.py @@ -176,18 +176,13 @@ def list_of_models_for_rendering(self) -> t.List[ModelForRendering]: if audit._path is not None ] - def get_autocomplete( - self, uri: t.Optional[URI], content: t.Optional[str] = None + @staticmethod + def get_completions( + self: t.Optional["LSPContext"] = None, + uri: t.Optional[URI] = None, + file_content: t.Optional[str] = None, ) -> AllModelsResponse: - """Get autocomplete suggestions for a file. - - Args: - uri: The URI of the file to get autocomplete suggestions for. - content: The content of the file (optional). - - Returns: - AllModelsResponse containing models and keywords. - """ + """Get completion suggestions for a file""" from sqlmesh.lsp.completions import get_sql_completions - return get_sql_completions(self, uri, content) + return get_sql_completions(self, uri, file_content) diff --git a/sqlmesh/lsp/custom.py b/sqlmesh/lsp/custom.py index 9e0bc07cd4..72c1ec7917 100644 --- a/sqlmesh/lsp/custom.py +++ b/sqlmesh/lsp/custom.py @@ -20,6 +20,7 @@ class AllModelsResponse(PydanticModel): models: t.List[str] keywords: t.List[str] + macros: t.List[str] RENDER_MODEL_FEATURE = "sqlmesh/render_model" diff --git a/sqlmesh/lsp/helpers.py b/sqlmesh/lsp/helpers.py new file mode 100644 index 0000000000..7aa06ccb4c --- /dev/null +++ b/sqlmesh/lsp/helpers.py @@ -0,0 +1,27 @@ +from lsprotocol.types import Range, Position + +from sqlmesh.core.linter.helpers import ( + Range as SQLMeshRange, + Position as SQLMeshPosition, +) + + +def to_lsp_range( + range: SQLMeshRange, +) -> Range: + """ + Converts a SQLMesh Range to an LSP Range. + """ + return Range( + start=Position(line=range.start.line, character=range.start.character), + end=Position(line=range.end.line, character=range.end.character), + ) + + +def to_lsp_position( + position: SQLMeshPosition, +) -> Position: + """ + Converts a SQLMesh Position to an LSP Position. + """ + return Position(line=position.line, character=position.character) diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index 3b42805920..2295a4b95f 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -145,7 +145,7 @@ def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsRespons try: context = self._context_get_or_load(uri) - return context.get_autocomplete(uri, content) + return LSPContext.get_completions(context, uri, content) except Exception as e: from sqlmesh.lsp.completions import get_sql_completions @@ -565,7 +565,10 @@ def workspace_diagnostic( ) return types.WorkspaceDiagnosticReport(items=[]) - @self.server.feature(types.TEXT_DOCUMENT_COMPLETION) + @self.server.feature( + types.TEXT_DOCUMENT_COMPLETION, + types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros + ) def completion( ls: LanguageServer, params: types.CompletionParams ) -> t.Optional[types.CompletionList]: @@ -583,7 +586,7 @@ def completion( pass # Get completions using the existing completions module - completion_response = context.get_autocomplete(uri, content) + completion_response = LSPContext.get_completions(context, uri, content) completion_items = [] # Add model completions @@ -595,7 +598,26 @@ def completion( detail="SQLMesh Model", ) ) - # Add keyword completions + # Add macro completions + triggered_by_at = ( + params.context is not None + and getattr(params.context, "trigger_character", None) == "@" + ) + + for macro_name in completion_response.macros: + insert_text = macro_name if triggered_by_at else f"@{macro_name}" + + completion_items.append( + types.CompletionItem( + label=f"@{macro_name}", + insert_text=insert_text, + insert_text_format=types.InsertTextFormat.PlainText, + filter_text=macro_name, + kind=types.CompletionItemKind.Function, + detail="SQLMesh Macro", + ) + ) + for keyword in completion_response.keywords: completion_items.append( types.CompletionItem( diff --git a/sqlmesh/lsp/reference.py b/sqlmesh/lsp/reference.py index 816449090f..533ef75332 100644 --- a/sqlmesh/lsp/reference.py +++ b/sqlmesh/lsp/reference.py @@ -7,14 +7,14 @@ from sqlmesh.core.dialect import normalize_model_name from sqlmesh.core.linter.helpers import ( TokenPositionDetails, - Range as SQLMeshRange, - Position as SQLMeshPosition, ) from sqlmesh.core.model.definition import SqlModel from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget from sqlglot import exp from sqlmesh.lsp.description import generate_markdown_description from sqlglot.optimizer.scope import build_scope + +from sqlmesh.lsp.helpers import to_lsp_range, to_lsp_position from sqlmesh.lsp.uri import URI from sqlmesh.utils.pydantic import PydanticModel from sqlglot.optimizer.normalize_identifiers import normalize_identifiers @@ -624,24 +624,3 @@ def _position_within_range(position: Position, range: Range) -> bool: range.end.line > position.line or (range.end.line == position.line and range.end.character >= position.character) ) - - -def to_lsp_range( - range: SQLMeshRange, -) -> Range: - """ - Converts a SQLMesh Range to an LSP Range. - """ - return Range( - start=Position(line=range.start.line, character=range.start.character), - end=Position(line=range.end.line, character=range.end.character), - ) - - -def to_lsp_position( - position: SQLMeshPosition, -) -> Position: - """ - Converts a SQLMesh Position to an LSP Position. - """ - return Position(line=position.line, character=position.character) diff --git a/tests/lsp/test_completions.py b/tests/lsp/test_completions.py index e365873c19..8977d178ba 100644 --- a/tests/lsp/test_completions.py +++ b/tests/lsp/test_completions.py @@ -22,12 +22,27 @@ def test_get_sql_completions_no_context(): assert len(completions.models) == 0 +def test_get_macros(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + file_path = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql") + with open(file_path, "r", encoding="utf-8") as f: + file_content = f.read() + + file_uri = URI.from_path(file_path) + completions = LSPContext.get_completions(lsp_context, file_uri, file_content) + + assert "each" in completions.macros + assert "add_one" in completions.macros + + def test_get_sql_completions_with_context_no_file_uri(): context = Context(paths=["examples/sushi"]) lsp_context = LSPContext(context) - completions = lsp_context.get_autocomplete(None) - assert len(completions.keywords) > len(TOKENIZER_KEYWORDS) + completions = LSPContext.get_completions(lsp_context, None) + assert len(completions.keywords) >= len(TOKENIZER_KEYWORDS) assert "sushi.active_customers" in completions.models assert "sushi.customers" in completions.models @@ -37,7 +52,7 @@ def test_get_sql_completions_with_context_and_file_uri(): lsp_context = LSPContext(context) file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql") - completions = lsp_context.get_autocomplete(URI.from_path(file_uri)) + completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri)) assert len(completions.keywords) > len(TOKENIZER_KEYWORDS) assert "sushi.active_customers" not in completions.models @@ -84,7 +99,7 @@ def test_get_sql_completions_with_file_content(): """ file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql") - completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content) + completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri), content) # Check that SQL keywords are included assert any(k in ["SELECT", "FROM", "WHERE", "JOIN"] for k in completions.keywords) @@ -129,7 +144,7 @@ def test_get_sql_completions_with_partial_cte_query(): """ file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql") - completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content) + completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri), content) # Check that CTE names are included in the keywords keywords_list = completions.keywords diff --git a/vscode/extension/tests/completions.spec.ts b/vscode/extension/tests/completions.spec.ts index f0d167b91f..3c22e388f3 100644 --- a/vscode/extension/tests/completions.spec.ts +++ b/vscode/extension/tests/completions.spec.ts @@ -47,9 +47,118 @@ test('Autocomplete for model names', async () => { expect( await window.locator('text=sushi.waiter_as_customer_by_day').count(), ).toBe(1) + expect(await window.locator('text=SQLMesh Model').count()).toBe(1) await close() } finally { await fs.remove(tempDir) } }) + +// Skip the macro completions test as regular checks because they are flaky and +// covered in other non-integration tests. +test.describe('Macro Completions', () => { + test('Completion for inbuilt macros', async () => { + const tempDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-sushi-'), + ) + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + try { + const { window, close } = await startVSCode(tempDir) + + // Wait for the models folder to be visible + await window.waitForSelector('text=models') + + // Click on the models folder + await window + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the top_waiters model + await window + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await window.waitForSelector('text=grain') + await window.waitForSelector('text=Loaded SQLMesh Context') + + await window.locator('text=grain').first().click() + + // Move to the end of the file + await window.keyboard.press('Control+End') + + // Add a new line + await window.keyboard.press('Enter') + + await window.waitForTimeout(500) + + // Hit the '@' key to trigger autocomplete for inbuilt macros + await window.keyboard.press('@') + await window.keyboard.type('eac') + + // Wait a moment for autocomplete to appear + await window.waitForTimeout(500) + + // Check if the autocomplete suggestion for inbuilt macros is visible + expect(await window.locator('text=@each').count()).toBe(1) + + await close() + } finally { + await fs.remove(tempDir) + } + }) + + test('Completion for custom macros', async () => { + const tempDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-sushi-'), + ) + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + try { + const { window, close } = await startVSCode(tempDir) + + // Wait for the models folder to be visible + await window.waitForSelector('text=models') + + // Click on the models folder + await window + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the top_waiters model + await window + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await window.waitForSelector('text=grain') + await window.waitForSelector('text=Loaded SQLMesh Context') + + await window.locator('text=grain').first().click() + + // Move to the end of the file + await window.keyboard.press('Control+End') + + // Add a new line + await window.keyboard.press('Enter') + + // Type the beginning of a macro to trigger autocomplete + await window.keyboard.press('@') + await window.keyboard.type('add_o') + + // Wait a moment for autocomplete to appear + await window.waitForTimeout(500) + + // Check if the autocomplete suggestion for custom macros is visible + expect(await window.locator('text=@add_one').count()).toBe(1) + + await close() + } finally { + await fs.remove(tempDir) + } + }) +})