Skip to content

Commit f59c316

Browse files
committed
Feat(experimental): DBT project conversion
1 parent 8edfae4 commit f59c316

46 files changed

Lines changed: 3716 additions & 37 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

sqlmesh/cli/main.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"environments",
3131
"invalidate",
3232
"table_name",
33+
"dbt",
3334
)
3435
SKIP_CONTEXT_COMMANDS = ("init", "ui")
3536

@@ -1211,3 +1212,39 @@ def state_import(obj: Context, input_file: Path, replace: bool, no_confirm: bool
12111212
"""Import a state export file back into the state database"""
12121213
confirm = not no_confirm
12131214
obj.import_state(input_file=input_file, clear=replace, confirm=confirm)
1215+
1216+
1217+
@cli.group(no_args_is_help=True, hidden=True)
1218+
def dbt() -> None:
1219+
"""Commands for doing dbt-specific things"""
1220+
pass
1221+
1222+
1223+
@dbt.command("convert")
1224+
@click.option(
1225+
"-i",
1226+
"--input-dir",
1227+
help="Path to the DBT project",
1228+
required=True,
1229+
type=click.Path(exists=True, dir_okay=True, file_okay=False, readable=True, path_type=Path),
1230+
)
1231+
@click.option(
1232+
"-o",
1233+
"--output-dir",
1234+
required=True,
1235+
help="Path to write out the converted SQLMesh project",
1236+
type=click.Path(exists=False, dir_okay=True, file_okay=False, readable=True, path_type=Path),
1237+
)
1238+
@click.option("--no-prompts", is_flag=True, help="Disable interactive prompts", default=False)
1239+
@click.pass_obj
1240+
@error_handler
1241+
@cli_analytics
1242+
def dbt_convert(obj: Context, input_dir: Path, output_dir: Path, no_prompts: bool) -> None:
1243+
"""Convert a DBT project to a SQLMesh project"""
1244+
from sqlmesh.dbt.converter.convert import convert_project_files
1245+
1246+
convert_project_files(
1247+
input_dir.absolute(),
1248+
output_dir.absolute(),
1249+
no_prompts=no_prompts,
1250+
)

sqlmesh/core/config/root.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
scheduler_config_validator,
4040
)
4141
from sqlmesh.core.config.ui import UIConfig
42-
from sqlmesh.core.loader import Loader, SqlMeshLoader
42+
from sqlmesh.core.loader import Loader, SqlMeshLoader, MigratedDbtProjectLoader
4343
from sqlmesh.core.notification_target import NotificationTarget
4444
from sqlmesh.core.user import User
4545
from sqlmesh.utils.date import to_timestamp, now
@@ -219,6 +219,13 @@ def _normalize_and_validate_fields(cls, data: t.Any) -> t.Any:
219219
f"^{k}$": v for k, v in physical_schema_override.items()
220220
}
221221

222+
if (
223+
(variables := data.get("variables", ""))
224+
and isinstance(variables, dict)
225+
and c.MIGRATED_DBT_PROJECT_NAME in variables
226+
):
227+
data["loader"] = MigratedDbtProjectLoader
228+
222229
return data
223230

224231
@model_validator(mode="after")

sqlmesh/core/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
MAX_MODEL_DEFINITION_SIZE = 10000
3232
"""Maximum number of characters in a model definition"""
3333

34+
MIGRATED_DBT_PROJECT_NAME = "__dbt_project_name__"
35+
MIGRATED_DBT_PACKAGES = "__dbt_packages__"
36+
3437

3538
# The maximum number of fork processes, used for loading projects
3639
# None means default to process pool, 1 means don't fork, :N is number of processes

sqlmesh/core/loader.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@
3838
from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns
3939
from sqlmesh.utils import UniqueKeyDict, sys_path
4040
from sqlmesh.utils.errors import ConfigError
41-
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
41+
from sqlmesh.utils.jinja import (
42+
JinjaMacroRegistry,
43+
MacroExtractor,
44+
SQLMESH_DBT_COMPATIBILITY_PACKAGE,
45+
)
4246
from sqlmesh.utils.metaprogramming import import_python_file
4347
from sqlmesh.utils.pydantic import validation_error_message
4448
from sqlmesh.utils.process import create_process_pool_executor
@@ -548,6 +552,7 @@ def _load_sql_models(
548552
signals: UniqueKeyDict[str, signal],
549553
cache: CacheBase,
550554
gateway: t.Optional[str],
555+
loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
551556
) -> UniqueKeyDict[str, Model]:
552557
"""Loads the sql models into a Dict"""
553558
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
@@ -590,6 +595,7 @@ def _load_sql_models(
590595
infer_names=self.config.model_naming.infer_names,
591596
signal_definitions=signals,
592597
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
598+
**loading_default_kwargs or {},
593599
)
594600

595601
with create_process_pool_executor(
@@ -942,3 +948,104 @@ def _model_cache_entry_id(self, model_path: Path) -> str:
942948
self._loader.context.gateway or self._loader.config.default_gateway_name,
943949
]
944950
)
951+
952+
953+
class MigratedDbtProjectLoader(SqlMeshLoader):
954+
@property
955+
def migrated_dbt_project_name(self) -> str:
956+
return self.config.variables[c.MIGRATED_DBT_PROJECT_NAME]
957+
958+
def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
959+
from sqlmesh.dbt.converter.common import infer_dbt_package_from_path
960+
from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS
961+
962+
# Store a copy of the macro registry
963+
standard_macros = macro.get_registry()
964+
965+
jinja_macros = JinjaMacroRegistry(
966+
create_builtins_module=SQLMESH_DBT_COMPATIBILITY_PACKAGE,
967+
top_level_packages=["dbt", self.migrated_dbt_project_name],
968+
)
969+
extractor = MacroExtractor()
970+
971+
macros_max_mtime: t.Optional[float] = None
972+
973+
for path in self._glob_paths(
974+
self.config_path / c.MACROS,
975+
ignore_patterns=self.config.ignore_patterns,
976+
extension=".py",
977+
):
978+
if import_python_file(path, self.config_path):
979+
self._track_file(path)
980+
macro_file_mtime = self._path_mtimes[path]
981+
macros_max_mtime = (
982+
max(macros_max_mtime, macro_file_mtime)
983+
if macros_max_mtime
984+
else macro_file_mtime
985+
)
986+
987+
for path in self._glob_paths(
988+
self.config_path / c.MACROS,
989+
ignore_patterns=self.config.ignore_patterns,
990+
extension=".sql",
991+
):
992+
self._track_file(path)
993+
macro_file_mtime = self._path_mtimes[path]
994+
macros_max_mtime = (
995+
max(macros_max_mtime, macro_file_mtime) if macros_max_mtime else macro_file_mtime
996+
)
997+
998+
with open(path, "r", encoding="utf-8") as file:
999+
try:
1000+
package = infer_dbt_package_from_path(path) or self.migrated_dbt_project_name
1001+
1002+
jinja_macros.add_macros(
1003+
extractor.extract(file.read(), dialect=self.config.model_defaults.dialect),
1004+
package=package,
1005+
)
1006+
except Exception as e:
1007+
raise ConfigError(f"Failed to load macro file: {path}", e)
1008+
1009+
self._macros_max_mtime = macros_max_mtime
1010+
1011+
macros = macro.get_registry()
1012+
macro.set_registry(standard_macros)
1013+
1014+
connection_config = self.context.connection_config
1015+
# this triggers the DBT create_builtins_module to have a `target` property which is required for a bunch of DBT macros to work
1016+
if dbt_config_type := TARGET_TYPE_TO_CONFIG_CLASS.get(connection_config.type_):
1017+
try:
1018+
jinja_macros.add_globals(
1019+
{
1020+
"target": dbt_config_type.from_sqlmesh(
1021+
connection_config,
1022+
name=self.config.default_gateway_name,
1023+
).attribute_dict()
1024+
}
1025+
)
1026+
except NotImplementedError:
1027+
raise ConfigError(f"Unsupported dbt target type: {connection_config.type_}")
1028+
1029+
return macros, jinja_macros
1030+
1031+
def _load_sql_models(
1032+
self,
1033+
macros: MacroRegistry,
1034+
jinja_macros: JinjaMacroRegistry,
1035+
audits: UniqueKeyDict[str, ModelAudit],
1036+
signals: UniqueKeyDict[str, signal],
1037+
cache: CacheBase,
1038+
gateway: t.Optional[str],
1039+
loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
1040+
) -> UniqueKeyDict[str, Model]:
1041+
return super()._load_sql_models(
1042+
macros=macros,
1043+
jinja_macros=jinja_macros,
1044+
audits=audits,
1045+
signals=signals,
1046+
cache=cache,
1047+
gateway=gateway,
1048+
loading_default_kwargs=dict(
1049+
migrated_dbt_project_name=self.migrated_dbt_project_name,
1050+
),
1051+
)

sqlmesh/core/model/definition.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,6 +2017,7 @@ def load_sql_based_model(
20172017
variables: t.Optional[t.Dict[str, t.Any]] = None,
20182018
infer_names: t.Optional[bool] = False,
20192019
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
2020+
migrated_dbt_project_name: t.Optional[str] = None,
20202021
**kwargs: t.Any,
20212022
) -> Model:
20222023
"""Load a model from a parsed SQLMesh model SQL file.
@@ -2190,6 +2191,7 @@ def load_sql_based_model(
21902191
query_or_seed_insert,
21912192
kind=kind,
21922193
time_column_format=time_column_format,
2194+
migrated_dbt_project_name=migrated_dbt_project_name,
21932195
**common_kwargs,
21942196
)
21952197

@@ -2397,6 +2399,7 @@ def _create_model(
23972399
signal_definitions: t.Optional[SignalRegistry] = None,
23982400
variables: t.Optional[t.Dict[str, t.Any]] = None,
23992401
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
2402+
migrated_dbt_project_name: t.Optional[str] = None,
24002403
**kwargs: t.Any,
24012404
) -> Model:
24022405
validate_extra_and_required_fields(
@@ -2452,13 +2455,28 @@ def _create_model(
24522455

24532456
if jinja_macros:
24542457
jinja_macros = (
2455-
jinja_macros if jinja_macros.trimmed else jinja_macros.trim(jinja_macro_references)
2458+
jinja_macros
2459+
if jinja_macros.trimmed
2460+
else jinja_macros.trim(jinja_macro_references, package=migrated_dbt_project_name)
24562461
)
24572462
else:
24582463
jinja_macros = JinjaMacroRegistry()
24592464

2460-
for jinja_macro in jinja_macros.root_macros.values():
2461-
used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1])
2465+
if migrated_dbt_project_name:
2466+
# extract {{ var() }} references used in all jinja macro dependencies to check for any variables specific
2467+
# to a migrated DBT package and resolve them accordingly
2468+
# vars are added into __sqlmesh_vars__ in the Python env so that the native SQLMesh var() function can resolve them
2469+
variables = variables or {}
2470+
2471+
nested_macro_used_variables, flattened_package_variables = (
2472+
_extract_migrated_dbt_variable_references(jinja_macros, variables)
2473+
)
2474+
2475+
used_variables.update(nested_macro_used_variables)
2476+
variables.update(flattened_package_variables)
2477+
else:
2478+
for jinja_macro in jinja_macros.root_macros.values():
2479+
used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1])
24622480

24632481
model = klass(
24642482
name=name,
@@ -2841,7 +2859,7 @@ def render_expression(
28412859
"cron_tz": lambda value: exp.Literal.string(value),
28422860
"partitioned_by_": _single_expr_or_tuple,
28432861
"clustered_by": _single_expr_or_tuple,
2844-
"depends_on_": lambda value: exp.Tuple(expressions=sorted(value)),
2862+
"depends_on_": lambda value: exp.Tuple(expressions=sorted(value)) if value else "()",
28452863
"pre": _list_of_calls_to_exp,
28462864
"post": _list_of_calls_to_exp,
28472865
"audits": _list_of_calls_to_exp,
@@ -2912,4 +2930,37 @@ def clickhouse_partition_func(
29122930
)
29132931

29142932

2933+
def _extract_migrated_dbt_variable_references(
2934+
jinja_macros: JinjaMacroRegistry, project_variables: t.Dict[str, t.Any]
2935+
) -> t.Tuple[t.Set[str], t.Dict[str, t.Any]]:
2936+
if not jinja_macros.trimmed:
2937+
raise ValueError("Expecting a trimmed JinjaMacroRegistry")
2938+
2939+
used_variables = set()
2940+
# note: JinjaMacroRegistry is trimmed here so "all_macros" should be just be all the macros used by this model
2941+
for _, _, jinja_macro in jinja_macros.all_macros:
2942+
_, extracted_variable_names = extract_macro_references_and_variables(jinja_macro.definition)
2943+
used_variables.update(extracted_variable_names)
2944+
2945+
flattened = {}
2946+
if (dbt_package_variables := project_variables.get(c.MIGRATED_DBT_PACKAGES)) and isinstance(
2947+
dbt_package_variables, dict
2948+
):
2949+
# flatten the nested dict structure from the migrated dbt package variables in the SQLmesh config into __dbt_packages.<package>.<variable>
2950+
# to match what extract_macro_references_and_variables() returns. This allows the usage checks in create_python_env() to work
2951+
def _flatten(prefix: str, root: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
2952+
acc = {}
2953+
for k, v in root.items():
2954+
key_with_prefix = f"{prefix}.{k}"
2955+
if isinstance(v, dict):
2956+
acc.update(_flatten(key_with_prefix, v))
2957+
else:
2958+
acc[key_with_prefix] = v
2959+
return acc
2960+
2961+
flattened = _flatten(c.MIGRATED_DBT_PACKAGES, dbt_package_variables)
2962+
2963+
return used_variables, flattened
2964+
2965+
29152966
TIME_COL_PARTITION_FUNC = {"clickhouse": clickhouse_partition_func}

sqlmesh/core/model/kind.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,18 @@ def _merge_filter_validator(
491491

492492
return v.transform(d.replace_merge_table_aliases)
493493

494+
@field_validator("batch_concurrency", mode="before")
495+
def _batch_concurrency_validator(
496+
cls, v: t.Optional[exp.Expression], info: ValidationInfo
497+
) -> int:
498+
if isinstance(v, exp.Literal):
499+
return int(
500+
v.to_py()
501+
) # allow batch_concurrency = 1 to be specified in a Model definition without throwing a Pydantic error
502+
if isinstance(v, int):
503+
return v
504+
return 1
505+
494506
@property
495507
def data_hash_values(self) -> t.List[t.Optional[str]]:
496508
return [

sqlmesh/core/renderer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def _resolve_table(table: str | exp.Table) -> str:
179179
)
180180

181181
render_kwargs = {
182+
"dialect": self._dialect,
182183
**date_dict(
183184
to_datetime(execution_time or c.EPOCH),
184185
start_time,

sqlmesh/dbt/adapter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def __init__(
3838
self.jinja_globals = jinja_globals.copy() if jinja_globals else {}
3939
self.jinja_globals["adapter"] = self
4040
self.project_dialect = project_dialect
41+
self.jinja_globals["dialect"] = (
42+
project_dialect # so the dialect is available in the jinja env created by self.dispatch()
43+
)
4144
self.quote_policy = quote_policy or Policy()
4245

4346
@abc.abstractmethod

sqlmesh/dbt/builtin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class Var:
156156
def __init__(self, variables: t.Dict[str, t.Any]) -> None:
157157
self.variables = variables
158158

159-
def __call__(self, name: str, default: t.Optional[t.Any] = None) -> t.Any:
159+
def __call__(self, name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any) -> t.Any:
160160
return self.variables.get(name, default)
161161

162162
def has_var(self, name: str) -> bool:

sqlmesh/dbt/converter/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)