44import glob
55import itertools
66import linecache
7- import logging
7+ import multiprocessing as mp
88import os
99import re
1010import typing as t
1111from collections import Counter , defaultdict
1212from dataclasses import dataclass
1313from pathlib import Path
14+ from concurrent .futures import ProcessPoolExecutor , as_completed
1415
1516from sqlglot .errors import SqlglotError
1617from sqlglot import exp
2728from sqlmesh .core .model import (
2829 Model ,
2930 ModelCache ,
30- SeedModel ,
3131 create_external_model ,
3232 load_sql_based_models ,
3333)
4343
4444
4545if t .TYPE_CHECKING :
46+ from sqlmesh .core .config import Config
4647 from sqlmesh .core .context import GenericContext
4748
4849
49- logger = logging .getLogger (__name__ )
50-
5150GATEWAY_PATTERN = re .compile (r"gateway:\s*([^\s]+)" )
5251
5352
@@ -67,21 +66,80 @@ class LoadedProject:
6766
6867class CacheBase (abc .ABC ):
6968 @abc .abstractmethod
70- def get_or_load_models (
71- self , target_path : Path , loader : t .Callable [[], t .List [Model ]]
72- ) -> t .List [Model ]:
73- """Get or load all models from cache."""
69+ def put (self , models : t .List [Model ], path : Path ) -> bool :
70+ pass
71+
72+ @abc .abstractmethod
73+ def get (self , path : Path ) -> t .List [Model ]:
74+ pass
75+
76+
77+ _defaults : t .Optional [t .Dict [str , t .Any ]] = None
78+ _cache : t .Optional [CacheBase ] = None
79+ _config : t .Optional [Config ] = None
80+ _selected_gateway : t .Optional [str ] = None
81+
82+
83+ def _init_model_defaults (
84+ config : Config ,
85+ selected_gateway : t .Optional [str ],
86+ defaults : t .Optional [t .Dict [str , t .Any ]] = None ,
87+ cache : t .Optional [CacheBase ] = None ,
88+ ) -> None :
89+ global _defaults , _cache , _config , _selected_gateway
90+ _defaults = defaults
91+ _cache = cache
92+ _config = config
93+ _selected_gateway = selected_gateway
94+
95+
96+ def load_sql_models (path : Path ) -> t .Tuple [Path , list [Model ]]:
97+ assert _defaults
98+ assert _cache
99+
100+ with open (path , "r" , encoding = "utf-8" ) as file :
101+ expressions = parse (file .read (), default_dialect = _defaults ["dialect" ])
102+ models = load_sql_based_models (expressions , path = Path (path ).absolute (), ** _defaults )
103+
104+ return (path , [] if _cache .put (models , path ) else models )
105+
106+
107+ def get_variables (gateway_name : t .Optional [str ] = None ) -> t .Dict [str , t .Any ]:
108+ assert _config
109+
110+ gateway_name = gateway_name or _selected_gateway
111+
112+ try :
113+ gateway = _config .get_gateway (gateway_name )
114+ except ConfigError :
115+ from sqlmesh .core .console import get_console
116+
117+ get_console ().log_warning (
118+ f"Gateway '{ gateway_name } ' not found in project '{ _config .project } '."
119+ )
120+ gateway = None
121+
122+ return {
123+ ** _config .variables ,
124+ ** (gateway .variables if gateway else {}),
125+ c .GATEWAY : gateway_name ,
126+ }
74127
75128
76129class Loader (abc .ABC ):
77130 """Abstract base class to load macros and models for a context"""
78131
79132 def __init__ (self , context : GenericContext , path : Path ) -> None :
133+ from sqlmesh .core .console import get_console
134+
80135 self ._path_mtimes : t .Dict [Path , float ] = {}
81136 self .context = context
82137 self .config_path = path
83138 self .config = self .context .configs [self .config_path ]
84139 self ._variables_by_gateway : t .Dict [str , t .Dict [str , t .Any ]] = {}
140+ self ._console = get_console ()
141+
142+ _init_model_defaults (self .config , self .context .selected_gateway )
85143
86144 def load (self ) -> LoadedProject :
87145 """
@@ -223,30 +281,32 @@ def _load_external_models(
223281 if external_models_path .exists () and external_models_path .is_dir ():
224282 paths_to_load .extend (self ._glob_paths (external_models_path , extension = ".yaml" ))
225283
226- def _load () -> t .List [Model ]:
227- try :
228- with open (path , "r" , encoding = "utf-8" ) as file :
229- return [
230- create_external_model (
231- defaults = self .config .model_defaults .dict (),
232- path = path ,
233- project = self .config .project ,
234- audit_definitions = audits ,
235- ** {
236- "dialect" : self .config .model_defaults .dialect ,
237- "default_catalog" : self .context .default_catalog ,
238- ** row ,
239- },
240- )
241- for row in YAML ().load (file .read ())
242- ]
243- except Exception as ex :
244- raise ConfigError (f"Failed to load model definition at '{ path } '.\n { ex } " )
245-
246284 for path in paths_to_load :
247285 self ._track_file (path )
286+ external_models = cache .get (path )
287+
288+ if not external_models :
289+ try :
290+ with open (path , "r" , encoding = "utf-8" ) as file :
291+ external_models = [
292+ create_external_model (
293+ defaults = self .config .model_defaults .dict (),
294+ path = path ,
295+ project = self .config .project ,
296+ audit_definitions = audits ,
297+ ** {
298+ "dialect" : self .config .model_defaults .dialect ,
299+ "default_catalog" : self .context .default_catalog ,
300+ ** row ,
301+ },
302+ )
303+ for row in YAML ().load (file .read ())
304+ ]
305+
306+ cache .put (external_models , path )
307+ except Exception as ex :
308+ raise ConfigError (f"Failed to load model definition at '{ path } '.\n { ex } " )
248309
249- external_models = cache .get_or_load_models (path , _load )
250310 # external models with no explicit gateway defined form the base set
251311 for model in external_models :
252312 if model .gateway is None :
@@ -341,28 +401,6 @@ def _track_file(self, path: Path) -> None:
341401 """Project file to track for modifications"""
342402 self ._path_mtimes [path ] = path .stat ().st_mtime
343403
344- def _get_variables (self , gateway_name : t .Optional [str ] = None ) -> t .Dict [str , t .Any ]:
345- gateway_name = gateway_name or self .context .selected_gateway
346-
347- if gateway_name not in self ._variables_by_gateway :
348- try :
349- gateway = self .config .get_gateway (gateway_name )
350- except ConfigError :
351- from sqlmesh .core .console import get_console
352-
353- get_console ().log_warning (
354- f"Gateway '{ gateway_name } ' not found in project '{ self .config .project } '."
355- )
356- gateway = None
357-
358- self ._variables_by_gateway [gateway_name ] = {
359- ** self .config .variables ,
360- ** (gateway .variables if gateway else {}),
361- c .GATEWAY : gateway_name ,
362- }
363-
364- return self ._variables_by_gateway [gateway_name ]
365-
366404
367405class SqlMeshLoader (Loader ):
368406 """Loads macros and models for a context using the SQLMesh file formats"""
@@ -425,8 +463,14 @@ def _load_models(
425463 audits into a Dict and creates the dag
426464 """
427465 cache = SqlMeshLoader ._Cache (self , self .config_path )
428- sql_models = self ._load_sql_models (macros , jinja_macros , audits , signals , cache )
466+ import time
467+
468+ now = time .time ()
469+ sql_models = self ._load_sql_models (macros , jinja_macros , audits , signals , cache , gateway )
470+ print ("sql models" , time .time () - now )
471+ now = time .time ()
429472 external_models = self ._load_external_models (audits , cache , gateway )
473+ print ("external models" , time .time () - now )
430474 python_models = self ._load_python_models (macros , jinja_macros , audits , signals )
431475
432476 all_model_names = list (sql_models ) + list (external_models ) + list (python_models )
@@ -443,10 +487,13 @@ def _load_sql_models(
443487 audits : UniqueKeyDict [str , ModelAudit ],
444488 signals : UniqueKeyDict [str , signal ],
445489 cache : CacheBase ,
490+ gateway : t .Optional [str ],
446491 ) -> UniqueKeyDict [str , Model ]:
447492 """Loads the sql models into a Dict"""
448493 models : UniqueKeyDict [str , Model ] = UniqueKeyDict ("models" )
449494
495+ paths = set ()
496+
450497 for path in self ._glob_paths (
451498 self .config_path / c .MODELS ,
452499 ignore_patterns = self .config .ignore_patterns ,
@@ -456,43 +503,63 @@ def _load_sql_models(
456503 continue
457504
458505 self ._track_file (path )
506+ paths .add (path )
459507
460- def _load () -> t .List [Model ]:
461- try :
462- with open (path , "r" , encoding = "utf-8" ) as file :
463- expressions = parse (
464- file .read (), default_dialect = self .config .model_defaults .dialect
465- )
508+ for path in paths .copy ():
509+ cached_models = cache .get (path )
466510
467- return load_sql_based_models (
468- expressions ,
469- self ._get_variables ,
470- defaults = self .config .model_defaults .dict (),
471- macros = macros ,
472- jinja_macros = jinja_macros ,
473- audit_definitions = audits ,
474- default_audits = self .config .model_defaults .audits ,
475- path = Path (path ).absolute (),
476- module_path = self .config_path ,
477- dialect = self .config .model_defaults .dialect ,
478- time_column_format = self .config .time_column_format ,
479- physical_schema_mapping = self .config .physical_schema_mapping ,
480- project = self .config .project ,
481- default_catalog = self .context .default_catalog ,
482- infer_names = self .config .model_naming .infer_names ,
483- signal_definitions = signals ,
484- default_catalog_per_gateway = self .context .default_catalog_per_gateway ,
485- )
486- except Exception as ex :
487- raise ConfigError (f"Failed to load model definition at '{ path } '.\n { ex } " )
511+ if cached_models :
512+ paths .remove (path )
488513
489- for model in cache .get_or_load_models (path , _load ):
490- if model .enabled :
514+ for model in cached_models :
491515 models [model .fqn ] = model
492516
493- if isinstance (model , SeedModel ):
494- seed_path = model .seed_path
495- self ._track_file (seed_path )
517+ error = False
518+
519+ if paths :
520+ defaults = dict (
521+ get_variables = get_variables ,
522+ defaults = self .config .model_defaults .dict (),
523+ macros = macros ,
524+ jinja_macros = jinja_macros ,
525+ audit_definitions = audits ,
526+ default_audits = self .config .model_defaults .audits ,
527+ module_path = self .config_path ,
528+ dialect = self .config .model_defaults .dialect ,
529+ time_column_format = self .config .time_column_format ,
530+ physical_schema_mapping = self .config .physical_schema_mapping ,
531+ project = self .config .project ,
532+ default_catalog = self .context .default_catalog ,
533+ infer_names = self .config .model_naming .infer_names ,
534+ signal_definitions = signals ,
535+ default_catalog_per_gateway = self .context .default_catalog_per_gateway ,
536+ )
537+
538+ with ProcessPoolExecutor (
539+ mp_context = mp .get_context ("fork" ),
540+ initializer = _init_model_defaults ,
541+ initargs = (self .config , gateway , defaults , cache ),
542+ max_workers = c .MAX_FORK_WORKERS ,
543+ ) as pool :
544+ for fut in as_completed (pool .submit (load_sql_models , path ) for path in paths ):
545+ try :
546+ path , loaded = fut .result ()
547+
548+ if loaded :
549+ for model in loaded :
550+ model ._path = path
551+ models [model .fqn ] = model
552+ else :
553+ for model in cache .get (path ):
554+ models [model .fqn ] = model
555+ except Exception as ex :
556+ self ._console .log_error (
557+ f"Failed to load model definition at '{ path } '.\n { ex } "
558+ )
559+ error = True
560+
561+ if error :
562+ raise ConfigError ("Failed to load models" )
496563
497564 return models
498565
@@ -526,7 +593,7 @@ def _load_python_models(
526593 registered |= new
527594 for name in new :
528595 for model in registry [name ].models (
529- self . _get_variables ,
596+ get_variables ,
530597 path = path ,
531598 module_path = self .config_path ,
532599 defaults = self .config .model_defaults .dict (),
@@ -594,7 +661,7 @@ def _load_audits(
594661 """Loads all the model audits."""
595662 audits_by_name : UniqueKeyDict [str , Audit ] = UniqueKeyDict ("audits" )
596663 audits_max_mtime : t .Optional [float ] = None
597- variables = self . _get_variables ()
664+ variables = get_variables ()
598665
599666 for path in self ._glob_paths (
600667 self .config_path / c .AUDITS ,
@@ -670,7 +737,7 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
670737 module_path = self .config_path ,
671738 jinja_macro_references = None ,
672739 macros = macros ,
673- variables = self . _get_variables (),
740+ variables = get_variables (),
674741 path = self .config_path ,
675742 )
676743
@@ -705,7 +772,7 @@ def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]:
705772 gateway_line = GATEWAY_PATTERN .search (source )
706773 gateway = YAML ().load (gateway_line .group (0 ))["gateway" ] if gateway_line else None
707774
708- contents = yaml_load (source , variables = self . _get_variables (gateway ))
775+ contents = yaml_load (source , variables = get_variables (gateway ))
709776
710777 for test_name , value in contents .items ():
711778 model_test_metadata [test_name ] = ModelTestMetadata (
@@ -755,16 +822,21 @@ def __init__(self, loader: SqlMeshLoader, config_path: Path):
755822 self .config_path = config_path
756823 self ._model_cache = ModelCache (self .config_path / c .CACHE )
757824
758- def get_or_load_models (
759- self , target_path : Path , loader : t .Callable [[], t .List [Model ]]
760- ) -> t .List [Model ]:
761- models = self ._model_cache .get_or_load (
762- self ._cache_entry_name (target_path ),
763- self ._model_cache_entry_id (target_path ),
764- loader = loader ,
825+ def put (self , models : t .List [Model ], path : Path ) -> bool :
826+ return self ._model_cache .put (
827+ models ,
828+ self ._cache_entry_name (path ),
829+ self ._model_cache_entry_id (path ),
765830 )
831+
832+ def get (self , path : Path ) -> t .List [Model ]:
833+ models = self ._model_cache .get (
834+ self ._cache_entry_name (path ),
835+ self ._model_cache_entry_id (path ),
836+ )
837+
766838 for model in models :
767- model ._path = target_path
839+ model ._path = path
768840
769841 return models
770842
0 commit comments