Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions sqlmesh/lsp/completions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from functools import lru_cache
from sqlglot import Dialect, Tokenizer
from sqlmesh.lsp.custom import AllModelsResponse
from sqlmesh import macro
import typing as t
from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget
from sqlmesh.lsp.uri import URI


def get_sql_completions(
context: t.Optional[LSPContext], file_uri: t.Optional[URI], content: t.Optional[str] = None
context: t.Optional[LSPContext] = None,
file_uri: t.Optional[URI] = None,
content: t.Optional[str] = None,
) -> AllModelsResponse:
"""
Return a list of completions for a given file.
Expand All @@ -26,6 +29,7 @@ def get_sql_completions(
return AllModelsResponse(
models=list(get_models(context, file_uri)),
keywords=all_keywords,
macros=list(get_macros(context, file_uri)),
)


Expand Down Expand Up @@ -56,6 +60,17 @@ def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.
return all_models


def get_macros(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]:
"""Return a set of all macros with the ``@`` prefix."""
names = set(macro.get_registry())
Comment thread
benfdking marked this conversation as resolved.
try:
if context is not None:
names.update(context.context._macros)
except Exception:
pass
return names


def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]:
"""
Return a list of sql keywords for a given file.
Expand Down Expand Up @@ -138,6 +153,7 @@ def get_dialect(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t
def extract_keywords_from_content(content: str, dialect: t.Optional[str] = None) -> t.Set[str]:
"""
Extract identifiers from SQL content using the tokenizer.

Only extracts identifiers (variable names, table names, column names, etc.)
that are not SQL keywords.
"""
Expand All @@ -155,7 +171,7 @@ def extract_keywords_from_content(content: str, dialect: t.Optional[str] = None)
keywords.add(token.text)

except Exception:
# If tokenization fails, return empty set
# If tokenization fails, return an empty set
pass

return keywords
19 changes: 7 additions & 12 deletions sqlmesh/lsp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,13 @@ def list_of_models_for_rendering(self) -> t.List[ModelForRendering]:
if audit._path is not None
]

def get_autocomplete(
self, uri: t.Optional[URI], content: t.Optional[str] = None
@staticmethod
def get_completions(
self: t.Optional["LSPContext"] = None,
uri: t.Optional[URI] = None,
file_content: t.Optional[str] = None,
) -> AllModelsResponse:
"""Get autocomplete suggestions for a file.

Args:
uri: The URI of the file to get autocomplete suggestions for.
content: The content of the file (optional).

Returns:
AllModelsResponse containing models and keywords.
"""
"""Get completion suggestions for a file"""
from sqlmesh.lsp.completions import get_sql_completions

return get_sql_completions(self, uri, content)
return get_sql_completions(self, uri, file_content)
1 change: 1 addition & 0 deletions sqlmesh/lsp/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class AllModelsResponse(PydanticModel):

models: t.List[str]
keywords: t.List[str]
macros: t.List[str]


RENDER_MODEL_FEATURE = "sqlmesh/render_model"
Expand Down
27 changes: 27 additions & 0 deletions sqlmesh/lsp/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from lsprotocol.types import Range, Position

from sqlmesh.core.linter.helpers import (
Range as SQLMeshRange,
Position as SQLMeshPosition,
)


def to_lsp_range(
range: SQLMeshRange,
) -> Range:
"""
Converts a SQLMesh Range to an LSP Range.
"""
return Range(
start=Position(line=range.start.line, character=range.start.character),
end=Position(line=range.end.line, character=range.end.character),
)


def to_lsp_position(
position: SQLMeshPosition,
) -> Position:
"""
Converts a SQLMesh Position to an LSP Position.
"""
return Position(line=position.line, character=position.character)
30 changes: 26 additions & 4 deletions sqlmesh/lsp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsRespons

try:
context = self._context_get_or_load(uri)
return context.get_autocomplete(uri, content)
return LSPContext.get_completions(context, uri, content)
except Exception as e:
from sqlmesh.lsp.completions import get_sql_completions

Expand Down Expand Up @@ -565,7 +565,10 @@ def workspace_diagnostic(
)
return types.WorkspaceDiagnosticReport(items=[])

@self.server.feature(types.TEXT_DOCUMENT_COMPLETION)
@self.server.feature(
types.TEXT_DOCUMENT_COMPLETION,
types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros
)
def completion(
ls: LanguageServer, params: types.CompletionParams
) -> t.Optional[types.CompletionList]:
Expand All @@ -583,7 +586,7 @@ def completion(
pass

# Get completions using the existing completions module
completion_response = context.get_autocomplete(uri, content)
completion_response = LSPContext.get_completions(context, uri, content)

completion_items = []
# Add model completions
Expand All @@ -595,7 +598,26 @@ def completion(
detail="SQLMesh Model",
)
)
# Add keyword completions
# Add macro completions
triggered_by_at = (
params.context is not None
and getattr(params.context, "trigger_character", None) == "@"
)

for macro_name in completion_response.macros:
insert_text = macro_name if triggered_by_at else f"@{macro_name}"

completion_items.append(
types.CompletionItem(
label=f"@{macro_name}",
insert_text=insert_text,
insert_text_format=types.InsertTextFormat.PlainText,
filter_text=macro_name,
kind=types.CompletionItemKind.Function,
detail="SQLMesh Macro",
)
)

for keyword in completion_response.keywords:
completion_items.append(
types.CompletionItem(
Expand Down
25 changes: 2 additions & 23 deletions sqlmesh/lsp/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from sqlmesh.core.dialect import normalize_model_name
from sqlmesh.core.linter.helpers import (
TokenPositionDetails,
Range as SQLMeshRange,
Position as SQLMeshPosition,
)
from sqlmesh.core.model.definition import SqlModel
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
from sqlglot import exp
from sqlmesh.lsp.description import generate_markdown_description
from sqlglot.optimizer.scope import build_scope

from sqlmesh.lsp.helpers import to_lsp_range, to_lsp_position
from sqlmesh.lsp.uri import URI
from sqlmesh.utils.pydantic import PydanticModel
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
Expand Down Expand Up @@ -624,24 +624,3 @@ def _position_within_range(position: Position, range: Range) -> bool:
range.end.line > position.line
or (range.end.line == position.line and range.end.character >= position.character)
)


def to_lsp_range(
range: SQLMeshRange,
) -> Range:
"""
Converts a SQLMesh Range to an LSP Range.
"""
return Range(
start=Position(line=range.start.line, character=range.start.character),
end=Position(line=range.end.line, character=range.end.character),
)


def to_lsp_position(
position: SQLMeshPosition,
) -> Position:
"""
Converts a SQLMesh Position to an LSP Position.
"""
return Position(line=position.line, character=position.character)
25 changes: 20 additions & 5 deletions tests/lsp/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,27 @@ def test_get_sql_completions_no_context():
assert len(completions.models) == 0


def test_get_macros():
context = Context(paths=["examples/sushi"])
lsp_context = LSPContext(context)

file_path = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
with open(file_path, "r", encoding="utf-8") as f:
file_content = f.read()

file_uri = URI.from_path(file_path)
completions = LSPContext.get_completions(lsp_context, file_uri, file_content)

assert "each" in completions.macros
assert "add_one" in completions.macros


def test_get_sql_completions_with_context_no_file_uri():
context = Context(paths=["examples/sushi"])
lsp_context = LSPContext(context)

completions = lsp_context.get_autocomplete(None)
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
completions = LSPContext.get_completions(lsp_context, None)
assert len(completions.keywords) >= len(TOKENIZER_KEYWORDS)
assert "sushi.active_customers" in completions.models
assert "sushi.customers" in completions.models

Expand All @@ -37,7 +52,7 @@ def test_get_sql_completions_with_context_and_file_uri():
lsp_context = LSPContext(context)

file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
completions = lsp_context.get_autocomplete(URI.from_path(file_uri))
completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri))
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
assert "sushi.active_customers" not in completions.models

Expand Down Expand Up @@ -84,7 +99,7 @@ def test_get_sql_completions_with_file_content():
"""

file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content)
completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri), content)

# Check that SQL keywords are included
assert any(k in ["SELECT", "FROM", "WHERE", "JOIN"] for k in completions.keywords)
Expand Down Expand Up @@ -129,7 +144,7 @@ def test_get_sql_completions_with_partial_cte_query():
"""

file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content)
completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri), content)

# Check that CTE names are included in the keywords
keywords_list = completions.keywords
Expand Down
109 changes: 109 additions & 0 deletions vscode/extension/tests/completions.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,118 @@ test('Autocomplete for model names', async () => {
expect(
await window.locator('text=sushi.waiter_as_customer_by_day').count(),
).toBe(1)
expect(await window.locator('text=SQLMesh Model').count()).toBe(1)

await close()
} finally {
await fs.remove(tempDir)
}
})

// Skip the macro completions test as regular checks because they are flaky and
// covered in other non-integration tests.
test.describe('Macro Completions', () => {
test('Completion for inbuilt macros', async () => {
const tempDir = await fs.mkdtemp(
path.join(os.tmpdir(), 'vscode-test-sushi-'),
)
await fs.copy(SUSHI_SOURCE_PATH, tempDir)

try {
const { window, close } = await startVSCode(tempDir)

// Wait for the models folder to be visible
await window.waitForSelector('text=models')

// Click on the models folder
await window
.getByRole('treeitem', { name: 'models', exact: true })
.locator('a')
.click()

// Open the top_waiters model
await window
.getByRole('treeitem', { name: 'customers.sql', exact: true })
.locator('a')
.click()

await window.waitForSelector('text=grain')
await window.waitForSelector('text=Loaded SQLMesh Context')

await window.locator('text=grain').first().click()

// Move to the end of the file
await window.keyboard.press('Control+End')

// Add a new line
await window.keyboard.press('Enter')

await window.waitForTimeout(500)

// Hit the '@' key to trigger autocomplete for inbuilt macros
await window.keyboard.press('@')
await window.keyboard.type('eac')

// Wait a moment for autocomplete to appear
await window.waitForTimeout(500)

// Check if the autocomplete suggestion for inbuilt macros is visible
expect(await window.locator('text=@each').count()).toBe(1)

await close()
} finally {
await fs.remove(tempDir)
}
})

test('Completion for custom macros', async () => {
const tempDir = await fs.mkdtemp(
path.join(os.tmpdir(), 'vscode-test-sushi-'),
)
await fs.copy(SUSHI_SOURCE_PATH, tempDir)

try {
const { window, close } = await startVSCode(tempDir)

// Wait for the models folder to be visible
await window.waitForSelector('text=models')

// Click on the models folder
await window
.getByRole('treeitem', { name: 'models', exact: true })
.locator('a')
.click()

// Open the top_waiters model
await window
.getByRole('treeitem', { name: 'customers.sql', exact: true })
.locator('a')
.click()

await window.waitForSelector('text=grain')
await window.waitForSelector('text=Loaded SQLMesh Context')

await window.locator('text=grain').first().click()

// Move to the end of the file
await window.keyboard.press('Control+End')

// Add a new line
await window.keyboard.press('Enter')

// Type the beginning of a macro to trigger autocomplete
await window.keyboard.press('@')
await window.keyboard.type('add_o')

// Wait a moment for autocomplete to appear
await window.waitForTimeout(500)

// Check if the autocomplete suggestion for custom macros is visible
expect(await window.locator('text=@add_one').count()).toBe(1)

await close()
} finally {
await fs.remove(tempDir)
}
})
})