Skip to content

Commit b9f5c3b

Browse files
address comments
1 parent 97e5f99 commit b9f5c3b

5 files changed

Lines changed: 44 additions & 41 deletions

File tree

sqlmesh/core/loader.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646

4747

4848
if t.TYPE_CHECKING:
49-
from sqlmesh.core.config import Config
5049
from sqlmesh.core.context import GenericContext
5150

5251

@@ -104,21 +103,21 @@ def get(self, path: Path) -> t.List[Model]:
104103

105104
_defaults: t.Optional[t.Dict[str, t.Any]] = None
106105
_cache: t.Optional[CacheBase] = None
107-
_config: t.Optional[Config] = None
106+
_config_essentials: t.Optional[t.Dict[str, t.Any]] = None
108107
_selected_gateway: t.Optional[str] = None
109108

110109

111110
def _init_model_defaults(
112-
config: Config,
111+
config_essentials: t.Dict[str, t.Any],
113112
selected_gateway: t.Optional[str],
114113
model_loading_defaults: t.Optional[t.Dict[str, t.Any]] = None,
115114
cache: t.Optional[CacheBase] = None,
116115
console: t.Optional[Console] = None,
117116
) -> None:
118-
global _defaults, _cache, _config, _selected_gateway
117+
global _defaults, _cache, _config_essentials, _selected_gateway
119118
_defaults = model_loading_defaults
120119
_cache = cache
121-
_config = config
120+
_config_essentials = config_essentials
122121
_selected_gateway = selected_gateway
123122

124123
# Set the console passed from the parent process
@@ -140,22 +139,22 @@ def load_sql_models(path: Path) -> t.List[Model]:
140139

141140

142141
def get_variables(gateway_name: t.Optional[str] = None) -> t.Dict[str, t.Any]:
143-
assert _config
142+
assert _config_essentials
144143

145144
gateway_name = gateway_name or _selected_gateway
146145

147146
try:
148-
gateway = _config.get_gateway(gateway_name)
147+
gateway = _config_essentials["gateways"].get(gateway_name)
149148
except ConfigError:
150149
from sqlmesh.core.console import get_console
151150

152151
get_console().log_warning(
153-
f"Gateway '{gateway_name}' not found in project '{_config.project}'."
152+
f"Gateway '{gateway_name}' not found in project '{_config_essentials['project']}'."
154153
)
155154
gateway = None
156155

157156
return {
158-
**_config.variables,
157+
**_config_essentials["variables"],
159158
**(gateway.variables if gateway else {}),
160159
c.GATEWAY: gateway_name,
161160
}
@@ -174,7 +173,12 @@ def __init__(self, context: GenericContext, path: Path) -> None:
174173
self._variables_by_gateway: t.Dict[str, t.Dict[str, t.Any]] = {}
175174
self._console = get_console()
176175

177-
_init_model_defaults(self.config, self.context.selected_gateway)
176+
self.config_essentials = {
177+
"project": self.config.project,
178+
"variables": self.config.variables,
179+
"gateways": self.config.gateways,
180+
}
181+
_init_model_defaults(self.config_essentials, self.context.selected_gateway)
178182

179183
def load(self) -> LoadedProject:
180184
"""
@@ -539,6 +543,7 @@ def _load_sql_models(
539543
"""Loads the sql models into a Dict"""
540544
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
541545
paths: t.Set[Path] = set()
546+
cached_paths: UniqueKeyDict[Path, t.List[Model]] = UniqueKeyDict("cached_paths")
542547

543548
for path in self._glob_paths(
544549
self.config_path / c.MODELS,
@@ -550,14 +555,14 @@ def _load_sql_models(
550555

551556
self._track_file(path)
552557
paths.add(path)
558+
if cached_models := cache.get(path):
559+
cached_paths[path] = cached_models
553560

554-
for path in paths.copy():
555-
cached_models = cache.get(path)
556-
if cached_models:
557-
paths.remove(path)
558-
for model in cached_models:
559-
if model.enabled:
560-
models[model.fqn] = model
561+
for path, cached_models in cached_paths.items():
562+
paths.remove(path)
563+
for model in cached_models:
564+
if model.enabled:
565+
models[model.fqn] = model
561566

562567
if paths:
563568
model_loading_defaults = dict(
@@ -578,10 +583,18 @@ def _load_sql_models(
578583
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
579584
)
580585

581-
errors: t.List[str] = []
586+
# if not c.MAX_FORK_WORKERS:
587+
# breakpoint()
588+
582589
with create_process_pool_executor(
583590
initializer=_init_model_defaults,
584-
initargs=(self.config, gateway, model_loading_defaults, cache, self._console),
591+
initargs=(
592+
self.config_essentials,
593+
gateway,
594+
model_loading_defaults,
595+
cache,
596+
self._console,
597+
),
585598
max_workers=c.MAX_FORK_WORKERS,
586599
) as pool:
587600
futures_to_paths = {pool.submit(load_sql_models, path): path for path in paths}
@@ -591,7 +604,7 @@ def _load_sql_models(
591604
loaded = future.result()
592605
for model in loaded or cache.get(path):
593606
if model.fqn in models:
594-
errors.append(
607+
raise ConfigError(
595608
self._failed_to_load_model_error(
596609
path, f"Duplicate SQL model name: '{model.name}'."
597610
)
@@ -600,11 +613,7 @@ def _load_sql_models(
600613
model._path = path
601614
models[model.fqn] = model
602615
except Exception as ex:
603-
errors.append(self._failed_to_load_model_error(path, str(ex)))
604-
605-
if errors:
606-
error_string = "\n".join(errors)
607-
raise ConfigError(error_string)
616+
raise ConfigError(self._failed_to_load_model_error(path, ex))
608617

609618
return models
610619

sqlmesh/core/model/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def get_or_load(
6565
return models
6666

6767
def put(self, models: t.List[Model], name: str, entry_id: str = "") -> bool:
68-
if isinstance(models, list) and isinstance(seq_get(models, 0), (SqlModel, ExternalModel)):
68+
if models and isinstance(seq_get(models, 0), (SqlModel, ExternalModel)):
6969
# make sure we preload full_depends_on
7070
for model in models:
7171
model.full_depends_on

sqlmesh/core/model/schema.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,10 @@ def process_models(completed_model: t.Optional[Model] = None) -> None:
8181
)
8282
)
8383

84-
errors: t.List[str] = []
8584
with optimized_query_cache_pool(optimized_query_cache) as executor:
8685
process_models()
8786

88-
while futures and not errors:
87+
while futures:
8988
for future in as_completed(futures):
9089
try:
9190
futures.remove(future)
@@ -99,8 +98,4 @@ def process_models(completed_model: t.Optional[Model] = None) -> None:
9998
_update_schema_with_model(schema, model)
10099
process_models(completed_model=model)
101100
except Exception as ex:
102-
errors.append(f"{ex}")
103-
104-
if errors:
105-
error_string = "\n".join(errors)
106-
raise SchemaError(f"Failed to update model schemas\n\n{error_string}")
101+
raise SchemaError(f"Failed to update model schemas\n\n{ex}")

sqlmesh/utils/process.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from concurrent.futures import Future, ProcessPoolExecutor
44
import typing as t
55
import multiprocessing as mp
6-
from sqlmesh.core import constants as c
76
from sqlmesh.utils.windows import IS_WINDOWS
87

98

@@ -26,7 +25,7 @@ def __enter__(self):
2625

2726
def __exit__(self, *args):
2827
self.shutdown(wait=True)
29-
return True
28+
return False
3029

3130
def shutdown(self, wait=True, cancel_futures=False):
3231
"""No-op method to match ProcessPoolExecutor API.
@@ -59,7 +58,7 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):
5958

6059

6160
def create_process_pool_executor(
62-
initializer: t.Callable, initargs: t.Tuple, max_workers: t.Optional[int] = c.MAX_FORK_WORKERS
61+
initializer: t.Callable, initargs: t.Tuple, max_workers: t.Optional[int]
6362
) -> PoolExecutor:
6463
if max_workers == 1 or IS_WINDOWS:
6564
return SynchronousPoolExecutor(

tests/core/test_integration.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4133,12 +4133,12 @@ def test_plan_repairs_unrenderable_snapshot_state(
41334133
f"name = '{target_snapshot.name}' AND identifier = '{target_snapshot.identifier}'",
41344134
)
41354135

4136-
context.clear_caches()
4137-
4138-
target_snapshot_in_state = context.state_sync.get_snapshots([target_snapshot.snapshot_id])[
4139-
target_snapshot.snapshot_id
4140-
]
41414136
with pytest.raises(Exception):
4137+
context_copy = context.copy()
4138+
context_copy.clear_caches()
4139+
target_snapshot_in_state = context_copy.state_sync.get_snapshots(
4140+
[target_snapshot.snapshot_id]
4141+
)[target_snapshot.snapshot_id]
41424142
target_snapshot_in_state.model.render_query_or_raise()
41434143

41444144
# Repair the snapshot by creating a new version of it

0 commit comments

Comments
 (0)