Skip to content

Commit a6a464e

Browse files
tobymaothemisvaltinos
authored andcommitted
feat: more cores for loading
1 parent 4b3e2bb commit a6a464e

3 files changed

Lines changed: 192 additions & 129 deletions

File tree

sqlmesh/core/loader.py

Lines changed: 168 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
import glob
55
import itertools
66
import linecache
7-
import logging
7+
import multiprocessing as mp
88
import os
99
import re
1010
import typing as t
1111
from collections import Counter, defaultdict
1212
from dataclasses import dataclass
1313
from pathlib import Path
14+
from concurrent.futures import ProcessPoolExecutor, as_completed
1415

1516
from sqlglot.errors import SqlglotError
1617
from sqlglot import exp
@@ -27,7 +28,6 @@
2728
from sqlmesh.core.model import (
2829
Model,
2930
ModelCache,
30-
SeedModel,
3131
create_external_model,
3232
load_sql_based_models,
3333
)
@@ -43,11 +43,10 @@
4343

4444

4545
if t.TYPE_CHECKING:
46+
from sqlmesh.core.config import Config
4647
from sqlmesh.core.context import GenericContext
4748

4849

49-
logger = logging.getLogger(__name__)
50-
5150
GATEWAY_PATTERN = re.compile(r"gateway:\s*([^\s]+)")
5251

5352

@@ -67,21 +66,80 @@ class LoadedProject:
6766

6867
class CacheBase(abc.ABC):
6968
@abc.abstractmethod
70-
def get_or_load_models(
71-
self, target_path: Path, loader: t.Callable[[], t.List[Model]]
72-
) -> t.List[Model]:
73-
"""Get or load all models from cache."""
69+
def put(self, models: t.List[Model], path: Path) -> bool:
70+
pass
71+
72+
@abc.abstractmethod
73+
def get(self, path: Path) -> t.List[Model]:
74+
pass
75+
76+
77+
_defaults: t.Optional[t.Dict[str, t.Any]] = None
78+
_cache: t.Optional[CacheBase] = None
79+
_config: t.Optional[Config] = None
80+
_selected_gateway: t.Optional[str] = None
81+
82+
83+
def _init_model_defaults(
84+
config: Config,
85+
selected_gateway: t.Optional[str],
86+
defaults: t.Optional[t.Dict[str, t.Any]] = None,
87+
cache: t.Optional[CacheBase] = None,
88+
) -> None:
89+
global _defaults, _cache, _config, _selected_gateway
90+
_defaults = defaults
91+
_cache = cache
92+
_config = config
93+
_selected_gateway = selected_gateway
94+
95+
96+
def load_sql_models(path: Path) -> t.Tuple[Path, list[Model]]:
97+
assert _defaults
98+
assert _cache
99+
100+
with open(path, "r", encoding="utf-8") as file:
101+
expressions = parse(file.read(), default_dialect=_defaults["dialect"])
102+
models = load_sql_based_models(expressions, path=Path(path).absolute(), **_defaults)
103+
104+
return (path, [] if _cache.put(models, path) else models)
105+
106+
107+
def get_variables(gateway_name: t.Optional[str] = None) -> t.Dict[str, t.Any]:
108+
assert _config
109+
110+
gateway_name = gateway_name or _selected_gateway
111+
112+
try:
113+
gateway = _config.get_gateway(gateway_name)
114+
except ConfigError:
115+
from sqlmesh.core.console import get_console
116+
117+
get_console().log_warning(
118+
f"Gateway '{gateway_name}' not found in project '{_config.project}'."
119+
)
120+
gateway = None
121+
122+
return {
123+
**_config.variables,
124+
**(gateway.variables if gateway else {}),
125+
c.GATEWAY: gateway_name,
126+
}
74127

75128

76129
class Loader(abc.ABC):
77130
"""Abstract base class to load macros and models for a context"""
78131

79132
def __init__(self, context: GenericContext, path: Path) -> None:
133+
from sqlmesh.core.console import get_console
134+
80135
self._path_mtimes: t.Dict[Path, float] = {}
81136
self.context = context
82137
self.config_path = path
83138
self.config = self.context.configs[self.config_path]
84139
self._variables_by_gateway: t.Dict[str, t.Dict[str, t.Any]] = {}
140+
self._console = get_console()
141+
142+
_init_model_defaults(self.config, self.context.selected_gateway)
85143

86144
def load(self) -> LoadedProject:
87145
"""
@@ -223,30 +281,32 @@ def _load_external_models(
223281
if external_models_path.exists() and external_models_path.is_dir():
224282
paths_to_load.extend(self._glob_paths(external_models_path, extension=".yaml"))
225283

226-
def _load() -> t.List[Model]:
227-
try:
228-
with open(path, "r", encoding="utf-8") as file:
229-
return [
230-
create_external_model(
231-
defaults=self.config.model_defaults.dict(),
232-
path=path,
233-
project=self.config.project,
234-
audit_definitions=audits,
235-
**{
236-
"dialect": self.config.model_defaults.dialect,
237-
"default_catalog": self.context.default_catalog,
238-
**row,
239-
},
240-
)
241-
for row in YAML().load(file.read())
242-
]
243-
except Exception as ex:
244-
raise ConfigError(f"Failed to load model definition at '{path}'.\n{ex}")
245-
246284
for path in paths_to_load:
247285
self._track_file(path)
286+
external_models = cache.get(path)
287+
288+
if not external_models:
289+
try:
290+
with open(path, "r", encoding="utf-8") as file:
291+
external_models = [
292+
create_external_model(
293+
defaults=self.config.model_defaults.dict(),
294+
path=path,
295+
project=self.config.project,
296+
audit_definitions=audits,
297+
**{
298+
"dialect": self.config.model_defaults.dialect,
299+
"default_catalog": self.context.default_catalog,
300+
**row,
301+
},
302+
)
303+
for row in YAML().load(file.read())
304+
]
305+
306+
cache.put(external_models, path)
307+
except Exception as ex:
308+
raise ConfigError(f"Failed to load model definition at '{path}'.\n{ex}")
248309

249-
external_models = cache.get_or_load_models(path, _load)
250310
# external models with no explicit gateway defined form the base set
251311
for model in external_models:
252312
if model.gateway is None:
@@ -341,28 +401,6 @@ def _track_file(self, path: Path) -> None:
341401
"""Project file to track for modifications"""
342402
self._path_mtimes[path] = path.stat().st_mtime
343403

344-
def _get_variables(self, gateway_name: t.Optional[str] = None) -> t.Dict[str, t.Any]:
345-
gateway_name = gateway_name or self.context.selected_gateway
346-
347-
if gateway_name not in self._variables_by_gateway:
348-
try:
349-
gateway = self.config.get_gateway(gateway_name)
350-
except ConfigError:
351-
from sqlmesh.core.console import get_console
352-
353-
get_console().log_warning(
354-
f"Gateway '{gateway_name}' not found in project '{self.config.project}'."
355-
)
356-
gateway = None
357-
358-
self._variables_by_gateway[gateway_name] = {
359-
**self.config.variables,
360-
**(gateway.variables if gateway else {}),
361-
c.GATEWAY: gateway_name,
362-
}
363-
364-
return self._variables_by_gateway[gateway_name]
365-
366404

367405
class SqlMeshLoader(Loader):
368406
"""Loads macros and models for a context using the SQLMesh file formats"""
@@ -425,8 +463,14 @@ def _load_models(
425463
audits into a Dict and creates the dag
426464
"""
427465
cache = SqlMeshLoader._Cache(self, self.config_path)
428-
sql_models = self._load_sql_models(macros, jinja_macros, audits, signals, cache)
466+
import time
467+
468+
now = time.time()
469+
sql_models = self._load_sql_models(macros, jinja_macros, audits, signals, cache, gateway)
470+
print("sql models", time.time() - now)
471+
now = time.time()
429472
external_models = self._load_external_models(audits, cache, gateway)
473+
print("external models", time.time() - now)
430474
python_models = self._load_python_models(macros, jinja_macros, audits, signals)
431475

432476
all_model_names = list(sql_models) + list(external_models) + list(python_models)
@@ -443,10 +487,13 @@ def _load_sql_models(
443487
audits: UniqueKeyDict[str, ModelAudit],
444488
signals: UniqueKeyDict[str, signal],
445489
cache: CacheBase,
490+
gateway: t.Optional[str],
446491
) -> UniqueKeyDict[str, Model]:
447492
"""Loads the sql models into a Dict"""
448493
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
449494

495+
paths = set()
496+
450497
for path in self._glob_paths(
451498
self.config_path / c.MODELS,
452499
ignore_patterns=self.config.ignore_patterns,
@@ -456,43 +503,63 @@ def _load_sql_models(
456503
continue
457504

458505
self._track_file(path)
506+
paths.add(path)
459507

460-
def _load() -> t.List[Model]:
461-
try:
462-
with open(path, "r", encoding="utf-8") as file:
463-
expressions = parse(
464-
file.read(), default_dialect=self.config.model_defaults.dialect
465-
)
508+
for path in paths.copy():
509+
cached_models = cache.get(path)
466510

467-
return load_sql_based_models(
468-
expressions,
469-
self._get_variables,
470-
defaults=self.config.model_defaults.dict(),
471-
macros=macros,
472-
jinja_macros=jinja_macros,
473-
audit_definitions=audits,
474-
default_audits=self.config.model_defaults.audits,
475-
path=Path(path).absolute(),
476-
module_path=self.config_path,
477-
dialect=self.config.model_defaults.dialect,
478-
time_column_format=self.config.time_column_format,
479-
physical_schema_mapping=self.config.physical_schema_mapping,
480-
project=self.config.project,
481-
default_catalog=self.context.default_catalog,
482-
infer_names=self.config.model_naming.infer_names,
483-
signal_definitions=signals,
484-
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
485-
)
486-
except Exception as ex:
487-
raise ConfigError(f"Failed to load model definition at '{path}'.\n{ex}")
511+
if cached_models:
512+
paths.remove(path)
488513

489-
for model in cache.get_or_load_models(path, _load):
490-
if model.enabled:
514+
for model in cached_models:
491515
models[model.fqn] = model
492516

493-
if isinstance(model, SeedModel):
494-
seed_path = model.seed_path
495-
self._track_file(seed_path)
517+
error = False
518+
519+
if paths:
520+
defaults = dict(
521+
get_variables=get_variables,
522+
defaults=self.config.model_defaults.dict(),
523+
macros=macros,
524+
jinja_macros=jinja_macros,
525+
audit_definitions=audits,
526+
default_audits=self.config.model_defaults.audits,
527+
module_path=self.config_path,
528+
dialect=self.config.model_defaults.dialect,
529+
time_column_format=self.config.time_column_format,
530+
physical_schema_mapping=self.config.physical_schema_mapping,
531+
project=self.config.project,
532+
default_catalog=self.context.default_catalog,
533+
infer_names=self.config.model_naming.infer_names,
534+
signal_definitions=signals,
535+
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
536+
)
537+
538+
with ProcessPoolExecutor(
539+
mp_context=mp.get_context("fork"),
540+
initializer=_init_model_defaults,
541+
initargs=(self.config, gateway, defaults, cache),
542+
max_workers=c.MAX_FORK_WORKERS,
543+
) as pool:
544+
for fut in as_completed(pool.submit(load_sql_models, path) for path in paths):
545+
try:
546+
path, loaded = fut.result()
547+
548+
if loaded:
549+
for model in loaded:
550+
model._path = path
551+
models[model.fqn] = model
552+
else:
553+
for model in cache.get(path):
554+
models[model.fqn] = model
555+
except Exception as ex:
556+
self._console.log_error(
557+
f"Failed to load model definition at '{path}'.\n{ex}"
558+
)
559+
error = True
560+
561+
if error:
562+
raise ConfigError("Failed to load models")
496563

497564
return models
498565

@@ -526,7 +593,7 @@ def _load_python_models(
526593
registered |= new
527594
for name in new:
528595
for model in registry[name].models(
529-
self._get_variables,
596+
get_variables,
530597
path=path,
531598
module_path=self.config_path,
532599
defaults=self.config.model_defaults.dict(),
@@ -594,7 +661,7 @@ def _load_audits(
594661
"""Loads all the model audits."""
595662
audits_by_name: UniqueKeyDict[str, Audit] = UniqueKeyDict("audits")
596663
audits_max_mtime: t.Optional[float] = None
597-
variables = self._get_variables()
664+
variables = get_variables()
598665

599666
for path in self._glob_paths(
600667
self.config_path / c.AUDITS,
@@ -670,7 +737,7 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
670737
module_path=self.config_path,
671738
jinja_macro_references=None,
672739
macros=macros,
673-
variables=self._get_variables(),
740+
variables=get_variables(),
674741
path=self.config_path,
675742
)
676743

@@ -705,7 +772,7 @@ def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]:
705772
gateway_line = GATEWAY_PATTERN.search(source)
706773
gateway = YAML().load(gateway_line.group(0))["gateway"] if gateway_line else None
707774

708-
contents = yaml_load(source, variables=self._get_variables(gateway))
775+
contents = yaml_load(source, variables=get_variables(gateway))
709776

710777
for test_name, value in contents.items():
711778
model_test_metadata[test_name] = ModelTestMetadata(
@@ -755,16 +822,21 @@ def __init__(self, loader: SqlMeshLoader, config_path: Path):
755822
self.config_path = config_path
756823
self._model_cache = ModelCache(self.config_path / c.CACHE)
757824

758-
def get_or_load_models(
759-
self, target_path: Path, loader: t.Callable[[], t.List[Model]]
760-
) -> t.List[Model]:
761-
models = self._model_cache.get_or_load(
762-
self._cache_entry_name(target_path),
763-
self._model_cache_entry_id(target_path),
764-
loader=loader,
825+
def put(self, models: t.List[Model], path: Path) -> bool:
826+
return self._model_cache.put(
827+
models,
828+
self._cache_entry_name(path),
829+
self._model_cache_entry_id(path),
765830
)
831+
832+
def get(self, path: Path) -> t.List[Model]:
833+
models = self._model_cache.get(
834+
self._cache_entry_name(path),
835+
self._model_cache_entry_id(path),
836+
)
837+
766838
for model in models:
767-
model._path = target_path
839+
model._path = path
768840

769841
return models
770842

0 commit comments

Comments
 (0)