3838from sqlmesh .core .test import ModelTestMetadata , filter_tests_by_patterns
3939from sqlmesh .utils import UniqueKeyDict , sys_path
4040from sqlmesh .utils .errors import ConfigError
41- from sqlmesh .utils .jinja import JinjaMacroRegistry , MacroExtractor
41+ from sqlmesh .utils .jinja import (
42+ JinjaMacroRegistry ,
43+ MacroExtractor ,
44+ SQLMESH_JINJA_PACKAGE ,
45+ SQLMESH_DBT_COMPATIBILITY_PACKAGE ,
46+ )
4247from sqlmesh .utils .metaprogramming import import_python_file
4348from sqlmesh .utils .pydantic import validation_error_message
4449from sqlmesh .utils .yaml import YAML , load as yaml_load
@@ -384,15 +389,42 @@ def _raise_failed_to_load_model_error(self, path: Path, error: t.Union[str, Exce
384389class SqlMeshLoader (Loader ):
385390 """Loads macros and models for a context using the SQLMesh file formats"""
386391
392+ @property
393+ def is_migrated_dbt_project (self ) -> bool :
394+ return self .migrated_dbt_project_name is not None
395+
396+ @property
397+ def migrated_dbt_project_name (self ) -> t .Optional [str ]:
398+ return self .config .variables .get (c .MIGRATED_DBT_PROJECT_NAME )
399+
387400 def _load_scripts (self ) -> t .Tuple [MacroRegistry , JinjaMacroRegistry ]:
388401 """Loads all user defined macros."""
402+
403+ create_builtin_globals_module = (
404+ SQLMESH_DBT_COMPATIBILITY_PACKAGE
405+ if self .is_migrated_dbt_project
406+ else SQLMESH_JINJA_PACKAGE
407+ )
408+
389409 # Store a copy of the macro registry
390410 standard_macros = macro .get_registry ()
391- jinja_macros = JinjaMacroRegistry ()
411+
412+ top_level_packages = []
413+ if self .is_migrated_dbt_project :
414+ top_level_packages = ["dbt" ]
415+ if self .migrated_dbt_project_name :
416+ top_level_packages .append (self .migrated_dbt_project_name )
417+
418+ jinja_macros = JinjaMacroRegistry (
419+ create_builtins_module = create_builtin_globals_module ,
420+ top_level_packages = top_level_packages ,
421+ )
392422 extractor = MacroExtractor ()
393423
394424 macros_max_mtime : t .Optional [float ] = None
395425
426+ migrated_dbt_package_base_path = self .config_path / c .MACROS / c .MIGRATED_DBT_PACKAGES
427+
396428 for path in self ._glob_paths (
397429 self .config_path / c .MACROS ,
398430 ignore_patterns = self .config .ignore_patterns ,
@@ -417,16 +449,51 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
417449 macros_max_mtime = (
418450 max (macros_max_mtime , macro_file_mtime ) if macros_max_mtime else macro_file_mtime
419451 )
452+
420453 with open (path , "r" , encoding = "utf-8" ) as file :
421- jinja_macros .add_macros (
422- extractor .extract (file .read (), dialect = self .config .model_defaults .dialect )
423- )
454+ try :
455+ package : t .Optional [str ] = None
456+ if self .is_migrated_dbt_project :
457+ if path .is_relative_to (migrated_dbt_package_base_path ):
458+ package = str (
459+ path .relative_to (migrated_dbt_package_base_path ).parents [0 ]
460+ )
461+ else :
462+ package = self .migrated_dbt_project_name
463+
464+ jinja_macros .add_macros (
465+ extractor .extract (file .read (), dialect = self .config .model_defaults .dialect ),
466+ package = package ,
467+ )
468+ except :
469+ logger .error (f"Unable to read macro file: { path } " )
470+ raise
424471
425472 self ._macros_max_mtime = macros_max_mtime
426473
427474 macros = macro .get_registry ()
428475 macro .set_registry (standard_macros )
429476
477+ if self .is_migrated_dbt_project :
478+ from sqlmesh .dbt .target import TARGET_TYPE_TO_CONFIG_CLASS
479+
480+ connection_config = self .context ._connection_config
481+ # this triggers the DBT create_builtins_module to have a `target` property which is required for a bunch of DBT macros to work
482+ if dbt_config_type := TARGET_TYPE_TO_CONFIG_CLASS .get (connection_config .type_ ):
483+ try :
484+ jinja_macros .add_globals (
485+ {
486+ "target" : dbt_config_type .from_sqlmesh (
487+ self .context ._connection_config ,
488+ name = self .config .default_gateway_name ,
489+ ).attribute_dict ()
490+ }
491+ )
492+ except NotImplementedError :
493+ raise ConfigError (
494+ f"No DBT 'Target Type' mapping for connection type: { connection_config .type_ } "
495+ )
496+
430497 return macros , jinja_macros
431498
432499 def _load_models (
@@ -499,6 +566,7 @@ def _load() -> t.List[Model]:
499566 infer_names = self .config .model_naming .infer_names ,
500567 signal_definitions = signals ,
501568 default_catalog_per_gateway = self .context .default_catalog_per_gateway ,
569+ migrated_dbt_project_name = self .migrated_dbt_project_name ,
502570 )
503571 except Exception as ex :
504572 self ._raise_failed_to_load_model_error (path , ex )
0 commit comments