Skip to content

Commit f800529

Browse files
authored
feat(vscode): add completion provider (#4300)
1 parent d0a1d4d commit f800529

8 files changed

Lines changed: 237 additions & 0 deletions

File tree

sqlmesh/lsp/completions.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from functools import lru_cache
2+
from sqlglot import Dialect, Tokenizer
3+
from sqlmesh.lsp.custom import AllModelsResponse
4+
import typing as t
5+
from sqlmesh.lsp.context import LSPContext
6+
7+
8+
def get_sql_completions(context: t.Optional[LSPContext], file_uri: str) -> AllModelsResponse:
9+
"""
10+
Return a list of completions for a given file.
11+
"""
12+
return AllModelsResponse(
13+
models=list(get_models(context, file_uri)),
14+
keywords=list(get_keywords(context, file_uri)),
15+
)
16+
17+
18+
def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[str]) -> t.Set[str]:
19+
"""
20+
Return a list of models for a given file.
21+
22+
If there is no context, return an empty list.
23+
If there is a context, return a list of all models bar the ones the file itself defines.
24+
"""
25+
if context is None:
26+
return set()
27+
all_models = set(model for models in context.map.values() for model in models)
28+
if file_uri is not None:
29+
models_file_refers_to = context.map[file_uri]
30+
for model in models_file_refers_to:
31+
all_models.discard(model)
32+
return all_models
33+
34+
35+
def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[str]) -> t.Set[str]:
36+
"""
37+
Return a list of sql keywords for a given file.
38+
If no context is provided, return ANSI SQL keywords.
39+
40+
If a context is provided but no file_uri is provided, returns the keywords
41+
for the default dialect of the context.
42+
43+
If both a context and a file_uri are provided, returns the keywords
44+
for the dialect of the model that the file belongs to.
45+
"""
46+
if file_uri is not None and context is not None:
47+
models = context.map[file_uri]
48+
if models:
49+
model = models[0]
50+
model_from_context = context.context.get_model(model)
51+
if model_from_context is not None:
52+
if model_from_context.dialect:
53+
return get_keywords_from_tokenizer(model_from_context.dialect)
54+
if context is not None:
55+
return get_keywords_from_tokenizer(context.context.default_dialect)
56+
return get_keywords_from_tokenizer(None)
57+
58+
59+
@lru_cache()
60+
def get_keywords_from_tokenizer(dialect: t.Optional[str] = None) -> t.Set[str]:
61+
"""
62+
Return a list of sql keywords for a given dialect. This is separate from
63+
the direct use of Tokenizer.KEYWORDS.keys() because that returns a set of
64+
keywords that are expanded, e.g. "ORDER BY" -> ["ORDER", "BY"].
65+
"""
66+
tokenizer = Tokenizer
67+
if dialect is not None:
68+
try:
69+
tokenizer = Dialect.get_or_raise(dialect).tokenizer_class
70+
except Exception:
71+
pass
72+
73+
expanded_keywords = set()
74+
for keyword in tokenizer.KEYWORDS.keys():
75+
parts = keyword.split(" ")
76+
expanded_keywords.update(parts)
77+
return expanded_keywords

sqlmesh/lsp/custom.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from lsprotocol import types
2+
import typing as t
3+
from sqlmesh.utils.pydantic import PydanticModel
4+
5+
ALL_MODELS_FEATURE = "sqlmesh/all_models"
6+
7+
8+
class AllModelsRequest(PydanticModel):
9+
"""
10+
Request to get all the models that are in the current project.
11+
"""
12+
13+
textDocument: types.TextDocumentIdentifier
14+
15+
16+
class AllModelsResponse(PydanticModel):
17+
"""
18+
Response to get all the models that are in the current project.
19+
"""
20+
21+
models: t.List[str]
22+
keywords: t.List[str]

sqlmesh/lsp/main.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from sqlmesh._version import __version__
1212
from sqlmesh.core.context import Context
1313
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
14+
from sqlmesh.lsp.completions import get_sql_completions
1415
from sqlmesh.lsp.context import LSPContext
16+
from sqlmesh.lsp.custom import ALL_MODELS_FEATURE, AllModelsRequest, AllModelsResponse
1517
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
1618

1719

@@ -38,6 +40,14 @@ def __init__(
3840
def _register_features(self) -> None:
3941
"""Register LSP features on the internal LanguageServer instance."""
4042

43+
@self.server.feature(ALL_MODELS_FEATURE)
44+
def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse:
45+
try:
46+
context = self._context_get_or_load(params.textDocument.uri)
47+
return get_sql_completions(context, params.textDocument.uri)
48+
except Exception as e:
49+
return get_sql_completions(None, params.textDocument.uri)
50+
4151
@self.server.feature(types.TEXT_DOCUMENT_DID_OPEN)
4252
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
4353
context = self._context_get_or_load(params.text_document.uri)

tests/lsp/test_completions.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pytest
2+
from sqlglot import Tokenizer
3+
from sqlmesh.core.context import Context
4+
from sqlmesh.lsp.completions import get_keywords_from_tokenizer, get_sql_completions
5+
from sqlmesh.lsp.context import LSPContext
6+
7+
8+
TOKENIZER_KEYWORDS = set(Tokenizer.KEYWORDS.keys())
9+
10+
11+
@pytest.mark.fast
12+
def test_get_keywords_from_tokenizer():
13+
assert len(get_keywords_from_tokenizer()) > len(TOKENIZER_KEYWORDS)
14+
15+
16+
@pytest.mark.fast
17+
def test_get_sql_completions_no_context():
18+
completions = get_sql_completions(None, None)
19+
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
20+
assert len(completions.models) == 0
21+
22+
23+
@pytest.mark.fast
24+
def test_get_sql_completions_with_context_no_file_uri():
25+
context = Context(paths=["examples/sushi"])
26+
lsp_context = LSPContext(context)
27+
28+
completions = get_sql_completions(lsp_context, None)
29+
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
30+
assert "sushi.active_customers" in completions.models
31+
assert "sushi.customers" in completions.models
32+
33+
34+
@pytest.mark.fast
35+
def test_get_sql_completions_with_context_and_file_uri():
36+
context = Context(paths=["examples/sushi"])
37+
lsp_context = LSPContext(context)
38+
39+
file_uri = next(
40+
key for key in lsp_context.map.keys() if key.endswith("models/active_customers.sql")
41+
)
42+
completions = get_sql_completions(lsp_context, file_uri)
43+
assert len(completions.keywords) > len(TOKENIZER_KEYWORDS)
44+
assert "sushi.active_customers" not in completions.models
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import * as vscode from 'vscode'
2+
import { LSPClient } from '../lsp/lsp'
3+
import { isErr } from '../utilities/functional/result'
4+
5+
export const selector: vscode.DocumentSelector = {
6+
pattern: '**/*.sql',
7+
}
8+
9+
export const completionProvider = (
10+
lsp: LSPClient,
11+
): vscode.CompletionItemProvider => {
12+
return {
13+
async provideCompletionItems(document) {
14+
const result = await lsp.call_custom_method('sqlmesh/all_models', {
15+
textDocument: {
16+
uri: document.uri.fsPath,
17+
},
18+
})
19+
if (isErr(result)) {
20+
return []
21+
}
22+
const modelCompletions = result.value.models.map(
23+
model =>
24+
new vscode.CompletionItem(model, vscode.CompletionItemKind.Reference),
25+
)
26+
const keywordCompletions = result.value.keywords.map(
27+
keyword =>
28+
new vscode.CompletionItem(keyword, vscode.CompletionItemKind.Keyword),
29+
)
30+
return new vscode.CompletionList([
31+
...modelCompletions,
32+
...keywordCompletions,
33+
])
34+
},
35+
}
36+
}

vscode/extension/src/extension.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import {
1818
handleSqlmeshLspNotFoundError,
1919
handleSqlmeshLspDependenciesMissingError,
2020
} from './utilities/errors'
21+
import { completionProvider } from './completion/completion'
22+
import { selector } from './completion/completion'
2123

2224
let lspClient: LSPClient | undefined
2325

@@ -82,6 +84,13 @@ export async function activate(context: vscode.ExtensionContext) {
8284
context.subscriptions.push(lspClient)
8385
}
8486

87+
context.subscriptions.push(
88+
vscode.languages.registerCompletionItemProvider(
89+
selector,
90+
completionProvider(lspClient),
91+
),
92+
)
93+
8594
const restart = async () => {
8695
if (lspClient) {
8796
traceVerbose('Restarting LSP client')

vscode/extension/src/lsp/custom.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
export interface AllModelsMethod {
2+
method: 'sqlmesh/all_models'
3+
request: AllModelsRequest
4+
response: AllModelsResponse
5+
}
6+
7+
// @eslint-disable-next-line @typescript-eslint/consistent-type-definition
8+
export type CustomLSPMethods = AllModelsMethod
9+
10+
interface AllModelsRequest {
11+
textDocument: {
12+
uri: string
13+
}
14+
}
15+
16+
interface AllModelsResponse {
17+
models: string[]
18+
keywords: string[]
19+
}

vscode/extension/src/lsp/lsp.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { err, isErr, ok, Result } from '../utilities/functional/result'
1010
import { getWorkspaceFolders } from '../utilities/common/vscodeapi'
1111
import { traceError } from '../utilities/common/log'
1212
import { ErrorType } from '../utilities/errors'
13+
import { CustomLSPMethods } from './custom'
1314

1415
let outputChannel: OutputChannel | undefined
1516

@@ -98,4 +99,23 @@ export class LSPClient implements Disposable {
9899
public async dispose() {
99100
await this.stop()
100101
}
102+
103+
public async call_custom_method<
104+
Method extends CustomLSPMethods['method'],
105+
Request extends Extract<CustomLSPMethods, { method: Method }>['request'],
106+
Response extends Extract<CustomLSPMethods, { method: Method }>['response'],
107+
>(method: Method, request: Request): Promise<Result<Response, string>> {
108+
if (!this.client) {
109+
return err('lsp client not ready')
110+
}
111+
try {
112+
const result = await this.client.sendRequest<Response>(method, request)
113+
return ok(result)
114+
} catch (error) {
115+
traceError(
116+
`lsp '${method}' request ${JSON.stringify(request)} failed: ${JSON.stringify(error)}`,
117+
)
118+
return err(JSON.stringify(error))
119+
}
120+
}
101121
}

0 commit comments

Comments
 (0)