Skip to content

Commit 6b9cbf3

Browse files
committed
feat(lsp): add keywords in query to autocompete
1 parent 8edfae4 commit 6b9cbf3

4 files changed

Lines changed: 222 additions & 8 deletions

File tree

sqlmesh/lsp/completions.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,25 @@
77

88

99
def get_sql_completions(
10-
context: t.Optional[LSPContext], file_uri: t.Optional[URI]
10+
context: t.Optional[LSPContext], file_uri: t.Optional[URI], content: t.Optional[str] = None
1111
) -> AllModelsResponse:
1212
"""
1313
Return a list of completions for a given file.
1414
"""
15+
# Get SQL keywords for the dialect
16+
sql_keywords = get_keywords(context, file_uri)
17+
18+
# Get keywords from file content if provided
19+
file_keywords = set()
20+
if content:
21+
file_keywords = extract_keywords_from_content(content, get_dialect(context, file_uri))
22+
23+
# Combine keywords - SQL keywords first, then file keywords
24+
all_keywords = list(sql_keywords) + list(file_keywords - sql_keywords)
25+
1526
return AllModelsResponse(
1627
models=list(get_models(context, file_uri)),
17-
keywords=list(get_keywords(context, file_uri)),
28+
keywords=all_keywords,
1829
)
1930

2031

@@ -97,3 +108,54 @@ def get_keywords_from_tokenizer(dialect: t.Optional[str] = None) -> t.Set[str]:
97108
parts = keyword.split(" ")
98109
expanded_keywords.update(parts)
99110
return expanded_keywords
111+
112+
113+
def get_dialect(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Optional[str]:
114+
"""
115+
Get the dialect for a given file.
116+
"""
117+
if file_uri is not None and context is not None and file_uri.to_path() in context.map:
118+
file_info = context.map[file_uri.to_path()]
119+
120+
# Handle ModelInfo objects
121+
if isinstance(file_info, ModelTarget) and file_info.names:
122+
model_name = file_info.names[0]
123+
model_from_context = context.context.get_model(model_name)
124+
return model_from_context.dialect
125+
126+
# Handle AuditInfo objects
127+
if isinstance(file_info, AuditTarget) and file_info.name:
128+
audit = context.context.standalone_audits.get(file_info.name)
129+
if audit is not None and audit.dialect:
130+
return audit.dialect
131+
132+
if context is not None:
133+
return context.context.default_dialect
134+
135+
return None
136+
137+
138+
def extract_keywords_from_content(content: str, dialect: t.Optional[str] = None) -> t.Set[str]:
139+
"""
140+
Extract identifiers from SQL content using the tokenizer.
141+
Only extracts identifiers (variable names, table names, column names, etc.)
142+
that are not SQL keywords.
143+
"""
144+
if not content:
145+
return set()
146+
147+
tokenizer_class = Dialect.get_or_raise(dialect).tokenizer_class
148+
keywords = set()
149+
try:
150+
tokenizer = tokenizer_class()
151+
tokens = tokenizer.tokenize(content)
152+
for token in tokens:
153+
# Don't include keywords in the set
154+
if token.text.upper() not in tokenizer_class.KEYWORDS:
155+
keywords.add(token.text)
156+
157+
except Exception:
158+
# If tokenization fails, return empty set
159+
pass
160+
161+
return keywords

sqlmesh/lsp/context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,18 @@ def list_of_models_for_rendering(self) -> t.List[ModelForRendering]:
176176
if audit._path is not None
177177
]
178178

179-
def get_autocomplete(self, uri: t.Optional[URI]) -> AllModelsResponse:
179+
def get_autocomplete(
180+
self, uri: t.Optional[URI], content: t.Optional[str] = None
181+
) -> AllModelsResponse:
180182
"""Get autocomplete suggestions for a file.
181183
182184
Args:
183185
uri: The URI of the file to get autocomplete suggestions for.
186+
content: The content of the file (optional).
184187
185188
Returns:
186189
AllModelsResponse containing models and keywords.
187190
"""
188191
from sqlmesh.lsp.completions import get_sql_completions
189192

190-
return get_sql_completions(self, uri)
193+
return get_sql_completions(self, uri, content)

sqlmesh/lsp/main.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,22 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
114114
@self.server.feature(ALL_MODELS_FEATURE)
115115
def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse:
116116
uri = URI(params.textDocument.uri)
117+
118+
# Get the document content
119+
content = None
120+
try:
121+
document = ls.workspace.get_text_document(params.textDocument.uri)
122+
content = document.source
123+
except Exception:
124+
pass
125+
117126
try:
118127
context = self._context_get_or_load(uri)
119-
return context.get_autocomplete(uri)
128+
return context.get_autocomplete(uri, content)
120129
except Exception as e:
121130
from sqlmesh.lsp.completions import get_sql_completions
122131

123-
return get_sql_completions(None, URI(params.textDocument.uri))
132+
return get_sql_completions(None, URI(params.textDocument.uri), content)
124133

125134
@self.server.feature(RENDER_MODEL_FEATURE)
126135
def render_model(ls: LanguageServer, params: RenderModelRequest) -> RenderModelResponse:
@@ -471,8 +480,16 @@ def completion(
471480
uri = URI(params.text_document.uri)
472481
context = self._context_get_or_load(uri)
473482

483+
# Get the document content
484+
content = None
485+
try:
486+
document = ls.workspace.get_text_document(params.text_document.uri)
487+
content = document.source
488+
except Exception:
489+
pass
490+
474491
# Get completions using the existing completions module
475-
completion_response = context.get_autocomplete(uri)
492+
completion_response = context.get_autocomplete(uri, content)
476493

477494
completion_items = []
478495
# Add model completions

tests/lsp/test_completions.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import pytest
22
from sqlglot import Tokenizer
33
from sqlmesh.core.context import Context
4-
from sqlmesh.lsp.completions import get_keywords_from_tokenizer, get_sql_completions
4+
from sqlmesh.lsp.completions import (
5+
get_keywords_from_tokenizer,
6+
get_sql_completions,
7+
extract_keywords_from_content,
8+
)
59
from sqlmesh.lsp.context import LSPContext
610
from sqlmesh.lsp.uri import URI
711

@@ -41,3 +45,131 @@ def test_get_sql_completions_with_context_and_file_uri():
4145
completions = lsp_context.get_autocomplete(URI.from_path(file_uri))
4246
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
4347
assert "sushi.active_customers" not in completions.models
48+
49+
50+
@pytest.mark.fast
51+
def test_extract_keywords_from_content():
52+
# Test extracting keywords from SQL content
53+
content = """
54+
SELECT customer_id, order_date, total_amount
55+
FROM orders o
56+
JOIN customers c ON o.customer_id = c.id
57+
WHERE order_date > '2024-01-01'
58+
"""
59+
60+
keywords = extract_keywords_from_content(content)
61+
62+
# Check that identifiers are extracted
63+
assert "customer_id" in keywords
64+
assert "order_date" in keywords
65+
assert "total_amount" in keywords
66+
assert "orders" in keywords
67+
assert "customers" in keywords
68+
assert "o" in keywords # alias
69+
assert "c" in keywords # alias
70+
assert "id" in keywords
71+
72+
# Check that SQL keywords are NOT included
73+
assert "SELECT" not in keywords
74+
assert "FROM" not in keywords
75+
assert "JOIN" not in keywords
76+
assert "WHERE" not in keywords
77+
assert "ON" not in keywords
78+
79+
80+
@pytest.mark.fast
81+
def test_get_sql_completions_with_file_content():
82+
context = Context(paths=["examples/sushi"])
83+
lsp_context = LSPContext(context)
84+
85+
# SQL content with custom identifiers
86+
content = """
87+
SELECT my_custom_column, another_identifier
88+
FROM my_custom_table mct
89+
JOIN some_other_table sot ON mct.id = sot.table_id
90+
WHERE my_custom_column > 100
91+
"""
92+
93+
file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
94+
completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content)
95+
96+
# Check that SQL keywords are included
97+
assert any(k in ["SELECT", "FROM", "WHERE", "JOIN"] for k in completions.keywords)
98+
99+
# Check that file-specific identifiers are included at the end
100+
keywords_list = completions.keywords
101+
assert "my_custom_column" in keywords_list
102+
assert "another_identifier" in keywords_list
103+
assert "my_custom_table" in keywords_list
104+
assert "some_other_table" in keywords_list
105+
assert "mct" in keywords_list # alias
106+
assert "sot" in keywords_list # alias
107+
assert "table_id" in keywords_list
108+
109+
# Check that file keywords come after SQL keywords
110+
# SQL keywords should appear first in the list
111+
sql_keyword_indices = [
112+
i for i, k in enumerate(keywords_list) if k in ["SELECT", "FROM", "WHERE", "JOIN"]
113+
]
114+
file_keyword_indices = [
115+
i for i, k in enumerate(keywords_list) if k in ["my_custom_column", "my_custom_table"]
116+
]
117+
118+
if sql_keyword_indices and file_keyword_indices:
119+
assert max(sql_keyword_indices) < min(file_keyword_indices), (
120+
"SQL keywords should come before file keywords"
121+
)
122+
123+
124+
@pytest.mark.fast
125+
def test_get_sql_completions_with_partial_cte_query():
126+
context = Context(paths=["examples/sushi"])
127+
lsp_context = LSPContext(context)
128+
129+
# Partial SQL query with CTEs
130+
content = """
131+
WITH _latest_complete_month AS (
132+
SELECT MAX(date_trunc('month', order_date)) as month
133+
FROM orders
134+
),
135+
_filtered AS (
136+
SELECT * FROM
137+
"""
138+
139+
file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql")
140+
completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content)
141+
142+
# Check that CTE names are included in the keywords
143+
keywords_list = completions.keywords
144+
assert "_latest_complete_month" in keywords_list
145+
assert "_filtered" in keywords_list
146+
147+
# Also check other identifiers from the partial query
148+
assert "month" in keywords_list
149+
assert "order_date" in keywords_list
150+
assert "orders" in keywords_list
151+
152+
153+
@pytest.mark.fast
154+
def test_extract_keywords_from_partial_query():
155+
# Test extracting keywords from an incomplete SQL query
156+
content = """
157+
WITH cte1 AS (
158+
SELECT col1, col2 FROM table1
159+
),
160+
cte2 AS (
161+
SELECT * FROM cte1 WHERE
162+
"""
163+
164+
keywords = extract_keywords_from_content(content)
165+
166+
# Check that CTEs are extracted
167+
assert "cte1" in keywords
168+
assert "cte2" in keywords
169+
170+
# Check that columns and tables are extracted
171+
assert "col1" in keywords
172+
assert "col2" in keywords
173+
assert "table1" in keywords
174+
175+
# Even though the query is incomplete, identifiers should still be extracted

0 commit comments

Comments
 (0)