Skip to content

Commit 2673372

Browse files
committed
temp [ci skip]
1 parent 312e87a commit 2673372

3 files changed

Lines changed: 57 additions & 41 deletions

File tree

examples/sushi/models/customers.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ WITH current_marketing AS (
3232
)
3333
SELECT DISTINCT
3434
o.customer_id::INT AS customer_id, /* this comment should not be registered */
35-
m.status,
35+
m.status ,
3636
d.zip
3737
FROM sushi.orders AS o
3838
LEFT JOIN current_marketing AS m

sqlmesh/core/context.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,11 +1078,19 @@ def format(
10781078
) -> bool:
10791079
"""Format all SQL models and audits."""
10801080
filtered_targets = [
1081-
target for target in self._models.values() | self._audits.values()
1082-
if target._path is not None and target._path.suffix == ".sql"
1081+
target
1082+
for target in chain(self._models.values(), self._audits.values())
1083+
if target._path is not None
1084+
and target._path.suffix == ".sql"
10831085
and (not paths or any(target._path.samefile(p) for p in paths))
10841086
]
1087+
unformatted_file_paths = []
1088+
10851089
for target in filtered_targets:
1090+
if (
1091+
target._path is None
1092+
): # introduced to satisfy type checker as still want to pull filter out as many targets as possible before loop
1093+
continue
10861094
with open(target._path, "r+", encoding="utf-8") as file:
10871095
before = file.read()
10881096
expressions = parse(before, default_dialect=self.config_for_node(target).dialect)
@@ -1129,11 +1137,6 @@ def format(
11291137
return False
11301138

11311139
return True
1132-
1133-
1134-
1135-
1136-
11371140

11381141
@python_api_analytics
11391142
def plan(

sqlmesh/lsp/main.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import logging
55
import typing as t
6-
import weakref
76
from contextlib import suppress
87
from pathlib import Path
98

@@ -12,70 +11,84 @@
1211
from pygls.server import LanguageServer
1312
from pygls.workspace import TextDocument
1413
from sqlmesh._version import __version__
14+
from sqlmesh.core.dialect import Audit
1515
from sqlmesh.core.model import Model
1616

1717
logger = logging.getLogger(__name__)
1818

19-
CONTEXTS: t.Dict[str, Context] = {}
20-
"""A mapping of workspace paths to SQLMesh contexts."""
19+
GLOBAL_CONTEXT: t.Optional[
20+
t.Tuple[
21+
Context,
22+
t.Dict[str, t.Tuple[Context, t.Union[Model, Audit]]],
23+
]
24+
] = None
2125

22-
PATHS_TO_MODELS: t.Dict[str, t.Tuple[Context, Model]] = {}
23-
"""A mapping of file paths to SQLMesh (context, model) tuples."""
2426

2527
server = LanguageServer("sqlmesh_lsp", __version__)
2628

27-
_CACHE: t.Set[str] = set()
28-
"""A cache of URIs for which we have already ensured a context exists."""
29-
3029

3130
def ensure_context_for_document(document: TextDocument) -> TextDocument:
32-
"""Ensure that a context exists for the given document if applicable."""
33-
if document.uri in _CACHE:
34-
return document
35-
_CACHE.add(document.uri)
31+
"""Ensure that a context exists for the given document if applicable by searching for a config.py or config.yml file in the parent directories."""
32+
# If the context is already loaded, return the document, if it is part of the same context
33+
global GLOBAL_CONTEXT
34+
if GLOBAL_CONTEXT is not None:
35+
GLOBAL_CONTEXT[0].load()
36+
if document.uri in GLOBAL_CONTEXT[1]:
37+
return document
38+
else:
39+
for model in GLOBAL_CONTEXT[0]._models.values():
40+
if model._path is None:
41+
continue
42+
path = model._path.resolve()
43+
if path == document.path:
44+
GLOBAL_CONTEXT = (
45+
GLOBAL_CONTEXT[0],
46+
GLOBAL_CONTEXT[1] | {document.uri: (GLOBAL_CONTEXT[0], model)},
47+
)
48+
return document
49+
for audit in GLOBAL_CONTEXT[0]._audits.values():
50+
if audit._path is None:
51+
continue
52+
path = audit._path.resolve()
53+
if path == document.path:
54+
GLOBAL_CONTEXT = (
55+
GLOBAL_CONTEXT[0],
56+
GLOBAL_CONTEXT[1] | {document.uri: (GLOBAL_CONTEXT[0], audit)},
57+
)
58+
return document
59+
return document
60+
61+
# If there is no context, load the context and then call this function again
3662
path = Path(document.path).resolve()
3763
if path.suffix not in (".sql", ".py"):
3864
return document
39-
initial_path = path
40-
while path.parents:
41-
if str(path) in CONTEXTS:
42-
return document
43-
path = path.parent
44-
path = initial_path
4565
loaded = False
4666
while path.parents and not loaded:
4767
for ext in ("py", "yml", "yaml"):
4868
config_path = path / f"config.{ext}"
4969
if config_path.exists():
5070
with suppress(Exception):
51-
handle = Context(paths=[f"{path}"])
52-
CONTEXTS[str(path)] = handle
53-
PATHS_TO_MODELS.update(
54-
{
55-
str(model._path.resolve()): (handle, weakref.proxy(model))
56-
for model in handle.models.values()
57-
}
58-
)
71+
handle = Context(paths=[path])
72+
GLOBAL_CONTEXT = (handle, {})
5973
server.show_message(f"Context loaded for: {path}")
6074
loaded = True
61-
break
75+
return ensure_context_for_document(document)
6276
path = path.parent
77+
6378
return document
6479

6580

6681
@server.feature(types.TEXT_DOCUMENT_FORMATTING)
6782
def formatting(
6883
ls: LanguageServer, params: types.DocumentFormattingParams
6984
) -> t.List[types.TextEdit]:
70-
"""Format the document based using SQLMesh format_model_expressions."""
85+
"""Format the document using SQLMesh format_model_expressions."""
7186
try:
72-
logger.info(f"Formatting document: {params.text_document.uri}")
7387
document = ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
74-
context, _ = PATHS_TO_MODELS.get(document.path, (None, None))
88+
context = GLOBAL_CONTEXT
7589
if context is None:
76-
logger.error(f"No context found for document: {document.path}")
77-
return []
78-
context.format(paths=(Path(document.path),))
90+
raise Exception(f"No context found for document: {document.path}")
91+
context[0].format(paths=(Path(document.path),))
7992
with open(document.path, "r+", encoding="utf-8") as file:
8093
return [
8194
types.TextEdit(

0 commit comments

Comments
 (0)