Skip to content

Commit 5a5ff94

Browse files
committed
feat: add model autocomplete to vscode
- includes dialect keywords - includes models
1 parent 61e63f2 commit 5a5ff94

10 files changed

Lines changed: 238 additions & 39 deletions

File tree

examples/sushi/models/latest_order.sql

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,4 @@ MODEL (
1111

1212
SELECT id, customer_id, start_ts, end_ts, event_date
1313
FROM sushi.orders
14-
ORDER BY event_date DESC LIMIT 1
15-
14+
ORDER BY event_date DESC LIMIT 1

sqlmesh/lsp/completions.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from functools import lru_cache
2+
from sqlglot import Tokenizer
3+
from sqlmesh.lsp.custom import AllModelsResponse
4+
import typing as t
5+
from sqlmesh.lsp.context import LSPContext
6+
7+
8+
def get_sql_completions(context: t.Optional[LSPContext], file_uri: str) -> AllModelsResponse:
9+
"""
10+
Return a list of completions for a given file.
11+
"""
12+
return AllModelsResponse(
13+
models=list(get_models(context, file_uri)),
14+
keywords=list(get_keywords(context, file_uri)),
15+
)
16+
17+
18+
def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[str]) -> t.Set[str]:
19+
"""
20+
Return a list of models for a given file.
21+
22+
If there is no context, return an empty list.
23+
If there is a context, return a list of all models bar the ones the file itself defines.
24+
"""
25+
if context is None:
26+
return set()
27+
all_models = set(model for models in context.map.values() for model in models)
28+
if file_uri is not None:
29+
models_file_refers_to = context.map[file_uri]
30+
for model in models_file_refers_to:
31+
all_models.discard(model)
32+
return all_models
33+
34+
35+
def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[str]) -> t.Set[str]:
36+
"""
37+
Return a list of sql keywords for a given file.
38+
If no context is provided, return ANSI SQL keywords.
39+
40+
If a context is provided but no file_uri is provided, returns the keywords
41+
for the default dialect of the context.
42+
43+
If both a context and a file_uri are provided, returns the keywords
44+
for the dialect of the model that the file belongs to.
45+
"""
46+
return get_keywords_from_tokenizer()
47+
48+
49+
@lru_cache(maxsize=1)
50+
def get_keywords_from_tokenizer() -> t.Set[str]:
51+
keywords = list(Tokenizer.KEYWORDS.keys())
52+
expanded_keywords = []
53+
54+
for keyword in keywords:
55+
parts = keyword.split(" ")
56+
expanded_keywords.extend(parts)
57+
58+
return set(expanded_keywords)

sqlmesh/lsp/custom.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from lsprotocol import types
2+
import typing as t
3+
from sqlmesh.utils.pydantic import PydanticModel
4+
5+
ALL_MODELS_FEATURE = "sqlmesh/all_models"
6+
7+
8+
class AllModelsRequest(PydanticModel):
9+
"""
10+
Request to get all the models that are in the current project.
11+
"""
12+
13+
textDocument: types.TextDocumentIdentifier
14+
15+
16+
class AllModelsResponse(PydanticModel):
17+
"""
18+
Response to get all the models that are in the current project.
19+
"""
20+
21+
models: t.List[str]
22+
keywords: t.List[str]

sqlmesh/lsp/main.py

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from sqlmesh._version import __version__
1212
from sqlmesh.core.context import Context
1313
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
14+
from sqlmesh.lsp.completions import get_sql_completions
1415
from sqlmesh.lsp.context import LSPContext
16+
from sqlmesh.lsp.custom import ALL_MODELS_FEATURE, AllModelsRequest, AllModelsResponse
1517
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
1618

1719

@@ -38,6 +40,14 @@ def __init__(
3840
def _register_features(self) -> None:
3941
"""Register LSP features on the internal LanguageServer instance."""
4042

43+
@self.server.feature(ALL_MODELS_FEATURE)
44+
def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse:
45+
try:
46+
context = self._context_get_or_load(params.textDocument.uri)
47+
return get_sql_completions(context, params.textDocument.uri)
48+
except Exception as e:
49+
return get_sql_completions(None, params.textDocument.uri)
50+
4151
@self.server.feature(types.TEXT_DOCUMENT_DID_OPEN)
4252
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
4353
context = self._context_get_or_load(params.text_document.uri)
@@ -130,43 +140,6 @@ def formatting(
130140
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
131141
return []
132142

133-
@self.server.feature(types.TEXT_DOCUMENT_DEFINITION)
134-
def goto_definition(
135-
ls: LanguageServer, params: types.DefinitionParams
136-
) -> t.List[types.LocationLink]:
137-
"""Jump to an object's definition."""
138-
try:
139-
self._ensure_context_for_document(params.text_document.uri)
140-
document = ls.workspace.get_document(params.text_document.uri)
141-
if self.lsp_context is None:
142-
raise RuntimeError(f"No context found for document: {document.path}")
143-
144-
references = get_model_definitions_for_a_path(
145-
self.lsp_context, params.text_document.uri
146-
)
147-
if len(references) == 0:
148-
return []
149-
150-
return [
151-
types.LocationLink(
152-
target_uri=reference.uri,
153-
target_selection_range=types.Range(
154-
start=types.Position(line=0, character=0),
155-
end=types.Position(line=0, character=0),
156-
),
157-
target_range=types.Range(
158-
start=types.Position(line=0, character=0),
159-
end=types.Position(line=0, character=0),
160-
),
161-
origin_selection_range=reference.range,
162-
)
163-
for reference in references
164-
]
165-
166-
except Exception as e:
167-
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
168-
return []
169-
170143
@self.server.feature(types.TEXT_DOCUMENT_DEFINITION)
171144
def goto_definition(
172145
ls: LanguageServer, params: types.DefinitionParams

tests/lsp/test_completions.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pytest
2+
from sqlglot import Tokenizer
3+
from sqlmesh.core.context import Context
4+
from sqlmesh.lsp.completions import get_keywords_from_tokenizer, get_sql_completions
5+
from sqlmesh.lsp.context import LSPContext
6+
7+
8+
TOKENIZER_KEYWORDS = set(Tokenizer.KEYWORDS.keys())
9+
10+
11+
@pytest.mark.fast
12+
def test_get_keywords_from_tokenizer():
13+
assert len(get_keywords_from_tokenizer()) > len(TOKENIZER_KEYWORDS)
14+
15+
16+
@pytest.mark.fast
17+
def test_get_sql_completions_no_context():
18+
completions = get_sql_completions(None, None)
19+
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
20+
assert len(completions.models) == 0
21+
22+
23+
@pytest.mark.fast
24+
def test_get_sql_completions_with_context_no_file_uri():
25+
context = Context(paths=["examples/sushi"])
26+
lsp_context = LSPContext(context)
27+
28+
completions = get_sql_completions(lsp_context, None)
29+
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
30+
assert "sushi.active_customers" in completions.models
31+
assert "sushi.customers" in completions.models
32+
33+
34+
@pytest.mark.fast
35+
def test_get_sql_completions_with_context_and_file_uri():
36+
context = Context(paths=["examples/sushi"])
37+
lsp_context = LSPContext(context)
38+
39+
file_uri = next(
40+
key for key in lsp_context.map.keys() if key.endswith("models/active_customers.sql")
41+
)
42+
completions = get_sql_completions(lsp_context, file_uri)
43+
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
44+
assert "sushi.active_customers" not in completions.models

vscode/extension/package.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@
5757
"command": "sqlmesh.signout",
5858
"title": "Sign out from Tobiko Cloud",
5959
"description": "SQLMesh"
60+
},
61+
{
62+
"command": "sqlmesh.allModels",
63+
"title": "Get all models",
64+
"description": "SQLMesh"
6065
}
6166
]
6267
},
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import * as vscode from 'vscode'
2+
import { LSPClient } from '../lsp/lsp'
3+
import { isErr } from '../utilities/functional/result'
4+
5+
export const selector: vscode.DocumentSelector = {
6+
pattern: '**/*.sql',
7+
}
8+
9+
export const completionProvider = (
10+
lsp: LSPClient,
11+
): vscode.CompletionItemProvider => {
12+
return {
13+
async provideCompletionItems(document) {
14+
const result = await lsp.call_custom_method('sqlmesh/all_models', {
15+
textDocument: {
16+
uri: document.uri.fsPath,
17+
},
18+
})
19+
if (isErr(result)) {
20+
return []
21+
}
22+
const modelCompletions = result.value.models.map(
23+
model =>
24+
new vscode.CompletionItem(model, vscode.CompletionItemKind.Reference),
25+
)
26+
const keywordCompletions = result.value.keywords.map(
27+
keyword =>
28+
new vscode.CompletionItem(keyword, vscode.CompletionItemKind.Keyword),
29+
)
30+
return new vscode.CompletionList([
31+
...modelCompletions,
32+
...keywordCompletions,
33+
])
34+
},
35+
}
36+
}

vscode/extension/src/extension.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import {
1818
handleSqlmeshLspNotFoundError,
1919
handleSqlmeshLspDependenciesMissingError,
2020
} from './utilities/errors'
21+
import { completionProvider } from './completion/completion'
22+
import { selector } from './completion/completion'
2123

2224
let lspClient: LSPClient | undefined
2325

@@ -82,6 +84,13 @@ export async function activate(context: vscode.ExtensionContext) {
8284
context.subscriptions.push(lspClient)
8385
}
8486

87+
context.subscriptions.push(
88+
vscode.languages.registerCompletionItemProvider(
89+
selector,
90+
completionProvider(lspClient),
91+
),
92+
)
93+
8594
const restart = async () => {
8695
if (lspClient) {
8796
traceVerbose('Restarting LSP client')

vscode/extension/src/lsp/custom.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
export interface AllModelsMethod {
2+
method: 'sqlmesh/all_models'
3+
request: AllModelsRequest
4+
response: AllModelsResponse
5+
}
6+
7+
export interface TestMethod {
8+
method: 'sqlmesh/test'
9+
request: TestRequest
10+
response: TestResponse
11+
}
12+
13+
// @eslint-disable-next-line @typescript-eslint/consistent-type-definition
14+
export type CustomLSPMethods = AllModelsMethod | TestMethod
15+
16+
interface AllModelsRequest {
17+
textDocument: {
18+
uri: string
19+
}
20+
}
21+
22+
interface AllModelsResponse {
23+
models: string[]
24+
keywords: string[]
25+
}
26+
27+
interface TestRequest {
28+
foo: string
29+
}
30+
31+
interface TestResponse {
32+
bar: string
33+
}

vscode/extension/src/lsp/lsp.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { err, isErr, ok, Result } from '../utilities/functional/result'
1010
import { getWorkspaceFolders } from '../utilities/common/vscodeapi'
1111
import { traceError } from '../utilities/common/log'
1212
import { ErrorType } from '../utilities/errors'
13+
import { CustomLSPMethods } from './custom'
1314

1415
let outputChannel: OutputChannel | undefined
1516

@@ -98,4 +99,23 @@ export class LSPClient implements Disposable {
9899
public async dispose() {
99100
await this.stop()
100101
}
102+
103+
public async call_custom_method<
104+
Method extends CustomLSPMethods['method'],
105+
Request extends Extract<CustomLSPMethods, { method: Method }>['request'],
106+
Response extends Extract<CustomLSPMethods, { method: Method }>['response'],
107+
>(method: Method, request: Request): Promise<Result<Response, string>> {
108+
if (!this.client) {
109+
return err('lsp client not ready')
110+
}
111+
try {
112+
const result = await this.client.sendRequest<Response>(method, request)
113+
return ok(result)
114+
} catch (error) {
115+
traceError(
116+
`lsp '${method}' request ${JSON.stringify(request)} failed: ${JSON.stringify(error)}`,
117+
)
118+
return err(JSON.stringify(error))
119+
}
120+
}
101121
}

0 commit comments

Comments
 (0)