diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 51eb0c3432..b83daf1724 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -33,6 +33,7 @@ "rollback", "run", "table_name", + "dbt", ) SKIP_CONTEXT_COMMANDS = ("init", "ui") @@ -1219,3 +1220,39 @@ def state_import(obj: Context, input_file: Path, replace: bool, no_confirm: bool """Import a state export file back into the state database""" confirm = not no_confirm obj.import_state(input_file=input_file, clear=replace, confirm=confirm) + + +@cli.group(no_args_is_help=True, hidden=True) +def dbt() -> None: + """Commands for doing dbt-specific things""" + pass + + +@dbt.command("convert") +@click.option( + "-i", + "--input-dir", + help="Path to the DBT project", + required=True, + type=click.Path(exists=True, dir_okay=True, file_okay=False, readable=True, path_type=Path), +) +@click.option( + "-o", + "--output-dir", + required=True, + help="Path to write out the converted SQLMesh project", + type=click.Path(exists=False, dir_okay=True, file_okay=False, readable=True, path_type=Path), +) +@click.option("--no-prompts", is_flag=True, help="Disable interactive prompts", default=False) +@click.pass_obj +@error_handler +@cli_analytics +def dbt_convert(obj: Context, input_dir: Path, output_dir: Path, no_prompts: bool) -> None: + """Convert a DBT project to a SQLMesh project""" + from sqlmesh.dbt.converter.convert import convert_project_files + + convert_project_files( + input_dir.absolute(), + output_dir.absolute(), + no_prompts=no_prompts, + ) diff --git a/sqlmesh/core/config/root.py b/sqlmesh/core/config/root.py index 0f132680a8..1d53235f73 100644 --- a/sqlmesh/core/config/root.py +++ b/sqlmesh/core/config/root.py @@ -39,7 +39,7 @@ scheduler_config_validator, ) from sqlmesh.core.config.ui import UIConfig -from sqlmesh.core.loader import Loader, SqlMeshLoader +from sqlmesh.core.loader import Loader, SqlMeshLoader, MigratedDbtProjectLoader from sqlmesh.core.notification_target import NotificationTarget from sqlmesh.core.user import User from sqlmesh.utils.date import to_timestamp, now @@ -219,6 +219,13 @@ def _normalize_and_validate_fields(cls, data: t.Any) -> t.Any: f"^{k}$": v for k, v in physical_schema_override.items() } + if ( + (variables := data.get("variables", "")) + and isinstance(variables, dict) + and c.MIGRATED_DBT_PROJECT_NAME in variables + ): + data["loader"] = MigratedDbtProjectLoader + return data @model_validator(mode="after") diff --git a/sqlmesh/core/constants.py b/sqlmesh/core/constants.py index 60c6a3eedf..2ab592f368 100644 --- a/sqlmesh/core/constants.py +++ b/sqlmesh/core/constants.py @@ -31,6 +31,9 @@ MAX_MODEL_DEFINITION_SIZE = 10000 """Maximum number of characters in a model definition""" +MIGRATED_DBT_PROJECT_NAME = "__dbt_project_name__" +MIGRATED_DBT_PACKAGES = "__dbt_packages__" + # The maximum number of fork processes, used for loading projects # None means default to process pool, 1 means don't fork, :N is number of processes diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index e7df315768..7f90c0de63 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -38,7 +38,11 @@ from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns from sqlmesh.utils import UniqueKeyDict, sys_path from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor +from sqlmesh.utils.jinja import ( + JinjaMacroRegistry, + MacroExtractor, + SQLMESH_DBT_COMPATIBILITY_PACKAGE, +) from sqlmesh.utils.metaprogramming import import_python_file from sqlmesh.utils.pydantic import validation_error_message from sqlmesh.utils.process import create_process_pool_executor @@ -548,6 +552,7 @@ def _load_sql_models( signals: UniqueKeyDict[str, signal], cache: CacheBase, gateway: t.Optional[str], + loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None, ) -> UniqueKeyDict[str, Model]: """Loads the sql models into a Dict""" models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") @@ -590,6 +595,7 @@ def _load_sql_models( infer_names=self.config.model_naming.infer_names, signal_definitions=signals, default_catalog_per_gateway=self.context.default_catalog_per_gateway, + **loading_default_kwargs or {}, ) with create_process_pool_executor( @@ -942,3 +948,104 @@ def _model_cache_entry_id(self, model_path: Path) -> str: self._loader.context.gateway or self._loader.config.default_gateway_name, ] ) + + +class MigratedDbtProjectLoader(SqlMeshLoader): + @property + def migrated_dbt_project_name(self) -> str: + return self.config.variables[c.MIGRATED_DBT_PROJECT_NAME] + + def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: + from sqlmesh.dbt.converter.common import infer_dbt_package_from_path + from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS + + # Store a copy of the macro registry + standard_macros = macro.get_registry() + + jinja_macros = JinjaMacroRegistry( + create_builtins_module=SQLMESH_DBT_COMPATIBILITY_PACKAGE, + top_level_packages=["dbt", self.migrated_dbt_project_name], + ) + extractor = MacroExtractor() + + macros_max_mtime: t.Optional[float] = None + + for path in self._glob_paths( + self.config_path / c.MACROS, + ignore_patterns=self.config.ignore_patterns, + extension=".py", + ): + if import_python_file(path, self.config_path): + self._track_file(path) + macro_file_mtime = self._path_mtimes[path] + macros_max_mtime = ( + max(macros_max_mtime, macro_file_mtime) + if macros_max_mtime + else macro_file_mtime + ) + + for path in self._glob_paths( + self.config_path / c.MACROS, + ignore_patterns=self.config.ignore_patterns, + extension=".sql", + ): + self._track_file(path) + macro_file_mtime = self._path_mtimes[path] + macros_max_mtime = ( + max(macros_max_mtime, macro_file_mtime) if macros_max_mtime else macro_file_mtime + ) + + with open(path, "r", encoding="utf-8") as file: + try: + package = infer_dbt_package_from_path(path) or self.migrated_dbt_project_name + + jinja_macros.add_macros( + extractor.extract(file.read(), dialect=self.config.model_defaults.dialect), + package=package, + ) + except Exception as e: + raise ConfigError(f"Failed to load macro file: {path}", e) + + self._macros_max_mtime = macros_max_mtime + + macros = macro.get_registry() + macro.set_registry(standard_macros) + + connection_config = self.context.connection_config + # this triggers the DBT create_builtins_module to have a `target` property which is required for a bunch of DBT macros to work + if dbt_config_type := TARGET_TYPE_TO_CONFIG_CLASS.get(connection_config.type_): + try: + jinja_macros.add_globals( + { + "target": dbt_config_type.from_sqlmesh( + connection_config, + name=self.config.default_gateway_name, + ).attribute_dict() + } + ) + except NotImplementedError: + raise ConfigError(f"Unsupported dbt target type: {connection_config.type_}") + + return macros, jinja_macros + + def _load_sql_models( + self, + macros: MacroRegistry, + jinja_macros: JinjaMacroRegistry, + audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], + cache: CacheBase, + gateway: t.Optional[str], + loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None, + ) -> UniqueKeyDict[str, Model]: + return super()._load_sql_models( + macros=macros, + jinja_macros=jinja_macros, + audits=audits, + signals=signals, + cache=cache, + gateway=gateway, + loading_default_kwargs=dict( + migrated_dbt_project_name=self.migrated_dbt_project_name, + ), + ) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index db61b09c8e..f42a3ebfdc 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -2017,6 +2017,7 @@ def load_sql_based_model( variables: t.Optional[t.Dict[str, t.Any]] = None, infer_names: t.Optional[bool] = False, blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + migrated_dbt_project_name: t.Optional[str] = None, **kwargs: t.Any, ) -> Model: """Load a model from a parsed SQLMesh model SQL file. @@ -2193,6 +2194,7 @@ def load_sql_based_model( query_or_seed_insert, kind=kind, time_column_format=time_column_format, + migrated_dbt_project_name=migrated_dbt_project_name, **common_kwargs, ) @@ -2400,6 +2402,7 @@ def _create_model( signal_definitions: t.Optional[SignalRegistry] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + migrated_dbt_project_name: t.Optional[str] = None, **kwargs: t.Any, ) -> Model: validate_extra_and_required_fields( @@ -2455,13 +2458,28 @@ def _create_model( if jinja_macros: jinja_macros = ( - jinja_macros if jinja_macros.trimmed else jinja_macros.trim(jinja_macro_references) + jinja_macros + if jinja_macros.trimmed + else jinja_macros.trim(jinja_macro_references, package=migrated_dbt_project_name) ) else: jinja_macros = JinjaMacroRegistry() - for jinja_macro in jinja_macros.root_macros.values(): - used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1]) + if migrated_dbt_project_name: + # extract {{ var() }} references used in all jinja macro dependencies to check for any variables specific + # to a migrated DBT package and resolve them accordingly + # vars are added into __sqlmesh_vars__ in the Python env so that the native SQLMesh var() function can resolve them + variables = variables or {} + + nested_macro_used_variables, flattened_package_variables = ( + _extract_migrated_dbt_variable_references(jinja_macros, variables) + ) + + used_variables.update(nested_macro_used_variables) + variables.update(flattened_package_variables) + else: + for jinja_macro in jinja_macros.root_macros.values(): + used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1]) model = klass( name=name, @@ -2844,7 +2862,7 @@ def render_expression( "cron_tz": lambda value: exp.Literal.string(value), "partitioned_by_": _single_expr_or_tuple, "clustered_by": _single_expr_or_tuple, - "depends_on_": lambda value: exp.Tuple(expressions=sorted(value)), + "depends_on_": lambda value: exp.Tuple(expressions=sorted(value)) if value else "()", "pre": _list_of_calls_to_exp, "post": _list_of_calls_to_exp, "audits": _list_of_calls_to_exp, @@ -2915,4 +2933,37 @@ def clickhouse_partition_func( ) +def _extract_migrated_dbt_variable_references( + jinja_macros: JinjaMacroRegistry, project_variables: t.Dict[str, t.Any] +) -> t.Tuple[t.Set[str], t.Dict[str, t.Any]]: + if not jinja_macros.trimmed: + raise ValueError("Expecting a trimmed JinjaMacroRegistry") + + used_variables = set() + # note: JinjaMacroRegistry is trimmed here so "all_macros" should be just be all the macros used by this model + for _, _, jinja_macro in jinja_macros.all_macros: + _, extracted_variable_names = extract_macro_references_and_variables(jinja_macro.definition) + used_variables.update(extracted_variable_names) + + flattened = {} + if (dbt_package_variables := project_variables.get(c.MIGRATED_DBT_PACKAGES)) and isinstance( + dbt_package_variables, dict + ): + # flatten the nested dict structure from the migrated dbt package variables in the SQLmesh config into __dbt_packages.. + # to match what extract_macro_references_and_variables() returns. This allows the usage checks in create_python_env() to work + def _flatten(prefix: str, root: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + acc = {} + for k, v in root.items(): + key_with_prefix = f"{prefix}.{k}" + if isinstance(v, dict): + acc.update(_flatten(key_with_prefix, v)) + else: + acc[key_with_prefix] = v + return acc + + flattened = _flatten(c.MIGRATED_DBT_PACKAGES, dbt_package_variables) + + return used_variables, flattened + + TIME_COL_PARTITION_FUNC = {"clickhouse": clickhouse_partition_func} diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index f58127dcdf..86eb6e665c 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -4,7 +4,7 @@ from enum import Enum from typing_extensions import Self -from pydantic import Field +from pydantic import Field, BeforeValidator from sqlglot import exp from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers @@ -33,6 +33,7 @@ field_validator, get_dialect, validate_string, + positive_int_validator, ) @@ -455,7 +456,7 @@ class IncrementalByUniqueKeyKind(_IncrementalBy): unique_key: SQLGlotListOfFields when_matched: t.Optional[exp.Whens] = None merge_filter: t.Optional[exp.Expression] = None - batch_concurrency: t.Literal[1] = 1 + batch_concurrency: t.Annotated[t.Literal[1], BeforeValidator(positive_int_validator)] = 1 @field_validator("when_matched", mode="before") def _when_matched_validator( diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index c683fc5862..6622094da3 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -179,6 +179,7 @@ def _resolve_table(table: str | exp.Table) -> str: ) render_kwargs = { + "dialect": self._dialect, **date_dict( to_datetime(execution_time or c.EPOCH), start_time, diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index cfff977a96..92719abacc 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -38,6 +38,9 @@ def __init__( self.jinja_globals = jinja_globals.copy() if jinja_globals else {} self.jinja_globals["adapter"] = self self.project_dialect = project_dialect + self.jinja_globals["dialect"] = ( + project_dialect # so the dialect is available in the jinja env created by self.dispatch() + ) self.quote_policy = quote_policy or Policy() @abc.abstractmethod diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index e07e00c961..4646011d57 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -156,7 +156,7 @@ class Var: def __init__(self, variables: t.Dict[str, t.Any]) -> None: self.variables = variables - def __call__(self, name: str, default: t.Optional[t.Any] = None) -> t.Any: + def __call__(self, name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any) -> t.Any: return self.variables.get(name, default) def has_var(self, name: str) -> bool: diff --git a/sqlmesh/dbt/converter/__init__.py b/sqlmesh/dbt/converter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sqlmesh/dbt/converter/common.py b/sqlmesh/dbt/converter/common.py new file mode 100644 index 0000000000..2bf4131065 --- /dev/null +++ b/sqlmesh/dbt/converter/common.py @@ -0,0 +1,40 @@ +from __future__ import annotations +import jinja2.nodes as j +from sqlglot import exp +import typing as t +import sqlmesh.core.constants as c +from pathlib import Path + + +# jinja transform is a function that takes (current node, previous node, parent node) and returns a new Node or None +# returning None means the current node is removed from the tree +# returning a different Node means the current node is replaced with the new Node +JinjaTransform = t.Callable[[j.Node, t.Optional[j.Node], t.Optional[j.Node]], t.Optional[j.Node]] +SQLGlotTransform = t.Callable[[exp.Expression], t.Optional[exp.Expression]] + + +def _sqlmesh_predefined_macro_variables() -> t.Set[str]: + def _gen() -> t.Iterable[str]: + for suffix in ("dt", "date", "ds", "ts", "tstz", "hour", "epoch", "millis"): + for prefix in ("start", "end", "execution"): + yield f"{prefix}_{suffix}" + + for item in ("runtime_stage", "gateway", "this_model", "this_env", "model_kind_name"): + yield item + + return set(_gen()) + + +SQLMESH_PREDEFINED_MACRO_VARIABLES = _sqlmesh_predefined_macro_variables() + + +def infer_dbt_package_from_path(path: Path) -> t.Optional[str]: + """ + Given a path like "sqlmesh-project/macros/__dbt_packages__/foo/bar.sql" + + Infer that 'foo' is the DBT package + """ + if c.MIGRATED_DBT_PACKAGES in path.parts: + idx = path.parts.index(c.MIGRATED_DBT_PACKAGES) + return path.parts[idx + 1] + return None diff --git a/sqlmesh/dbt/converter/console.py b/sqlmesh/dbt/converter/console.py new file mode 100644 index 0000000000..3fb12bcbc5 --- /dev/null +++ b/sqlmesh/dbt/converter/console.py @@ -0,0 +1,117 @@ +from __future__ import annotations +import typing as t +from pathlib import Path +from rich.console import Console as RichConsole +from rich.tree import Tree +from rich.progress import Progress, TextColumn, BarColumn, MofNCompleteColumn, TimeElapsedColumn +from sqlmesh.core.console import PROGRESS_BAR_WIDTH +from sqlmesh.utils import columns_to_types_all_known +from sqlmesh.utils import rich as srich +import logging +from rich.prompt import Confirm + +logger = logging.getLogger(__name__) + +if t.TYPE_CHECKING: + from sqlmesh.dbt.converter.convert import ConversionReport + + +def make_progress_bar( + console: t.Optional[RichConsole] = None, + justify: t.Literal["default", "left", "center", "right", "full"] = "right", +) -> Progress: + return Progress( + TextColumn("[bold blue]{task.description}", justify=justify), + BarColumn(bar_width=PROGRESS_BAR_WIDTH), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + console=console, + ) + + +class DbtConversionConsole: + """Console for displaying DBT project conversion progress""" + + def __init__(self, console: t.Optional[RichConsole] = None) -> None: + self.console: RichConsole = console or srich.console + + def log_message(self, message: str) -> None: + self.console.print(message) + + def start_project_conversion(self, input_path: Path) -> None: + self.log_message(f"DBT project loaded from {input_path}; starting conversion") + + def prompt_clear_directory(self, prefix: str, path: Path) -> bool: + return Confirm.ask( + f"{prefix}'{path}' is not empty.\nWould you like to clear it?", console=self.console + ) + + # Models + def start_models_conversion(self, model_count: int) -> None: + self.progress_bar = make_progress_bar(justify="left", console=self.console) + self.progress_bar.start() + self.models_progress_task_id = self.progress_bar.add_task( + "Converting models", total=model_count + ) + + def start_model_conversion(self, model_name: str) -> None: + logger.debug(f"Converting model {model_name}") + self.progress_bar.update(self.models_progress_task_id, description=None, refresh=True) + + def complete_model_conversion(self) -> None: + self.progress_bar.update(self.models_progress_task_id, refresh=True, advance=1) + + def complete_models_conversion(self) -> None: + self.progress_bar.update(self.models_progress_task_id, description=None, refresh=True) + + # Audits + + def start_audits_conversion(self, audit_count: int) -> None: + self.audits_progress_task_id = self.progress_bar.add_task( + "Converting audits", total=audit_count + ) + + def start_audit_conversion(self, audit_name: str) -> None: + self.progress_bar.update(self.audits_progress_task_id, description=None, refresh=True) + + def complete_audit_conversion(self) -> None: + self.progress_bar.update(self.audits_progress_task_id, refresh=True, advance=1) + + def complete_audits_conversion(self) -> None: + self.progress_bar.update(self.audits_progress_task_id, description=None, refresh=True) + + # Macros + + def start_macros_conversion(self, macro_count: int) -> None: + self.macros_progress_task_id = self.progress_bar.add_task( + "Converting macros", total=macro_count + ) + + def start_macro_conversion(self, macro_name: str) -> None: + self.progress_bar.update(self.macros_progress_task_id, description=None, refresh=True) + + def complete_macro_conversion(self) -> None: + self.progress_bar.update(self.macros_progress_task_id, refresh=True, advance=1) + + def complete_macros_conversion(self) -> None: + self.progress_bar.update(self.macros_progress_task_id, description=None, refresh=True) + self.progress_bar.stop() + + def output_report(self, report: ConversionReport) -> None: + tree = Tree( + "[blue]The following models are self-referencing and their column types could not be statically inferred:" + ) + + for output_path, model in report.self_referencing_models: + if not model.columns_to_types or not columns_to_types_all_known(model.columns_to_types): + tree_node = tree.add(f"[green]{model.name}") + tree_node.add(output_path.as_posix()) + + self.console.print(tree) + + self.log_message( + "[red]These will need to be manually fixed.[/red]\nEither specify the column types in the MODEL block or ensure the outer SELECT lists all columns" + ) diff --git a/sqlmesh/dbt/converter/convert.py b/sqlmesh/dbt/converter/convert.py new file mode 100644 index 0000000000..f097a83884 --- /dev/null +++ b/sqlmesh/dbt/converter/convert.py @@ -0,0 +1,414 @@ +import typing as t +from pathlib import Path +import shutil +import os + +from sqlmesh.dbt.loader import sqlmesh_config, DbtLoader, DbtContext, Project +from sqlmesh.core.context import Context +import sqlmesh.core.dialect as d +from sqlmesh.core import constants as c + +from sqlmesh.core.model.kind import SeedKind +from sqlmesh.core.model import SqlModel, SeedModel +from sqlmesh.dbt.converter.jinja import convert_jinja_query, convert_jinja_macro +from sqlmesh.dbt.converter.common import infer_dbt_package_from_path +import dataclasses +from dataclasses import dataclass + +from sqlmesh.dbt.converter.console import DbtConversionConsole +from sqlmesh.utils.jinja import JinjaMacroRegistry, extract_macro_references_and_variables +from sqlmesh.utils import yaml + + +@dataclass +class ConversionReport: + self_referencing_models: t.List[t.Tuple[Path, SqlModel]] = dataclasses.field( + default_factory=list + ) + + +@dataclass +class InputPaths: + # todo: read paths from DBT project yaml + + base: Path + + @property + def models(self) -> Path: + return self.base / "models" + + @property + def seeds(self) -> Path: + return self.base / "seeds" + + @property + def tests(self) -> Path: + return self.base / "tests" + + @property + def macros(self) -> Path: + return self.base / "macros" + + @property + def snapshots(self) -> Path: + return self.base / "snapshots" + + @property + def packages(self) -> Path: + return self.base / "dbt_packages" + + +@dataclass +class OutputPaths: + base: Path + + @property + def models(self) -> Path: + return self.base / "models" + + @property + def seeds(self) -> Path: + return self.base / "seeds" + + @property + def audits(self) -> Path: + return self.base / "audits" + + @property + def macros(self) -> Path: + return self.base / "macros" + + +def convert_project_files(src: Path, dest: Path, no_prompts: bool = True) -> None: + console = DbtConversionConsole() + report = ConversionReport() + + console.log_message(f"Converting project at '{src}' to '{dest}'") + + ctx, dbt_project = _load_project(src) + dbt_load_context = dbt_project.context + + console.start_project_conversion(src) + + input_paths, output_paths = _ensure_paths(src, dest, console, no_prompts) + + model_count = len(ctx.models) + + # DBT Models -> SQLMesh Models + console.start_models_conversion(model_count) + _convert_models(ctx, input_paths, output_paths, report, console) + console.complete_models_conversion() + + # DBT Tests -> Standalone Audits + console.start_audits_conversion(len(ctx.standalone_audits)) + _convert_standalone_audits(ctx, input_paths, output_paths, console) + console.complete_audits_conversion() + + # DBT Macros -> SQLMesh Jinja Macros + all_macros = list( + iterate_macros(input_paths.macros, output_paths.macros, dbt_load_context, ctx) + ) + console.start_macros_conversion(len(all_macros)) + for package, macro_text, input_id, output_file_path, should_transform in all_macros: + console.start_macro_conversion(input_id) + + output_file_path.parent.mkdir(parents=True, exist_ok=True) + converted = ( + convert_jinja_macro(ctx, macro_text, package) if should_transform else macro_text + ) + output_file_path.write_text(converted, encoding="utf8") + + console.complete_macro_conversion() + + console.complete_macros_conversion() + + # Generate SQLMesh config + # TODO: read all profiles from config and convert to gateways instead of just the current profile? + console.log_message("Writing SQLMesh config") + new_config = _generate_sqlmesh_config(ctx, dbt_project, dbt_load_context) + (dest / "config.yml").write_text(yaml.dump(new_config)) + + if report.self_referencing_models: + console.output_report(report) + + console.log_message("All done") + + +def _load_project(src: Path) -> t.Tuple[Context, Project]: + config = sqlmesh_config(project_root=src) + + ctx = Context(config=config, paths=src) + + dbt_loader = ctx._loaders[0] + assert isinstance(dbt_loader, DbtLoader) + + dbt_project = dbt_loader._projects[0] + + return ctx, dbt_project + + +def _ensure_paths( + src: Path, dest: Path, console: DbtConversionConsole, no_prompts: bool +) -> t.Tuple[InputPaths, OutputPaths]: + if not dest.exists(): + console.log_message(f"Creating output directory: {dest}") + dest.mkdir() + + if dest.is_file(): + raise ValueError(f"Output path must be a directory") + + if any(dest.iterdir()): + if not no_prompts and console.prompt_clear_directory("Output directory ", dest): + for path in dest.glob("**/*"): + if path.is_file(): + path.unlink() + elif path.is_dir(): + shutil.rmtree(path) + console.log_message(f"Output directory '{dest}' cleared") + else: + raise ValueError("Please ensure the output directory is empty") + + input_paths = InputPaths(src) + output_paths = OutputPaths(dest) + + for dir in (output_paths.models, output_paths.seeds, output_paths.audits, output_paths.macros): + dir.mkdir() + + return input_paths, output_paths + + +def _convert_models( + ctx: Context, + input_paths: InputPaths, + output_paths: OutputPaths, + report: ConversionReport, + console: DbtConversionConsole, +) -> None: + # Iterating in DAG order helps minimize re-rendering when the fingerprint cache is busted when we call upsert_model() to check if + # a self-referencing model has all its columns_to_types known or not + for fqn in ctx.dag: + model = ctx.models.get(fqn) + + if not model: + # some entries in the dag are not models + continue + + model_name = fqn + + # todo: support DBT model_paths[] being not `models` or being a list + # todo: write out column_descriptions() into model block + console.start_model_conversion(model_name) + + if model.kind.is_external: + # skip external models + # they can be created with `sqlmesh create_external_models` post-conversion + console.complete_model_conversion() # still advance the progress bar + continue + + if model.kind.is_seed: + # this will produce the original seed file, eg "items.csv" + seed_filename = model._path.relative_to(input_paths.seeds) + + # seed definition - rename "items.csv" -> "items.sql" + model_filename = seed_filename.with_suffix(".sql") + + # copy the seed data itself to the seeds dir + shutil.copyfile(model._path, output_paths.seeds / seed_filename) + + # monkeypatch the model kind to have a relative reference to the seed file + assert isinstance(model.kind, SeedKind) + model.kind.path = str(Path("../seeds", seed_filename)) + else: + if input_paths.models in model._path.parents: + model_filename = model._path.relative_to(input_paths.models) + elif input_paths.snapshots in model._path.parents: + # /base/path/snapshots/foo.sql -> /output/path/models/dbt_snapshots/foo.sql + model_filename = "dbt_snapshots" / model._path.relative_to(input_paths.snapshots) + elif input_paths.packages in model._path.parents: + model_filename = c.MIGRATED_DBT_PACKAGES / model._path.relative_to( + input_paths.packages + ) + else: + raise ValueError(f"Unhandled model path: {model._path}") + + # todo: a SQLGLot transform on `audits` in the model definition to lowercase the names? + model_output_path = output_paths.models / model_filename + model_output_path.parent.mkdir(parents=True, exist_ok=True) + model_package = infer_dbt_package_from_path(model_output_path) + + def _render(e: d.exp.Expression) -> str: + if isinstance(e, d.Jinja): + e = convert_jinja_query(ctx, model, e, model_package) + rendered = e.sql(dialect=model.dialect, pretty=True) + if not isinstance(e, d.Jinja): + rendered += ";" + return rendered + + model_to_render = model.model_copy( + update=dict(depends_on_=None if len(model.depends_on) > 0 else set()) + ) + if isinstance(model, (SqlModel, SeedModel)): + # Keep depends_on for SQL Models because sometimes the entire query is a macro call. + # If we clear it and rely on inference, the SQLMesh native loader will throw: + # - ConfigError: Dependencies must be provided explicitly for models that can be rendered only at runtime + model_to_render = model.model_copy( + update=dict(depends_on_=resolve_fqns_to_model_names(ctx, model.depends_on)) + ) + + rendered_queries = [ + _render(q) + for q in model_to_render.render_definition(render_query=False, include_python=False) + ] + + # add inline audits + # todo: handle these better + # maybe output generic audits for the 4 DBT audits (not_null, unique, accepted_values, relationships) and emit definitions for them? + for _, audit in model.audit_definitions.items(): + rendered_queries.append("\n" + _render(d.parse_one(f"AUDIT (name {audit.name})"))) + # todo: or do we want the original? + rendered_queries.append(_render(model.render_audit_query(audit))) + + model_definition = "\n".join(rendered_queries) + + model_output_path.write_text(model_definition) + + console.complete_model_conversion() + + +def _convert_standalone_audits( + ctx: Context, input_paths: InputPaths, output_paths: OutputPaths, console: DbtConversionConsole +) -> None: + for _, audit in ctx.standalone_audits.items(): + console.start_audit_conversion(audit.name) + audit_definition = audit.render_definition(include_python=False) + + stringified = [] + for expression in audit_definition: + if isinstance(expression, d.JinjaQuery): + expression = convert_jinja_query(ctx, audit, expression) + stringified.append(expression.sql(dialect=audit.dialect, pretty=True)) + + audit_definition_string = ";\n".join(stringified) + + audit_filename = audit._path.relative_to(input_paths.tests) + audit_output_path = output_paths.audits / audit_filename + audit_output_path.write_text(audit_definition_string) + console.complete_audit_conversion() + return None + + +def _generate_sqlmesh_config( + ctx: Context, dbt_project: Project, dbt_load_context: DbtContext +) -> t.Dict[str, t.Any]: + DEFAULT_ARGS: t.Dict[str, t.Any] + from sqlmesh.utils.pydantic import DEFAULT_ARGS + + base_config = ctx.config.model_dump( + mode="json", include={"gateways", "model_defaults", "variables"}, **DEFAULT_ARGS + ) + # Extend with the variables loaded from DBT + if "variables" not in base_config: + base_config["variables"] = {} + if c.MIGRATED_DBT_PACKAGES not in base_config["variables"]: + base_config["variables"][c.MIGRATED_DBT_PACKAGES] = {} + + # this is used when loading with the native loader to set the package name for top level macros + base_config["variables"][c.MIGRATED_DBT_PROJECT_NAME] = dbt_project.context.project_name + + migrated_package_names = [] + for package in dbt_project.packages.values(): + dbt_load_context.set_and_render_variables(package.variables, package.name) + + if package.name == dbt_project.context.project_name: + base_config["variables"].update(dbt_load_context.variables) + else: + base_config["variables"][c.MIGRATED_DBT_PACKAGES][package.name] = ( + dbt_load_context.variables + ) + migrated_package_names.append(package.name) + + for package_name in migrated_package_names: + # these entries are duplicates because the DBT loader already applies any project specific overrides to the + # package level variables + base_config["variables"].pop(package_name, None) + + return base_config + + +def iterate_macros( + input_macros_dir: Path, output_macros_dir: Path, dbt_load_context: DbtContext, ctx: Context +) -> t.Iterator[t.Tuple[t.Optional[str], str, str, Path, bool]]: + """ + Return an iterator over all the macros that need to be migrated + + The main project level ones are read from the source macros directory (it's assumed these are written by the user) + + The rest / library level ones are read from the DBT manifest based on merging together all the model JinjaMacroRegistry's from the SQLMesh context + """ + + all_macro_references = set() + + for dirpath, _, files in os.walk( + input_macros_dir + ): # note: pathlib doesnt have a walk function until python 3.12 + for name in files: + if name.lower().endswith(".sql"): + input_file_path = Path(dirpath) / name + + output_file_path = output_macros_dir / ( + input_file_path.relative_to(input_macros_dir) + ) + + input_file_contents = input_file_path.read_text(encoding="utf8") + + # as we migrate user-defined macros, keep track of other macros they reference from other packages/libraries + # so we can be sure theyre included + # (since there is no guarantee a model references a user-defined macro which means the dependencies may not be pulled in automatically) + macro_refs, _ = extract_macro_references_and_variables( + input_file_contents, dbt_target_name=dbt_load_context.target_name + ) + all_macro_references.update(macro_refs) + + yield ( + None, + input_file_contents, + str(input_file_path), + output_file_path, + True, + ) + + jmr = JinjaMacroRegistry() + for model in ctx.models.values(): + jmr = jmr.merge(model.jinja_macros) + + # add any macros that are referenced in user macros but not necessarily directly in models + # this can happen if a user has defined a macro that is currently unused in a model but we still want to migrate it + jmr = jmr.merge( + dbt_load_context.jinja_macros.trim( + all_macro_references, package=dbt_load_context.project_name + ) + ) + + for package, name, macro in jmr.all_macros: + if package and package != dbt_load_context.project_name: + output_file_path = output_macros_dir / c.MIGRATED_DBT_PACKAGES / package / f"{name}.sql" + + yield ( + package, + macro.definition, + f"{package}.{name}", + output_file_path, + "var(" in macro.definition, # todo: check for ref() etc as well? + ) + + +def resolve_fqns_to_model_names(ctx: Context, fqns: t.Set[str]) -> t.Set[str]: + # model.depends_on is provided by the DbtLoader as a list of fully qualified table name strings + # if we output them verbatim, when loading them back we get errors like: + # - ConfigError: Failed to load model definition: 'Dot' object has no attribute 'catalog' + # So we need to resolve them to model names instead. + # External models also need to be excluded because the "name" is still a FQN string so cause the above error + + return { + ctx.models[i].name for i in fqns if i in ctx.models and not ctx.models[i].kind.is_external + } diff --git a/sqlmesh/dbt/converter/jinja.py b/sqlmesh/dbt/converter/jinja.py new file mode 100644 index 0000000000..783ae5a74f --- /dev/null +++ b/sqlmesh/dbt/converter/jinja.py @@ -0,0 +1,604 @@ +import typing as t +import jinja2.nodes as j +import sqlmesh.core.dialect as d +from sqlmesh.core.context import Context +from sqlmesh.core.snapshot import Node +from sqlmesh.core.model import SqlModel, load_sql_based_model +from sqlglot import exp +from sqlmesh.dbt.converter.common import JinjaTransform +from inspect import signature +from more_itertools import windowed +from itertools import chain +from sqlmesh.dbt.context import DbtContext +import sqlmesh.dbt.converter.jinja_transforms as jt +from sqlmesh.utils.errors import ConfigError +from sqlmesh.utils.jinja import SQLMESH_DBT_COMPATIBILITY_PACKAGE + +# for j.Operand.op +OPERATOR_MAP = { + "eq": "==", + "ne": "!=", + "lt": "<", + "gt": ">", + "lteq": "<=", + "gteq": ">=", + "in": "in", + "notin": "not in", +} + + +def lpad_windowed(iterable: t.Iterable[j.Node]) -> t.Iterator[t.Tuple[t.Optional[j.Node], j.Node]]: + for prev, curr in windowed(chain([None], iterable), 2): + if curr is None: + raise ValueError("Current item cannot be None") + yield prev, curr + + +class JinjaGenerator: + def generate( + self, node: j.Node, prev: t.Optional[j.Node] = None, parent: t.Optional[j.Node] = None + ) -> str: + if not isinstance(node, j.Node): + raise ValueError(f"Generator only works with Jinja AST nodes, not: {type(node)}") + + acc = "" + + node_type = type(node) + generator_fn_name = f"_generate_{node_type.__name__.lower()}" + + if generator_fn := getattr(self, generator_fn_name, None): + sig = signature(generator_fn) + kwargs: t.Dict[str, t.Optional[j.Node]] = {"node": node} + if "prev" in sig.parameters: + kwargs["prev"] = prev + if "parent" in sig.parameters: + kwargs["parent"] = parent + acc += generator_fn(**kwargs) + else: + raise NotImplementedError(f"Generator for node type '{type(node)}' is not implemented") + + return acc + + def _generate_template(self, node: j.Template) -> str: + acc = [] + for prev, curr in lpad_windowed(node.body): + if curr: + acc.append(self.generate(curr, prev, node)) + + return "".join(acc) + + def _generate_output(self, node: j.Output) -> str: + acc = [] + for prev, curr in lpad_windowed(node.nodes): + acc.append(self.generate(curr, prev, node)) + + return "".join(acc) + + def _generate_templatedata(self, node: j.TemplateData) -> str: + return node.data + + def _generate_name( + self, node: j.Name, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> str: + return self._wrap_in_expression_if_necessary(node.name, prev, parent) + + def _generate_getitem( + self, node: j.Getitem, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> str: + item_name = self.generate(node.node, parent=node) + if node.arg: + if node.node.find(j.Filter): + # for when someone has {{ (foo | bar | baz)[0] }} + item_name = f"({item_name})" + item_name = f"{item_name}[{self.generate(node.arg, parent=node)}]" + + return self._wrap_in_expression_if_necessary(item_name, prev, parent) + + def _generate_getattr( + self, node: j.Getattr, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> str: + what_str = self.generate(node.node, parent=node) + + return self._wrap_in_expression_if_necessary(f"{what_str}.{node.attr}", prev, parent) + + def _generate_const( + self, node: j.Const, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> str: + quotechar = "" + node_value: str + if isinstance(node.value, str): + quotechar = "'" if "'" not in node.value else '"' + node_value = node.value + else: + node_value = str(node.value) + + const_value = quotechar + node_value + quotechar + + return self._wrap_in_expression_if_necessary(const_value, prev, parent) + + def _generate_keyword(self, node: j.Keyword) -> str: + return node.key + "=" + self.generate(node.value, parent=node) + + def _generate_test(self, node: j.Test, parent: t.Optional[j.Node]) -> str: + var_name = self.generate(node.node, parent=node) + test = "is" if not isinstance(parent, j.Not) else "is not" + if node.name: + return f"{var_name} {test} {node.name}" + return var_name + + def _generate_assign(self, node: j.Assign) -> str: + target_str = self.generate(node.target, parent=node) + what_str = self.generate(node.node, parent=node) + return "{% set " + target_str + " = " + what_str + " %}" + + def _generate_assignblock(self, node: j.AssignBlock) -> str: + target_str = self.generate(node.target, parent=node) + body_str = "".join(self.generate(c, parent=node) for c in node.body) + # todo: node.filter? + return "{% set " + target_str + " %}" + body_str + "{% endset %}" + + def _generate_call( + self, node: j.Call, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> str: + call_name = self.generate(node.node, parent=node) + call_args = ", ".join(self.generate(a, parent=node) for a in node.args) + call_kwargs = ", ".join(self.generate(a, parent=node) for a in node.kwargs) + sep = ", " if call_args and call_kwargs else "" + call_str = call_name + f"({call_args}{sep}{call_kwargs})" + + return self._wrap_in_expression_if_necessary(call_str, prev, parent) + + def _generate_if(self, node: j.If, parent: t.Optional[j.Node]) -> str: + test_str = self.generate(node.test, parent=node) + body_str = "".join(self.generate(c, parent=node) for c in node.body) + elifs_str = "".join(self.generate(c, parent=node) for c in node.elif_) + elses_str = "".join(self.generate(c, parent=node) for c in node.else_) + + end_block_name: t.Optional[str] + block_name, end_block_name = "if", "endif" + if isinstance(parent, j.If): + if node in parent.elif_: + block_name, end_block_name = "elif", None + + end_block = "{% " + end_block_name + " %}" if end_block_name else "" + + elses_str = "{% else %}" + elses_str if elses_str else "" + + return ( + "{% " + + block_name + + " " + + test_str + + " %}" + + body_str + + elifs_str + + elses_str + + end_block + ) + + def _generate_macro(self, node: j.Macro, prev: t.Optional[j.Node]) -> str: + name_str = node.name + rendered_defaults = list(reversed([self.generate(d, parent=node) for d in node.defaults])) + rendered_args = [self.generate(a, parent=node) for a in node.args] + + # the defaults, if they exist, line up with the last arguments in the list + # so we reverse the lists to match the arrays and then reverse the result to get the original order + args_with_defaults = [ + (arg, next(iter(rendered_defaults[idx : idx + 1]), None)) + for idx, arg in enumerate(reversed(rendered_args)) + ] + args_with_defaults = list(reversed(args_with_defaults)) + + args_str = ", ".join(f"{a}={d}" if d is not None else a for a, d in args_with_defaults) + body_str = "".join(self.generate(c, parent=node) for c in node.body) + + # crude sql comment detection that will cause false positives that hopefully shouldnt matter + # this is to work around a WONTFIX bug in the SQLGlot tokenizer that if the macro body contains a SQL comment + # and {% endmacro %} is on the same line, it gets included as comment instead of a proper token + # the bug also occurs if the {% macro %} tag is on a line that starts with a SQL comment + start_tag = "{% macro " + if prev: + prev_str = self.generate(prev) + if "--" in prev_str and not prev_str.rstrip(" ").endswith("\n"): + start_tag = "\n" + start_tag + + end_tag = "{% endmacro %}" + if "--" in body_str and not body_str.rstrip(" ").endswith("\n"): + end_tag = "\n" + end_tag + + return start_tag + name_str + "(" + args_str + ")" + " %}" + body_str + end_tag + + def _generate_for(self, node: j.For) -> str: + target_str = self.generate(node.target, parent=node) + iter_str = self.generate(node.iter, parent=node) + test_str = "if " + self.generate(node.test, parent=node) if node.test else None + body_str = "".join(self.generate(c, parent=node) for c in node.body) + + acc = "{% for " + target_str + " in " + iter_str + if test_str: + acc += f" {test_str}" + acc += " %}" + acc += body_str + acc += "{% endfor %}" + + return acc + + def _generate_list(self, node: j.List, parent: t.Optional[j.Node]) -> str: + items_str_array = [self.generate(i, parent=node) for i in node.items] + items_on_newline = ( + not isinstance(parent, j.Pair) + and len(items_str_array) > 1 + and any(len(i) > 50 for i in items_str_array) + ) + item_separator = "\n\t" if items_on_newline else " " + items_str = f",{item_separator}".join(items_str_array) + start_separator = "\n\t" if items_on_newline else "" + end_separator = "\n" if items_on_newline else "" + return f"[{start_separator}{items_str}{end_separator}]" + + def _generate_dict(self, node: j.Dict) -> str: + items_str = ", ".join(self.generate(c, parent=node) for c in node.items) + return "{ " + items_str + " }" + + def _generate_pair(self, node: j.Pair) -> str: + key_str = self.generate(node.key, parent=node) + value_str = self.generate(node.value, parent=node) + return f"{key_str}: {value_str}" + + def _generate_not(self, node: j.Not) -> str: + if isinstance(node.node, j.Test): + return self.generate(node.node, parent=node) + + return self.__generate_unaryexp(node) + + def _generate_neg(self, node: j.Neg) -> str: + return self.__generate_unaryexp(node) + + def _generate_pos(self, node: j.Pos) -> str: + return self.__generate_unaryexp(node) + + def _generate_compare(self, node: j.Compare) -> str: + what_str = self.generate(node.expr, parent=node) + + # todo: is this correct? need to test with multiple ops + ops_str = "".join(self.generate(o, parent=node) for o in node.ops) + + return f"{what_str} {ops_str}" + + def _generate_slice(self, node: j.Slice) -> str: + start_str = self.generate(node.start, parent=node) if node.start else "" + stop_str = self.generate(node.stop, parent=node) if node.stop else "" + # todo: need a syntax example of step + return f"{start_str}:{stop_str}" + + def _generate_operand(self, node: j.Operand) -> str: + assert isinstance(node, j.Operand) + value_str = self.generate(node.expr, parent=node) + + return f"{OPERATOR_MAP[node.op]} " + value_str + + def _generate_add(self, node: j.Add, parent: t.Optional[j.Node]) -> str: + return self.__generate_binexp(node, parent) + + def _generate_mul(self, node: j.Mul, parent: t.Optional[j.Node]) -> str: + return self.__generate_binexp(node, parent) + + def _generate_div(self, node: j.Div, parent: t.Optional[j.Node]) -> str: + return self.__generate_binexp(node, parent) + + def _generate_sub(self, node: j.Sub, parent: t.Optional[j.Node]) -> str: + return self.__generate_binexp(node, parent) + + def _generate_floordiv(self, node: j.FloorDiv, parent: t.Optional[j.Node]) -> str: + return self.__generate_binexp(node, parent) + + def _generate_mod(self, node: j.Mod, parent: t.Optional[j.Node]) -> str: + return self.__generate_binexp(node, parent) + + def _generate_pow(self, node: j.Pow, parent: t.Optional[j.Node]) -> str: + return self.__generate_binexp(node, parent) + + def _generate_or(self, node: j.Or, parent: t.Optional[j.Node]) -> str: + return self.__generate_binexp(node, parent) + + def _generate_and(self, node: j.And, parent: t.Optional[j.Node]) -> str: + return self.__generate_binexp(node, parent) + + def _generate_concat(self, node: j.Concat) -> str: + return " ~ ".join(self.generate(c, parent=node) for c in node.nodes) + + def _generate_tuple(self, node: j.Tuple, parent: t.Optional[j.Node]) -> str: + parenthesis = isinstance(parent, (j.Operand, j.Call)) + items_str = ", ".join(self.generate(i, parent=node) for i in node.items) + return items_str if not parenthesis else f"({items_str})" + + def _generate_filter( + self, node: j.Filter, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> str: + # node.node may be None if this Filter is part of a FilterBlock + what_str = self.generate(node.node, parent=node) if node.node else None + if isinstance(node.node, j.CondExpr): + what_str = f"({what_str})" + + args_str = ", ".join(self.generate(a, parent=node) for a in node.args + node.kwargs) + if args_str: + args_str = f"({args_str})" + + filter_expr = f"{node.name}{args_str}" + if what_str: + filter_expr = f"{what_str} | {filter_expr}" + + return self._wrap_in_expression_if_necessary(filter_expr, prev=prev, parent=parent) + + def _generate_filterblock(self, node: j.FilterBlock) -> str: + filter_str = self.generate(node.filter, parent=node) + body_str = "".join(self.generate(c, parent=node) for c in node.body) + return "{% filter " + filter_str + " %}" + body_str + "{% endfilter %}" + + def _generate_exprstmt(self, node: j.ExprStmt) -> str: + node_str = self.generate(node.node, parent=node) + return "{% do " + node_str + " %}" + + def _generate_condexpr( + self, node: j.CondExpr, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> str: + test_sql = self.generate(node.test, parent=node) + expr1_sql = self.generate(node.expr1, parent=node) + + if node.expr2 is None: + raise ValueError("CondExpr lacked an 'else', not sure how to handle this") + + expr2_sql = self.generate(node.expr2, parent=node) + return self._wrap_in_expression_if_necessary( + f"{expr1_sql} if {test_sql} else {expr2_sql}", prev, parent + ) + + def __generate_binexp(self, node: j.BinExpr, parent: t.Optional[j.Node]) -> str: + left_str = self.generate(node.left, parent=node) + right_str = self.generate(node.right, parent=node) + + wrap_left = isinstance(node.left, j.BinExpr) + wrap_right = isinstance(node.right, j.BinExpr) + + acc = f"({left_str})" if wrap_left else left_str + acc += f" {node.operator} " + acc += f"({right_str})" if wrap_right else right_str + + return acc + + def __generate_unaryexp(self, node: j.UnaryExpr) -> str: + body_str = self.generate(node.node, parent=node) + return f"{node.operator} {body_str}" + + def _generate_nsref(self, node: j.NSRef) -> str: + return f"{node.name}.{node.attr}" + + def _generate_callblock(self, node: j.CallBlock) -> str: + call = self.generate(node.call, parent=node) + body = "".join(self.generate(e, parent=node) for e in node.body) + args = ", ".join(self.generate(arg, parent=node) for arg in node.args) + + open_tag = "{% call" + + if args: + open_tag += "(" + args + ")" + + if len(node.defaults) > 0: + raise NotImplementedError("Not sure how to handle CallBlock.defaults") + + return open_tag + " " + call + " %}" + body + "{% endcall %}" + + def _wrap_in_expression_if_necessary( + self, string: str, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> str: + wrap = False + if isinstance(prev, j.TemplateData): + wrap = True + elif prev is None and isinstance(parent, j.Output): + wrap = True + elif parent: + # if the node is nested inside eg an {% if %} block, dont wrap it in {{ }} + wrap = not any(isinstance(parent, t) for t in (j.Operand, j.Stmt, j.Expr, j.Helper)) + + return "{{ " + string + " }}" if wrap else string + + +def _contains_jinja(query: str) -> bool: + if "{{" in query: + return True + if "{%" in query: + return True + return False + + +def transform(base: j.Node, handler: JinjaTransform) -> j.Node: + sig = signature(handler) + + def _build_handler_kwargs( + node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> t.Dict[str, t.Any]: + kwargs: t.Dict[str, t.Optional[j.Node]] = {"node": node} + if "prev" in sig.parameters: + kwargs["prev"] = prev + if "parent" in sig.parameters: + kwargs["parent"] = parent + return kwargs + + def _transform( + node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> t.Optional[j.Node]: + transformed_node: t.Optional[j.Node] = handler(**_build_handler_kwargs(node, prev, parent)) # type: ignore + + if not transformed_node: + return None + + node = transformed_node + + new_children: t.Dict[j.Node, t.Optional[j.Node]] = {} + prev = None + for child in list(node.iter_child_nodes()): + transformed_child = _transform(node=child, prev=prev, parent=node) + if transformed_child != child: + new_children[child] = transformed_child + prev = child + + if new_children: + replacement_fields: t.Dict[str, t.Union[j.Node, t.List[j.Node]]] = {} + for name, value in node.iter_fields(): + assert isinstance(name, str) + + if isinstance(value, list): + replacement_value_list = [new_children.get(i, i) for i in value] + replacement_fields[name] = [r for r in replacement_value_list if r is not None] + elif isinstance(value, j.Node): + replacement_value = new_children.get(value) or value + replacement_fields[name] = replacement_value + for name, value in replacement_fields.items(): + setattr(node, name, value) + + return node + + transformed = _transform(node=base, prev=None, parent=None) + if transformed is None: + raise ValueError( + f"Transform '{handler.__name__}' consumed the entire AST; this indicates a bug" + ) + return transformed + + +def convert_jinja_query( + context: Context, + node: Node, + query: d.Jinja, + package: t.Optional[str] = None, + exclude: t.Optional[t.List[t.Callable]] = None, +) -> t.Union[d.JinjaQuery, d.JinjaStatement, exp.Query, exp.DDL]: + jinja_env = node.jinja_macros.build_environment() + + ast: j.Node = jinja_env.parse(query.text("this")) # type: ignore + + transforms = [ + # transform {{ ref("foo") }} -> schema.foo (NOT "fully_qualified"."schema"."foo") + jt.resolve_dbt_ref_to_model_name(context.models, jinja_env, node.dialect), + # Rewrite ref() calls that cant be converted to strings (maybe theyre macro aguments) to __migrated_ref() calls + jt.rewrite_dbt_ref_to_migrated_ref(context.models, jinja_env, node.dialect), + # transform {{ source("upstream"."foo") }} -> upstream.foo (NOT "fully_qualified"."upstream"."foo") + jt.resolve_dbt_source_to_model_name(context.models, jinja_env, node.dialect), + # Rewrite source() calls that cant be converted to strings (maybe theyre macro aguments) to __migrated_source() calls + jt.rewrite_dbt_source_to_migrated_source(context.models, jinja_env, node.dialect), + # transform {{ this }} -> model.name + jt.resolve_dbt_this_to_model_name(node.name), + # deuplicate where both {% if sqlmesh_incremental %} and {% if is_incremental() %} are used + jt.deduplicate_incremental_checks(), + # unpack {% if is_incremental() %} blocks because they arent necessary when running a native project + jt.unpack_incremental_checks(), + ] + + if package: + transforms.append(jt.append_dbt_package_kwarg_to_var_calls(package)) + + transforms = [ + t for t in transforms if not any(e.__name__ in t.__name__ for e in (exclude or [])) + ] + + for handler in transforms: + ast = transform(ast, handler) + + generator = JinjaGenerator() + pre_post_processing = generator.generate(ast) + if isinstance(node, SqlModel) and isinstance(query, d.JinjaQuery) and not node.depends_on_self: + # is it self-referencing now is_incremental() has been removed? + # if so, and columns_to_types are not all known, then we can't remove is_incremental() or we will get a load error + + # try to load the converted model with the native loader + model_definition = node.copy(update=dict(audits=[])).render_definition()[0].sql() + + # we need the Jinja builtins that inclide the compatibility shims because the transforms may have created eg __migrated_ref() calls + jinja_macros = node.jinja_macros.copy( + update=dict(create_builtins_module=SQLMESH_DBT_COMPATIBILITY_PACKAGE) + ) + + converted_node = load_sql_based_model( + expressions=[d.parse_one(model_definition), d.JinjaQuery(this=pre_post_processing)], + jinja_macros=jinja_macros, + defaults=context.config.model_defaults.dict(), + default_catalog=node.default_catalog, + ) + original_model = context.models[node.fqn] + + if converted_node.depends_on_self: + try: + # we need to upsert the model into the context to trigger columns_to_types inference + # note that this can sometimes bust the optimized query cache which can lead to long pauses converting some models in large projects + context.upsert_model(converted_node) + except ConfigError as e: + if "Self-referencing models require inferrable column types" in str(e): + # we have a self-referencing model where the columns_to_types cannot be inferred + # run the conversion again without the unpack_incremental_checks transform + return convert_jinja_query( + context, node, query, exclude=[jt.unpack_incremental_checks] + ) + raise + except Exception: + # todo: perhaps swallow this so that we just continue on with the original logic + raise + finally: + context.upsert_model(original_model) # put the original model definition back + + ast = transform(ast, jt.rewrite_sqlmesh_predefined_variables_to_sqlmesh_macro_syntax()) + post_processed = generator.generate(ast) + + # post processing - have we removed all the jinja so this can effectively be a normal SQL query? + if not _contains_jinja(post_processed): + parsed = d.parse_one(post_processed, dialect=node.dialect) + + # converting DBT '{{ start_ds }}' to a SQLMesh macro results in single quoted '@start_ds' but we really need unquoted @start_ds + transformed = parsed.transform(jt.unwrap_macros_in_string_literals()) + if isinstance(transformed, (exp.Query, exp.DDL)): + return transformed + + raise ValueError( + f"Transformation resulted in a {type(transformed)} node instead of Query / DDL statement" + ) + + if isinstance(query, d.JinjaQuery): + return d.JinjaQuery(this=pre_post_processing) + if isinstance(query, d.JinjaStatement): + return d.JinjaStatement(this=pre_post_processing) + + raise ValueError(f"Not sure how to handle: {type(query)}") + + +def convert_jinja_macro(context: Context, src: str, package: t.Optional[str] = None) -> str: + jinja_macros = DbtContext().jinja_macros # ensures the correct create_builtins_module is set + jinja_macros = jinja_macros.merge(context._jinja_macros) + + jinja_env = jinja_macros.build_environment() + + dialect = context.default_dialect + if not dialect: + raise ValueError("No project dialect configured?") + + transforms = [ + # transform {{ ref("foo") }} -> schema.foo (NOT "fully_qualified"."schema"."foo") + jt.resolve_dbt_ref_to_model_name(context.models, jinja_env, dialect), + # Rewrite ref() calls that cant be converted to strings (maybe theyre macro aguments) to __migrated_ref() calls + jt.rewrite_dbt_ref_to_migrated_ref(context.models, jinja_env, dialect), + # transform {{ source("foo", "bar") }} -> `qualified`.`foo`.`bar` + jt.resolve_dbt_source_to_model_name(context.models, jinja_env, dialect), + # transform {{ var('foo') }} -> {{ var('foo', __dbt_package='') }} + jt.append_dbt_package_kwarg_to_var_calls(package), + # deduplicate where both {% if sqlmesh_incremental %} and {% if is_incremental() %} are used + jt.deduplicate_incremental_checks(), + # unpack {% if sqlmesh_incremental %} blocks because they arent necessary when running a native project + jt.unpack_incremental_checks(), + ] + + ast: j.Node = jinja_env.parse(src) + + for handler in transforms: + ast = transform(ast, handler) + + generator = JinjaGenerator() + + return generator.generate(ast) diff --git a/sqlmesh/dbt/converter/jinja_builtins.py b/sqlmesh/dbt/converter/jinja_builtins.py new file mode 100644 index 0000000000..59303ad344 --- /dev/null +++ b/sqlmesh/dbt/converter/jinja_builtins.py @@ -0,0 +1,109 @@ +import typing as t +import functools +from sqlmesh.utils.jinja import JinjaMacroRegistry +from dbt.adapters.base.relation import BaseRelation +from sqlmesh.dbt.builtin import Api +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.utils.errors import ConfigError +from dbt.adapters.base import BaseRelation +from sqlglot import exp + +from dbt.adapters.base import BaseRelation + + +def migrated_ref( + dbt_api: Api, + database: t.Optional[str] = None, + schema: t.Optional[str] = None, + identifier: t.Optional[str] = None, + version: t.Optional[int] = None, + sqlmesh_model_name: t.Optional[str] = None, +) -> BaseRelation: + if version: + raise ValueError("dbt model versions are not supported in converted projects.") + + return dbt_api.Relation.create(database=database, schema=schema, identifier=identifier) + + +def migrated_source( + dbt_api: Api, + database: t.Optional[str] = None, + schema: t.Optional[str] = None, + identifier: t.Optional[str] = None, +) -> BaseRelation: + return dbt_api.Relation.create(database=database, schema=schema, identifier=identifier) + + +def create_builtin_globals( + jinja_macros: JinjaMacroRegistry, + global_vars: t.Dict[str, t.Any], + engine_adapter: t.Optional[EngineAdapter], + *args: t.Any, + **kwargs: t.Any, +) -> t.Dict[str, t.Any]: + import sqlmesh.utils.jinja as sqlmesh_native_jinja + import sqlmesh.dbt.builtin as sqlmesh_dbt_jinja + + # Capture dialect before the dbt builtins pops it + dialect = global_vars.get("dialect") + + sqlmesh_native_globals = sqlmesh_native_jinja.create_builtin_globals( + jinja_macros, global_vars, *args, **kwargs + ) + + if this_model := global_vars.get("this_model"): + # create a DBT-compatible version of @this_model for {{ this }} + if isinstance(this_model, str): + if not dialect: + raise ConfigError("No dialect?") + + # in audits, `this_model` is a SQL SELECT query that selects from the current table + # elsewhere, it's a fqn string + parsed: exp.Expression = exp.maybe_parse(this_model, dialect=dialect) + + table: t.Optional[exp.Table] = None + if isinstance(parsed, exp.Column): + table = exp.to_table(this_model, dialect=dialect) + elif isinstance(parsed, exp.Query): + table = parsed.find(exp.Table) + else: + raise ConfigError(f"Not sure how to handle this_model: {this_model}") + + if table: + # sqlmesh_dbt_jinja.create_builtin_globals() will construct a Relation for {{ this }} based on the supplied dict + global_vars["this"] = { + "database": table.catalog, + "schema": table.db, + "identifier": table.name, + } + + else: + raise ConfigError(f"Unhandled this_model type: {type(this_model)}") + + sqlmesh_dbt_globals = sqlmesh_dbt_jinja.create_builtin_globals( + jinja_macros, global_vars, engine_adapter, *args, **kwargs + ) + + def source(dbt_api: Api, source_name: str, table_name: str) -> BaseRelation: + # some source() calls cant be converted to __migrated_source() calls because they contain dynamic parameters + # this is a fallback and will be wrong in some situations because `sources` in DBT can be aliased in config + # TODO: maybe we migrate sources into the SQLMesh variables so we can look them up here? + return dbt_api.Relation.create(database=source_name, identifier=table_name) + + def ref(dbt_api: Api, ref_name: str, package: t.Optional[str] = None) -> BaseRelation: + # some ref() calls cant be converted to __migrated_ref() calls because they contain dynamic parameters + raise NotImplementedError( + f"Unable to resolve ref: {ref_name}. Please replace it with an actual model name or use a SQLMesh macro to generate dynamic model name." + ) + + dbt_compatibility_shims = { + "dialect": dialect, + "__migrated_ref": functools.partial(migrated_ref, sqlmesh_dbt_globals["api"]), + "__migrated_source": functools.partial(migrated_source, sqlmesh_dbt_globals["api"]), + "source": functools.partial(source, sqlmesh_dbt_globals["api"]), + "ref": functools.partial(ref, sqlmesh_dbt_globals["api"]), + # make {{ config(...) }} a no-op, some macros call it but its meaningless in a SQLMesh Native project + "config": lambda *_args, **_kwargs: None, + } + + return {**sqlmesh_native_globals, **sqlmesh_dbt_globals, **dbt_compatibility_shims} diff --git a/sqlmesh/dbt/converter/jinja_transforms.py b/sqlmesh/dbt/converter/jinja_transforms.py new file mode 100644 index 0000000000..4c4cf03edc --- /dev/null +++ b/sqlmesh/dbt/converter/jinja_transforms.py @@ -0,0 +1,465 @@ +import typing as t +from types import MappingProxyType +from sqlmesh.core.model import Model +from jinja2 import Environment +import jinja2.nodes as j +from sqlmesh.dbt.converter.common import ( + SQLMESH_PREDEFINED_MACRO_VARIABLES, + JinjaTransform, + SQLGlotTransform, +) +from dbt.adapters.base.relation import BaseRelation +from sqlmesh.core.dialect import normalize_model_name +from sqlglot import exp +import sqlmesh.core.dialect as d +from functools import wraps + + +def _make_standalone_call_transform(fn_name: str, handler: JinjaTransform) -> JinjaTransform: + """ + Creates a transform that identifies standalone Call nodes (that arent nested in other Call nodes) and replaces them with nodes + containing the result of the handler() function + """ + + def _handle( + node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> t.Optional[j.Node]: + if isinstance(node, j.Call): + if isinstance(parent, (j.Call, j.List, j.Keyword)): + return node + + if (name := node.find(j.Name)) and name.name == fn_name: + return handler(node, prev, parent) + + return node + + return _handle + + +def _make_single_expression_transform( + mapping: t.Union[ + t.Dict[str, str], + t.Callable[[j.Node, t.Optional[j.Node], t.Optional[j.Node], str], t.Optional[str]], + ], +) -> JinjaTransform: + """ + Creates a transform that looks for standalone {{ expression }} nodes + It then looks up 'expression' in the provided mapping and replaces it with a TemplateData node containing the value + """ + + def _handle(node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node]) -> j.Node: + # the assumption is that individual expressions are nested in between TemplateData + if prev and not isinstance(prev, j.TemplateData): + return node + + if isinstance(node, j.Name) and not isinstance(parent, j.Getattr): + if isinstance(mapping, dict): + result = mapping.get(node.name) + else: + result = mapping(node, prev, parent, node.name) + if result is not None: + return j.TemplateData(result) + + return node + + return _handle + + +def _dbt_relation_to_model_name( + models: MappingProxyType[str, t.Union[Model, str]], relation: BaseRelation, dialect: str +) -> t.Optional[str]: + model_fqn = normalize_model_name( + table=relation.render(), default_catalog=relation.database, dialect=dialect + ) + if resolved_value := models.get(model_fqn): + return resolved_value if isinstance(resolved_value, str) else resolved_value.name + return None + + +def _dbt_relation_to_kwargs(relation: BaseRelation) -> t.List[j.Keyword]: + kwargs = [] + if database := relation.database: + kwargs.append(j.Keyword("database", j.Const(database))) + if schema := relation.schema: + kwargs.append(j.Keyword("schema", j.Const(schema))) + if identifier := relation.identifier: + kwargs.append(j.Keyword("identifier", j.Const(identifier))) + return kwargs + + +ASTTransform = t.TypeVar("ASTTransform", JinjaTransform, SQLGlotTransform) + + +def ast_transform(fn: t.Callable[..., ASTTransform]) -> t.Callable[..., ASTTransform]: + """ + Decorator to mark functions as being Jinja or SQLGlot AST transforms + + The purpose is to set __name__ to be the outer function name so that the transforms have stable names for an exclude list + The function itself as well as the ASTTransform returned by the function should have the same __name__ for this to work + """ + + @wraps(fn) + def wrapper(*args: t.Any, **kwargs: t.Any) -> ASTTransform: + result = fn(*args, **kwargs) + result.__name__ = fn.__name__ + return result + + return wrapper + + +@ast_transform +def resolve_dbt_ref_to_model_name( + models: MappingProxyType[str, t.Union[Model, str]], env: Environment, dialect: str +) -> JinjaTransform: + """ + Takes an expression like "{{ ref('foo') }}" + And turns it into "sqlmesh.foo" based on the provided list of models and resolver() function + + Args: + models: A dict of models (or model names) keyed by model fqn + jinja_env: Should contain an implementation of {{ ref() }} to turn a DBT relation name into a DBT relation object + + Returns: + A string containing the **model name** (not fqn) of the model referenced by the DBT "{{ ref() }}" call + """ + + ref: t.Callable = env.globals["ref"] # type: ignore + + def _resolve( + node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> t.Optional[j.Node]: + if isinstance(node, j.Call) and node.args and isinstance(node.args[0], j.Const): + ref_name = node.args[0].value + version = None + if version_kwarg := next((k for k in node.kwargs if k.key in ("version", "v")), None): + if isinstance(version_kwarg.value, j.Const): + version = version_kwarg.value.value + else: + # the version arg is present but its some kind of dynamic runtime value + # this means we cant resolve the ref to a model + return node + + if relation := ref(ref_name, version=version): + if not isinstance(relation, BaseRelation): + raise ValueError( + f"ref() returned non-relation type for '{ref_name}': {relation}" + ) + if model_name := _dbt_relation_to_model_name(models, relation, dialect): + return j.TemplateData(model_name) + return j.TemplateData(f"__unresolved_ref__.{ref_name}") + + return node + + return _make_standalone_call_transform("ref", _resolve) + + +@ast_transform +def rewrite_dbt_ref_to_migrated_ref( + models: MappingProxyType[str, t.Union[Model, str]], env: Environment, dialect: str +) -> JinjaTransform: + """ + Takes an expression like "{{ ref('foo') }}" + And turns it into "{{ __migrated_ref(database='foo', schema='bar', identifier='baz', sqlmesh_model_name='') }}" + so that the SQLMesh Native loader can construct a Relation instance without needing the Context + + Args: + models: A dict of models (or model names) keyed by model fqn + jinja_env: Should contain an implementation of {{ ref() }} to turn a DBT relation name into a DBT relation object + + Returns: + A new Call node with enough data to reconstruct the Relation + """ + + ref: t.Callable = env.globals["ref"] # type: ignore + + def _rewrite( + node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> t.Optional[j.Node]: + if isinstance(node, j.Call) and isinstance(node.node, j.Name) and node.node.name == "ref": + if node.args and isinstance(node.args[0], j.Const): + ref_name = node.args[0].value + version_kwarg = next((k for k in node.kwargs if k.key == "version"), None) + if (relation := ref(ref_name)) and isinstance(relation, BaseRelation): + if model_name := _dbt_relation_to_model_name(models, relation, dialect): + kwargs = _dbt_relation_to_kwargs(relation) + if version_kwarg: + kwargs.append(version_kwarg) + kwargs.append(j.Keyword("sqlmesh_model_name", j.Const(model_name))) + return j.Call(j.Name("__migrated_ref", "load"), [], kwargs, None, None) + + return node + + return _rewrite + + +@ast_transform +def resolve_dbt_source_to_model_name( + models: MappingProxyType[str, t.Union[Model, str]], env: Environment, dialect: str +) -> JinjaTransform: + """ + Takes an expression like "{{ source('foo', 'bar') }}" + And turns it into "foo.bar" based on the provided list of models and resolver() function + + Args: + models: A dict of models (or model names) keyed by model fqn + jinja_env: Should contain an implementation of {{ source() }} to turn a DBT source name / table name into a DBT relation object + + Returns: + A string containing the table fqn of the external table referenced by the DBT "{{ source() }}" call + """ + source: t.Callable = env.globals["source"] # type: ignore + + def _resolve( + node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> t.Optional[j.Node]: + if isinstance(node, j.Call) and isinstance(parent, (j.TemplateData, j.Output)): + if ( + len(node.args) == 2 + and isinstance(node.args[0], j.Const) + and isinstance(node.args[1], j.Const) + ): + source_name = node.args[0].value + table_name = node.args[1].value + if relation := source(source_name, table_name): + if not isinstance(relation, BaseRelation): + raise ValueError( + f"source() returned non-relation type for '{source_name}.{table_name}': {relation}" + ) + if model_name := _dbt_relation_to_model_name(models, relation, dialect): + return j.TemplateData(model_name) + return j.TemplateData(relation.render()) + # source() didnt resolve anything, just pass through the arguments verbatim + return j.TemplateData(f"{source_name}.{table_name}") + + return node + + return _make_standalone_call_transform("source", _resolve) + + +@ast_transform +def rewrite_dbt_source_to_migrated_source( + models: MappingProxyType[str, t.Union[Model, str]], env: Environment, dialect: str +) -> JinjaTransform: + """ + Takes an expression like "{{ source('foo', 'bar') }}" + And turns it into "{{ __migrated_source(database='foo', identifier='bar') }}" + so that the SQLMesh Native loader can construct a Relation instance without needing the Context + + Args: + models: A dict of models (or model names) keyed by model fqn + jinja_env: Should contain an implementation of {{ source() }} to turn a DBT source name / table name into a DBT relation object + + Returns: + A new Call node with enough data to reconstruct the Relation + """ + + source: t.Callable = env.globals["source"] # type: ignore + + def _rewrite( + node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> t.Optional[j.Node]: + if ( + isinstance(node, j.Call) + and isinstance(node.node, j.Name) + and node.node.name == "source" + ): + if ( + len(node.args) == 2 + and isinstance(node.args[0], j.Const) + and isinstance(node.args[1], j.Const) + ): + source_name = node.args[0].value + table_name = node.args[1].value + if (relation := source(source_name, table_name)) and isinstance( + relation, BaseRelation + ): + kwargs = _dbt_relation_to_kwargs(relation) + return j.Call(j.Name("__migrated_source", "load"), [], kwargs, None, None) + + return node + + return _rewrite + + +@ast_transform +def resolve_dbt_this_to_model_name(model_name: str) -> JinjaTransform: + """ + Takes an expression like "{{ this }}" and turns it into the provided "model_name" string + """ + return _make_single_expression_transform({"this": model_name}) + + +@ast_transform +def deduplicate_incremental_checks() -> JinjaTransform: + """ + Some files may have been designed to run with both the SQLMesh DBT loader and DBT itself and contain sections like: + + --- + select * from foo + where + {% if is_incremental() %}ds > (select max(ds)) from {{ this }}{% endif %} + {% if sqlmesh_incremental is defined %}ds BETWEEN {{ start_ds }} and {{ end_ds }}{% endif %} + --- + + This is transform detects usages of {% if sqlmesh_incremental ... %} + If it finds them, it: + - removes occurances of {% if is_incremental() %} in favour of the {% if sqlmesh_incremental %} check + + If no instances of {% if sqlmesh_incremental %} are found, nothing changes + + For for example, the above will be transformed into: + --- + select * from foo + where + ds BETWEEN {{ start_ds }} and {{ end_ds }} + --- + + But if it didnt contain the {% if sqlmesh_incremental %} block, this transform would output: + --- + select * from foo + where + {% if is_incremental() %}ds > (select max(ds)) from {{ this }}){% endif %} + --- + + """ + has_sqlmesh_incremental = False + + def _handle( + node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> t.Optional[j.Node]: + nonlocal has_sqlmesh_incremental + + if isinstance(node, j.Template): + for if_node in node.find_all(j.If): + if test_name := if_node.test.find(j.Name): + if test_name.name == "sqlmesh_incremental": + has_sqlmesh_incremental = True + + # only remove the {% if is_incremental() %} checks in the present of {% sqlmesh_incremental is defined %} checks + if has_sqlmesh_incremental: + if isinstance(node, j.If) and node.test: + if test_name := node.test.find(j.Name): + if test_name.name == "is_incremental": + return None + + return node + + return _handle + + +@ast_transform +def unpack_incremental_checks() -> JinjaTransform: + """ + This takes queries like: + + > select * from foo where {% if sqlmesh_incremental is defined %}ds BETWEEN {{ start_ds }} and {{ end_ds }}{% endif %} + > select * from foo where {% if is_incremental() %}ds > (select max(ds)) from foo.table){% endif %} + + And, if possible, removes the {% if sqlmesh_incremental is defined %} / {% is_incremental %} block to achieve: + + > select * from foo where ds BETWEEN {{ start_ds }} and {{ end_ds }} + > select * from foo where ds > (select max(ds)) from foo.table) + + Note that if there is a {% else %} portion to the block, there is no SQLMesh equivalent so in that case the check is untouched. + + Also, if both may be present in a model, run the deduplicate_incremental_checks() transform first so only one gets unpacked by this transform + """ + + def _handle(node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node]) -> j.Node: + if isinstance(node, j.If) and node.test: + if test_name := node.test.find(j.Name): + if ( + test_name.name in ("is_incremental", "sqlmesh_incremental") + and not node.elif_ + and not node.else_ + ): + return j.Output(node.body) + + return node + + return _handle + + +@ast_transform +def rewrite_sqlmesh_predefined_variables_to_sqlmesh_macro_syntax() -> JinjaTransform: + """ + If there are SQLMesh predefined variables in Jinja form, eg "{{ start_dt }}" + Rewrite them to eg "@start_dt" + + Example: + + select * from foo where ds between {{ start_dt }} and {{ end_dt }} + + > select * from foo where ds between @start_dt and @end_dt + """ + + mapping = {v: f"@{v}" for v in SQLMESH_PREDEFINED_MACRO_VARIABLES} + + literal_remapping = {"dt": "ts", "date": "ds"} + + def _mapping_func( + node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node], name: str + ) -> t.Optional[str]: + wrapped_in_literal = False + if prev and isinstance(prev, j.TemplateData): + data = prev.data.strip() + if data.endswith("'"): + wrapped_in_literal = True + + if wrapped_in_literal: + for original, new in literal_remapping.items(): + if name.endswith(original): + name = name.removesuffix(original) + new + + return mapping.get(name) + + return _make_single_expression_transform(_mapping_func) + + +@ast_transform +def append_dbt_package_kwarg_to_var_calls(package_name: t.Optional[str]) -> JinjaTransform: + """ " + If there are calls like: + + > {% if 'col_name' in var('history_columns') %} + + Assuming package_name=foo, change it to: + + > {% if 'col_name' in var('history_columns', __dbt_package="foo") %} + + The point of this is to give a hint to the "var" shim in SQLMesh Native so it knows which key + under "__dbt_packages__" in the project variables to look for + """ + + def _append( + node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] + ) -> t.Optional[j.Node]: + if package_name and isinstance(node, j.Call): + node.kwargs.append(j.Keyword("__dbt_package", j.Const(package_name))) + return node + + return _make_standalone_call_transform("var", _append) + + +@ast_transform +def unwrap_macros_in_string_literals() -> SQLGlotTransform: + """ + Given a query containing string literals *that match SQLMesh predefined macro variables* like: + + > select * from foo where ds between '@start_dt' and '@end_dt' + + Unwrap them into: + + > select * from foo where ds between @start_dt and @end_dt + """ + values_to_check = {f"@{var}": var for var in SQLMESH_PREDEFINED_MACRO_VARIABLES} + + def _transform(e: exp.Expression) -> exp.Expression: + if isinstance(e, exp.Literal) and e.is_string: + if (value := e.text("this")) and value in values_to_check: + return d.MacroVar( + this=values_to_check[value] + ) # MacroVar adds in the @ so dont want to add it twice + return e + + return _transform diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index 4f4100f092..672ad1ac3e 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -99,11 +99,12 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: for file in macro_files: self._track_file(file) - # This doesn't do anything, the actual content will be loaded from the manifest - return ( - macro.get_registry(), - JinjaMacroRegistry(), - ) + jinja_macros = JinjaMacroRegistry() + for project in self._load_projects(): + jinja_macros = jinja_macros.merge(project.context.jinja_macros) + jinja_macros.add_globals(project.context.jinja_globals) + + return (macro.get_registry(), jinja_macros) def _load_models( self, diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index 51cfd06c88..4cbca09aee 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -567,6 +567,7 @@ def to_sqlmesh( kind=kind, start=self.start, audit_definitions=audit_definitions, + path=model_kwargs.pop("path", self.path), # This ensures that we bypass query rendering that would otherwise be required to extract additional # dependencies from the model's SQL. # Note: any table dependencies that are not referenced using the `ref` macro will not be included. diff --git a/sqlmesh/dbt/target.py b/sqlmesh/dbt/target.py index e7603232e8..05985d8762 100644 --- a/sqlmesh/dbt/target.py +++ b/sqlmesh/dbt/target.py @@ -83,26 +83,8 @@ def load(cls, data: t.Dict[str, t.Any]) -> TargetConfig: The configuration of the provided profile target """ db_type = data["type"] - if db_type == "databricks": - return DatabricksConfig(**data) - if db_type == "duckdb": - return DuckDbConfig(**data) - if db_type == "postgres": - return PostgresConfig(**data) - if db_type == "redshift": - return RedshiftConfig(**data) - if db_type == "snowflake": - return SnowflakeConfig(**data) - if db_type == "bigquery": - return BigQueryConfig(**data) - if db_type == "sqlserver": - return MSSQLConfig(**data) - if db_type == "trino": - return TrinoConfig(**data) - if db_type == "clickhouse": - return ClickhouseConfig(**data) - if db_type == "athena": - return AthenaConfig(**data) + if config_class := TARGET_TYPE_TO_CONFIG_CLASS.get(db_type): + return config_class(**data) raise ConfigError(f"{db_type} not supported.") @@ -114,6 +96,10 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: """Converts target config to SQLMesh connection config""" raise NotImplementedError + @classmethod + def from_sqlmesh(cls, config: ConnectionConfig, **kwargs: t.Dict[str, t.Any]) -> "TargetConfig": + raise NotImplementedError + def attribute_dict(self) -> AttributeDict: fields = self.dict(include=SERIALIZABLE_FIELDS).copy() fields["target_name"] = self.name @@ -202,6 +188,18 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: **kwargs, ) + @classmethod + def from_sqlmesh(cls, config: ConnectionConfig, **kwargs: t.Dict[str, t.Any]) -> "DuckDbConfig": + if not isinstance(config, DuckDBConnectionConfig): + raise ValueError(f"Incorrect config type: {type(config)}") + + return cls( + path=config.database, + extensions=config.extensions, + settings=config.connector_config, + **kwargs, + ) + class SnowflakeConfig(TargetConfig): """ @@ -372,6 +370,28 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: **kwargs, ) + @classmethod + def from_sqlmesh( + cls, config: ConnectionConfig, **kwargs: t.Dict[str, t.Any] + ) -> "PostgresConfig": + if not isinstance(config, PostgresConnectionConfig): + raise ValueError(f"Incorrect config type: {type(config)}") + + return cls( + schema="public", + host=config.host, + user=config.user, + password=config.password, + port=config.port, + dbname=config.database, + keepalives_idle=config.keepalives_idle, + threads=config.concurrent_tasks, + connect_timeout=config.connect_timeout, + role=config.role, + sslmode=config.sslmode, + **kwargs, + ) + class RedshiftConfig(TargetConfig): """ @@ -613,6 +633,39 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: **kwargs, ) + @classmethod + def from_sqlmesh( + cls, config: ConnectionConfig, **kwargs: t.Dict[str, t.Any] + ) -> "BigQueryConfig": + if not isinstance(config, BigQueryConnectionConfig): + raise ValueError(f"Incorrect config type: {type(config)}") + + return cls( + schema="__unknown__", + method=config.method, + project=config.project, + execution_project=config.execution_project, + quota_project=config.quota_project, + location=config.location, + threads=config.concurrent_tasks, + keyfile=config.keyfile, + keyfile_json=config.keyfile_json, + token=config.token, + refresh_token=config.refresh_token, + client_id=config.client_id, + client_secret=config.client_secret, + token_uri=config.token_uri, + scopes=config.scopes, + impersonated_service_account=config.impersonated_service_account, + job_creation_timeout_seconds=config.job_creation_timeout_seconds, + job_execution_timeout_seconds=config.job_execution_timeout_seconds, + job_retries=config.job_retries, + job_retry_deadline_seconds=config.job_retry_deadline_seconds, + priority=config.priority, + maximum_bytes_billed=config.maximum_bytes_billed, + **kwargs, + ) + class MSSQLConfig(TargetConfig): """ diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index dcb09296b8..711f760b7b 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -22,6 +22,7 @@ CallNames = t.Tuple[t.Tuple[str, ...], t.Union[nodes.Call, nodes.Getattr]] SQLMESH_JINJA_PACKAGE = "sqlmesh.utils.jinja" +SQLMESH_DBT_COMPATIBILITY_PACKAGE = "sqlmesh.dbt.converter.jinja_builtins" def environment(**kwargs: t.Any) -> Environment: @@ -94,7 +95,11 @@ def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: macro_str = self._find_sql(macro_start, self._next) macros[name] = MacroInfo( definition=macro_str, - depends_on=list(extract_macro_references_and_variables(macro_str)[0]), + depends_on=list( + extract_macro_references_and_variables(macro_str, dbt_target_name=dialect)[ + 0 + ] + ), ) self._advance() @@ -166,18 +171,86 @@ def parse() -> t.List[CallNames]: return parse() +def extract_dbt_adapter_dispatch_targets(jinja_str: str) -> t.List[t.Tuple[str, t.Optional[str]]]: + """ + Given a jinja string, identify {{ adapter.dispatch('foo','bar') }} calls and extract the (foo, bar) part as a tuple + """ + ast = ENVIRONMENT.parse(jinja_str) + + extracted = [] + + def _extract(node: nodes.Node, parent: t.Optional[nodes.Node] = None) -> None: + if ( + isinstance(node, nodes.Getattr) + and isinstance(parent, nodes.Call) + and (node_name := node.find(nodes.Name)) + ): + if node_name.name == "adapter" and node.attr == "dispatch": + call_args = [arg.value for arg in parent.args if isinstance(arg, nodes.Const)][0:2] + if len(call_args) == 1: + call_args.append(None) + macro_name, package = call_args + extracted.append((macro_name, package)) + + for child_node in node.iter_child_nodes(): + _extract(child_node, parent=node) + + _extract(ast) + + return extracted + + def extract_macro_references_and_variables( - *jinja_strs: str, + *jinja_strs: str, dbt_target_name: t.Optional[str] = None ) -> t.Tuple[t.Set[MacroReference], t.Set[str]]: macro_references = set() variables = set() for jinja_str in jinja_strs: + if dbt_target_name and "adapter.dispatch" in jinja_str: + for dispatch_target_name, package in extract_dbt_adapter_dispatch_targets(jinja_str): + # here we are guessing at the macro names that the {{ adapter.dispatch() }} call will invoke + # there is a defined resolution order: https://docs.getdbt.com/reference/dbt-jinja-functions/dispatch + # we rely on JinjaMacroRegistry.trim() to tune the dependencies down into just the ones that actually exist + macro_references.add( + MacroReference(package=package, name=f"default__{dispatch_target_name}") + ) + macro_references.add( + MacroReference( + package=package, name=f"{dbt_target_name}__{dispatch_target_name}" + ) + ) + if package and package.startswith("dbt"): + # handle the case where macros like `current_timestamp()` in the `dbt` package expect an implementation in eg the `dbt_bigquery` package + macro_references.add( + MacroReference( + package=f"dbt_{dbt_target_name}", + name=f"{dbt_target_name}__{dispatch_target_name}", + ) + ) + for call_name, node in extract_call_names(jinja_str): if call_name[0] == c.VAR: assert isinstance(node, nodes.Call) args = [jinja_call_arg_name(arg) for arg in node.args] if args and args[0]: - variables.add(args[0].lower()) + variable_name = args[0].lower() + + # check if this {{ var() }} reference is from a migrated DBT package + # if it is, there will be a __dbt_package= kwarg + dbt_package = next( + ( + kwarg.value + for kwarg in node.kwargs + if isinstance(kwarg, nodes.Keyword) and kwarg.key == "__dbt_package" + ), + None, + ) + if dbt_package and isinstance(dbt_package, nodes.Const): + dbt_package = dbt_package.value + # this convention is a flat way of referencing the nested values under `__dbt_packages__` in the SQLMesh project variables + variable_name = f"{c.MIGRATED_DBT_PACKAGES}.{dbt_package}.{variable_name}" + + variables.add(variable_name) elif call_name[0] == c.GATEWAY: variables.add(c.GATEWAY) elif len(call_name) == 1: @@ -255,6 +328,19 @@ def _convert( def trimmed(self) -> bool: return self._trimmed + @property + def all_macros(self) -> t.Iterable[t.Tuple[t.Optional[str], str, MacroInfo]]: + """ + Returns (package, macro_name, MacroInfo) tuples for every macro in this registry + Root macros will have package=None + """ + for name, macro in self.root_macros.items(): + yield None, name, macro + + for package, macros in self.packages.items(): + for name, macro in macros.items(): + yield (package, name, macro) + def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: """Adds macros to the target package. @@ -593,7 +679,12 @@ def jinja_call_arg_name(node: nodes.Node) -> str: def create_var(variables: t.Dict[str, t.Any]) -> t.Callable: - def _var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: + def _var( + var_name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any + ) -> t.Optional[t.Any]: + if dbt_package := kwargs.get("__dbt_package"): + var_name = f"{c.MIGRATED_DBT_PACKAGES}.{dbt_package}.{var_name}" + value = variables.get(var_name.lower(), default) if isinstance(value, SqlValue): return value.sql diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 44ef495737..dea9fb16da 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -28,6 +28,7 @@ from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter from sqlmesh.core.engine_adapter.redshift import RedshiftEngineAdapter +from sqlmesh.core.loader import MigratedDbtProjectLoader from sqlmesh.core.notification_target import ConsoleNotificationTarget from sqlmesh.core.user import User from sqlmesh.utils.errors import ConfigError @@ -1028,3 +1029,29 @@ def test_config_complex_types_supplied_as_json_strings_from_env(tmp_path: Path) assert conn.project == "unit-test" assert conn.scopes == ("a", "b", "c") assert conn.keyfile_json == {"foo": "bar"} + + +def test_loader_for_migrated_dbt_project(tmp_path: Path): + config_path = tmp_path / "config.yaml" + config_path.write_text(""" + gateways: + bigquery: + connection: + type: bigquery + project: unit-test + + default_gateway: bigquery + + model_defaults: + dialect: bigquery + + variables: + __dbt_project_name__: sushi +""") + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + assert config.loader == MigratedDbtProjectLoader diff --git a/tests/core/test_loader.py b/tests/core/test_loader.py index 2c648e7718..a616f520ef 100644 --- a/tests/core/test_loader.py +++ b/tests/core/test_loader.py @@ -4,6 +4,9 @@ from sqlmesh.core.config import Config, ModelDefaultsConfig from sqlmesh.core.context import Context from sqlmesh.utils.errors import ConfigError +import sqlmesh.core.constants as c +from sqlmesh.core.config import load_config_from_yaml +from sqlmesh.utils.yaml import dump @pytest.fixture @@ -201,3 +204,129 @@ def my_model(context, **kwargs): assert model.description == "model_payload_a" path_b.write_text(model_payload_b) context.load() # raise no error to duplicate key if the functions are identical (by registry class_method) + + +def test_load_migrated_dbt_adapter_dispatch_macros(tmp_path: Path): + init_example_project(tmp_path, dialect="duckdb") + + migrated_package_path = tmp_path / "macros" / c.MIGRATED_DBT_PACKAGES / "dbt_utils" + migrated_package_path.mkdir(parents=True) + + (migrated_package_path / "deduplicate.sql").write_text(""" + {%- macro deduplicate(relation) -%} + {{ return(adapter.dispatch('deduplicate', 'dbt_utils')(relation)) }} + {% endmacro %} + """) + + (migrated_package_path / "default__deduplicate.sql").write_text(""" + {%- macro default__deduplicate(relation) -%} + select 'default impl' from {{ relation }} + {% endmacro %} + """) + + (migrated_package_path / "duckdb__deduplicate.sql").write_text(""" + {%- macro duckdb__deduplicate(relation) -%} + select 'duckdb impl' from {{ relation }} + {% endmacro %} + """) + + # this should be pruned from the JinjaMacroRegistry because the target is duckdb, not bigquery + (migrated_package_path / "bigquery__deduplicate.sql").write_text(""" + {%- macro bigquery__deduplicate(relation) -%} + select 'bigquery impl' from {{ relation }} + {% endmacro %} + """) + + (tmp_path / "models" / "test_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.test, + kind FULL, + ); +JINJA_QUERY_BEGIN; +{{ dbt_utils.deduplicate(__migrated_ref(schema='sqlmesh_example', identifier='full_model')) }} +JINJA_END; + """) + + config_path = tmp_path / "config.yaml" + assert config_path.exists() + config = load_config_from_yaml(config_path) + config["variables"] = {} + config["variables"][c.MIGRATED_DBT_PROJECT_NAME] = "test" + + config_path.write_text(dump(config)) + + ctx = Context(paths=tmp_path) + + model = ctx.models['"db"."sqlmesh_example"."test"'] + assert model.dialect == "duckdb" + assert {(package, name) for package, name, _ in model.jinja_macros.all_macros} == { + ("dbt_utils", "deduplicate"), + ("dbt_utils", "default__deduplicate"), + ("dbt_utils", "duckdb__deduplicate"), + } + + assert ( + model.render_query_or_raise().sql(dialect="duckdb") + == """SELECT \'duckdb impl\' AS "duckdb impl" FROM "db"."sqlmesh_example"."full_model" AS "full_model\"""" + ) + + +def test_load_migrated_dbt_adapter_dispatch_macros_in_different_packages(tmp_path: Path): + # some things like dbt.current_timestamp() dispatch to macros in a different package + init_example_project(tmp_path, dialect="duckdb") + + migrated_package_path_dbt = tmp_path / "macros" / c.MIGRATED_DBT_PACKAGES / "dbt" + migrated_package_path_dbt_duckdb = tmp_path / "macros" / c.MIGRATED_DBT_PACKAGES / "dbt_duckdb" + migrated_package_path_dbt.mkdir(parents=True) + migrated_package_path_dbt_duckdb.mkdir(parents=True) + + (migrated_package_path_dbt / "current_timestamp.sql").write_text(""" + {%- macro current_timestamp(relation) -%} + {{ return(adapter.dispatch('current_timestamp', 'dbt')()) }} + {% endmacro %} + """) + + (migrated_package_path_dbt / "default__current_timestamp.sql").write_text(""" + {% macro default__current_timestamp() -%} + {{ exceptions.raise_not_implemented('current_timestamp macro not implemented') }} + {%- endmacro %} + """) + + (migrated_package_path_dbt_duckdb / "duckdb__current_timestamp.sql").write_text(""" + {%- macro duckdb__current_timestamp() -%} + 'duckdb current_timestamp impl' + {% endmacro %} + """) + + (tmp_path / "models" / "test_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.test, + kind FULL, + ); +JINJA_QUERY_BEGIN; +select {{ dbt.current_timestamp() }} as a +JINJA_END; + """) + + config_path = tmp_path / "config.yaml" + assert config_path.exists() + config = load_config_from_yaml(config_path) + config["variables"] = {} + config["variables"][c.MIGRATED_DBT_PROJECT_NAME] = "test" + + config_path.write_text(dump(config)) + + ctx = Context(paths=tmp_path) + + model = ctx.models['"db"."sqlmesh_example"."test"'] + assert model.dialect == "duckdb" + assert {(package, name) for package, name, _ in model.jinja_macros.all_macros} == { + ("dbt", "current_timestamp"), + ("dbt", "default__current_timestamp"), + ("dbt_duckdb", "duckdb__current_timestamp"), + } + + assert ( + model.render_query_or_raise().sql(dialect="duckdb") + == "SELECT 'duckdb current_timestamp impl' AS \"a\"" + ) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 7a2f808e12..ad083e9ae2 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -15,6 +15,7 @@ from sqlmesh.cli.example_project import init_example_project, ProjectTemplate from sqlmesh.core.environment import EnvironmentNamingInfo from sqlmesh.core.model.kind import TimeColumn, ModelKindName +from pydantic import ValidationError from sqlmesh import CustomMaterialization, CustomKind from pydantic import model_validator, ValidationError @@ -42,6 +43,7 @@ FullKind, IncrementalByTimeRangeKind, IncrementalUnmanagedKind, + IncrementalByUniqueKeyKind, ModelCache, ModelMeta, SeedKind, @@ -63,7 +65,13 @@ from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory from sqlmesh.utils.date import TimeLike, to_datetime, to_ds, to_timestamp from sqlmesh.utils.errors import ConfigError, SQLMeshError, LinterError -from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo, MacroExtractor +from sqlmesh.utils.jinja import ( + JinjaMacroRegistry, + MacroInfo, + MacroExtractor, + MacroReference, + SQLMESH_DBT_COMPATIBILITY_PACKAGE, +) from sqlmesh.utils.metaprogramming import Executable, SqlValue from sqlmesh.core.macros import RuntimeStage from tests.utils.test_helpers import use_terminal_console @@ -6171,6 +6179,58 @@ def model_with_variables(context, **kwargs): assert df.to_dict(orient="records") == [{"a": "test_value", "b": "default_value", "c": None}] +def test_variables_migrated_dbt_package_macro(): + expressions = parse( + """ + MODEL( + name test_model, + kind FULL, + ); + + JINJA_QUERY_BEGIN; + SELECT '{{ var('TEST_VAR_A') }}' as a, '{{ test.test_macro_var() }}' as b + JINJA_END; + """, + default_dialect="bigquery", + ) + + jinja_macros = JinjaMacroRegistry( + create_builtins_module=SQLMESH_DBT_COMPATIBILITY_PACKAGE, + packages={ + "test": { + "test_macro_var": MacroInfo( + definition=""" + {% macro test_macro_var() %} + {{- var('test_var_b', __dbt_package='test') }} + {%- endmacro %}""", + depends_on=[MacroReference(name="var")], + ) + } + }, + ) + + model = load_sql_based_model( + expressions, + variables={ + "test_var_a": "test_var_a_value", + c.MIGRATED_DBT_PACKAGES: { + "test": {"test_var_b": "test_var_b_value", "unused": "unused_value"}, + }, + "test_var_c": "test_var_c_value", + }, + jinja_macros=jinja_macros, + migrated_dbt_project_name="test", + dialect="bigquery", + ) + assert model.python_env[c.SQLMESH_VARS] == Executable.value( + {"test_var_a": "test_var_a_value", "__dbt_packages__.test.test_var_b": "test_var_b_value"} + ) + assert ( + model.render_query().sql(dialect="bigquery") + == "SELECT 'test_var_a_value' AS `a`, 'test_var_b_value' AS `b`" + ) + + def test_load_external_model_python(sushi_context) -> None: @model( "test_load_external_model_python", @@ -7727,6 +7787,37 @@ def test_model_kind_to_expression(): ) +def test_incremental_by_unique_key_batch_concurrency(): + with pytest.raises(ValidationError, match=r"Input should be 1"): + load_sql_based_model( + d.parse(""" + MODEL ( + name db.table, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key a, + batch_concurrency 2 + ) + ); + select 1; + """) + ) + + model = load_sql_based_model( + d.parse(""" + MODEL ( + name db.table, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key a, + batch_concurrency 1 + ) + ); + select 1; + """) + ) + assert isinstance(model.kind, IncrementalByUniqueKeyKind) + assert model.kind.batch_concurrency == 1 + + def test_bad_model_kind(): with pytest.raises( SQLMeshError, diff --git a/tests/dbt/converter/conftest.py b/tests/dbt/converter/conftest.py new file mode 100644 index 0000000000..e8dffeb263 --- /dev/null +++ b/tests/dbt/converter/conftest.py @@ -0,0 +1,21 @@ +from pathlib import Path +import typing as t +import pytest +from sqlmesh.core.context import Context + + +@pytest.fixture +def sushi_dbt_context(copy_to_temp_path: t.Callable) -> Context: + return Context(paths=copy_to_temp_path("examples/sushi_dbt")) + + +@pytest.fixture +def empty_dbt_context(copy_to_temp_path: t.Callable) -> Context: + fixture_path = Path(__file__).parent / "fixtures" / "empty_dbt_project" + assert fixture_path.exists() + + actual_path = copy_to_temp_path(fixture_path)[0] + + ctx = Context(paths=actual_path) + + return ctx diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/.gitignore b/tests/dbt/converter/fixtures/empty_dbt_project/.gitignore new file mode 100644 index 0000000000..232ccd1d8c --- /dev/null +++ b/tests/dbt/converter/fixtures/empty_dbt_project/.gitignore @@ -0,0 +1,2 @@ +target/ +logs/ diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/analyses/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/analyses/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/config.py b/tests/dbt/converter/fixtures/empty_dbt_project/config.py new file mode 100644 index 0000000000..e7e28c98e4 --- /dev/null +++ b/tests/dbt/converter/fixtures/empty_dbt_project/config.py @@ -0,0 +1,7 @@ +from pathlib import Path + +from sqlmesh.dbt.loader import sqlmesh_config + +config = sqlmesh_config(Path(__file__).parent) + +test_config = config diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/dbt_project.yml b/tests/dbt/converter/fixtures/empty_dbt_project/dbt_project.yml new file mode 100644 index 0000000000..007649e553 --- /dev/null +++ b/tests/dbt/converter/fixtures/empty_dbt_project/dbt_project.yml @@ -0,0 +1,22 @@ + +name: 'test' +version: '1.0.0' +config-version: 2 +profile: 'test' + +model-paths: ["models"] +analysis-paths: ["analyses"] +test-paths: ["tests"] +seed-paths: ["seeds"] +macro-paths: ["macros"] +snapshot-paths: ["snapshots"] + +target-path: "target" + +models: + +start: Jan 1 2022 + +seeds: + +schema: raw + +vars: {} diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/macros/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/macros/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/models/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/models/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/models/sources.yml b/tests/dbt/converter/fixtures/empty_dbt_project/models/sources.yml new file mode 100644 index 0000000000..49354831f4 --- /dev/null +++ b/tests/dbt/converter/fixtures/empty_dbt_project/models/sources.yml @@ -0,0 +1,6 @@ +version: 2 + +sources: + - name: external + tables: + - name: orders diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/packages/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/packages/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/profiles.yml b/tests/dbt/converter/fixtures/empty_dbt_project/profiles.yml new file mode 100644 index 0000000000..6d91ecbe65 --- /dev/null +++ b/tests/dbt/converter/fixtures/empty_dbt_project/profiles.yml @@ -0,0 +1,6 @@ +test: + outputs: + in_memory: + type: duckdb + schema: project + target: in_memory diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/seeds/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/seeds/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/seeds/items.csv b/tests/dbt/converter/fixtures/empty_dbt_project/seeds/items.csv new file mode 100644 index 0000000000..0f87cb2507 --- /dev/null +++ b/tests/dbt/converter/fixtures/empty_dbt_project/seeds/items.csv @@ -0,0 +1,94 @@ +id,name,price,ds +0,Maguro,4.34,2022-01-01 +1,Ika,7.35,2022-01-01 +2,Aji,6.06,2022-01-01 +3,Hotate,8.5,2022-01-01 +4,Escolar,8.46,2022-01-01 +5,Sake,4.91,2022-01-01 +6,Tamago,4.94,2022-01-01 +7,Umi Masu,8.61,2022-01-01 +8,Bincho,9.71,2022-01-01 +9,Toro,9.13,2022-01-01 +10,Aoyagi,5.5,2022-01-01 +11,Hamachi,6.51,2022-01-01 +12,Tobiko,7.78,2022-01-01 +13,Unagi,7.99,2022-01-01 +14,Tako,5.59,2022-01-01 +0,Kani,8.22,2022-01-02 +1,Amaebi,9.14,2022-01-02 +2,Uni,4.55,2022-01-02 +3,Sake Toro,5.01,2022-01-02 +4,Maguro,9.95,2022-01-02 +5,Katsuo,9.03,2022-01-02 +6,Hamachi Toro,3.76,2022-01-02 +7,Iwashi,5.56,2022-01-02 +8,Tamago,6.96,2022-01-02 +9,Tai,5.84,2022-01-02 +10,Ika,3.23,2022-01-02 +0,Hirame,7.74,2022-01-03 +1,Uni,3.98,2022-01-03 +2,Tai,4.09,2022-01-03 +3,Kanpachi,7.55,2022-01-03 +4,Tobiko,9.87,2022-01-03 +5,Hotate,7.86,2022-01-03 +6,Iwashi,8.33,2022-01-03 +7,Ikura,5.98,2022-01-03 +8,Maguro,3.97,2022-01-03 +9,Tsubugai,4.51,2022-01-03 +10,Tako,8.35,2022-01-03 +11,Sake,3.38,2022-01-03 +12,Tamago,6.43,2022-01-03 +13,Ika,4.26,2022-01-03 +14,Unagi,7.42,2022-01-03 +0,Ikura,5.02,2022-01-04 +1,Tobiko,9.15,2022-01-04 +2,Hamachi,6.66,2022-01-04 +3,Bincho,8.4,2022-01-04 +4,Tsubugai,5.26,2022-01-04 +5,Hotate,8.92,2022-01-04 +6,Toro,7.52,2022-01-04 +7,Aji,7.49,2022-01-04 +8,Ebi,5.67,2022-01-04 +9,Kanpachi,7.51,2022-01-04 +10,Kani,6.97,2022-01-04 +11,Hirame,4.51,2022-01-04 +0,Saba,7.41,2022-01-05 +1,Unagi,8.45,2022-01-05 +2,Uni,3.67,2022-01-05 +3,Maguro,8.76,2022-01-05 +4,Katsuo,5.99,2022-01-05 +5,Bincho,9.15,2022-01-05 +6,Sake Toro,3.67,2022-01-05 +7,Aji,9.55,2022-01-05 +8,Umi Masu,9.88,2022-01-05 +9,Hamachi,6.53,2022-01-05 +10,Tai,6.83,2022-01-05 +11,Tsubugai,4.62,2022-01-05 +12,Ikura,4.86,2022-01-05 +13,Ahi,9.66,2022-01-05 +14,Hotate,7.85,2022-01-05 +0,Hamachi Toro,4.87,2022-01-06 +1,Ika,3.26,2022-01-06 +2,Kanpachi,8.63,2022-01-06 +3,Hirame,5.34,2022-01-06 +4,Katsuo,9.24,2022-01-06 +5,Iwashi,8.67,2022-01-06 +6,Sake Toro,9.75,2022-01-06 +7,Bincho,9.7,2022-01-06 +8,Aji,7.14,2022-01-06 +9,Hokigai,5.18,2022-01-06 +10,Umi Masu,9.43,2022-01-06 +11,Unagi,3.35,2022-01-06 +12,Sake,4.58,2022-01-06 +13,Aoyagi,5.54,2022-01-06 +0,Amaebi,6.94,2022-01-07 +1,Ebi,7.84,2022-01-07 +2,Saba,5.28,2022-01-07 +3,Anago,4.53,2022-01-07 +4,Escolar,7.28,2022-01-07 +5,Ahi,6.48,2022-01-07 +6,Katsuo,5.16,2022-01-07 +7,Umi Masu,6.09,2022-01-07 +8,Maguro,7.7,2022-01-07 +9,Hokigai,7.37,2022-01-07 +10,Sake Toro,6.99,2022-01-07 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/seeds/properties.yml b/tests/dbt/converter/fixtures/empty_dbt_project/seeds/properties.yml new file mode 100644 index 0000000000..86ce6964fe --- /dev/null +++ b/tests/dbt/converter/fixtures/empty_dbt_project/seeds/properties.yml @@ -0,0 +1,13 @@ +version: 2 + +seeds: + - name: items + columns: + - name: id + description: Item id + - name: name + description: Name of the item + - name: price + description: Price of the item + - name: ds + description: Date \ No newline at end of file diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/snapshots/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/snapshots/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/tests/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/tests/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dbt/converter/fixtures/jinja_nested_if.sql b/tests/dbt/converter/fixtures/jinja_nested_if.sql new file mode 100644 index 0000000000..e7a1bed137 --- /dev/null +++ b/tests/dbt/converter/fixtures/jinja_nested_if.sql @@ -0,0 +1,15 @@ +{% if foo == 'bar' %} + baz + {% if baz == 'bing' %} + bong + {% else %} + qux + {% endif %} +{% elif a == fn(b) %} + {% if c == 'f' and fn1(a, c, 'foo') == 'test' %} + output1 + {% elif z is defined %} + output2 + {% endif %} + output +{% endif %} \ No newline at end of file diff --git a/tests/dbt/converter/fixtures/macro_dbt_incremental.sql b/tests/dbt/converter/fixtures/macro_dbt_incremental.sql new file mode 100644 index 0000000000..a76f60713b --- /dev/null +++ b/tests/dbt/converter/fixtures/macro_dbt_incremental.sql @@ -0,0 +1,11 @@ +{% macro incremental_by_time(col, time_type) %} + {% if is_incremental() %} + WHERE + {{ col }} > (select max({{ col }}) from {{ this }}) + {% endif %} + {% if sqlmesh_incremental is defined %} + {% set dates = incremental_dates_by_time_type(time_type) %} + WHERE + {{ col }} BETWEEN '{{ dates[0] }}' AND '{{ dates[1] }}' + {% endif %} +{% endmacro %} \ No newline at end of file diff --git a/tests/dbt/converter/fixtures/macro_func_with_params.sql b/tests/dbt/converter/fixtures/macro_func_with_params.sql new file mode 100644 index 0000000000..06bb757ef9 --- /dev/null +++ b/tests/dbt/converter/fixtures/macro_func_with_params.sql @@ -0,0 +1,17 @@ +{% macro func_with_params(amount, category) %} + case + {% for row in [ + { 'category': '1', 'range': [0, 10], 'consider': True }, + { 'category': '2', 'range': [11, 20], 'consider': None } + ] %} + when {{ category }} = '{{ row.category }}' + and {{ amount }} >= {{ row.range[0] }} + {% if row.consider is not none %} + and {{ amount }} < {{ row.range[1] }} + {% endif %} + then + ({{ amount }} * {{ row.range[0] }} + {{ row.range[1] }}) * 4 + {% endfor %} + else null + end +{% endmacro %} \ No newline at end of file diff --git a/tests/dbt/converter/fixtures/model_query_incremental.sql b/tests/dbt/converter/fixtures/model_query_incremental.sql new file mode 100644 index 0000000000..a9603dbcbb --- /dev/null +++ b/tests/dbt/converter/fixtures/model_query_incremental.sql @@ -0,0 +1,34 @@ +WITH cte AS ( + SELECT + oi.order_id AS order_id, + FROM {{ ref('order_items') }} AS oi + LEFT JOIN {{ ref('items') }} AS i + ON oi.item_id = i.id AND oi.ds = i.ds +{% if is_incremental() %} +WHERE + oi.ds > (select max(ds) from {{ this }}) +{% endif %} +{% if sqlmesh_incremental is defined %} +WHERE + oi.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' +{% endif %} +GROUP BY + oi.order_id, + oi.ds +) +SELECT + o.customer_id::INT AS customer_id, /* Customer id */ + SUM(ot.total)::NUMERIC AS revenue, /* Revenue from orders made by this customer */ + o.ds::TEXT AS ds /* Date */ +FROM {{ ref('orders') }} AS o + LEFT JOIN order_total AS ot + ON o.id = ot.order_id AND o.ds = ot.ds +{% if is_incremental() %} + WHERE o.ds > (select max(ds) from {{ this }}) +{% endif %} +{% if sqlmesh_incremental is defined %} + WHERE o.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' +{% endif %} +GROUP BY + o.customer_id, + o.ds \ No newline at end of file diff --git a/tests/dbt/converter/test_convert.py b/tests/dbt/converter/test_convert.py new file mode 100644 index 0000000000..001b1f82cc --- /dev/null +++ b/tests/dbt/converter/test_convert.py @@ -0,0 +1,105 @@ +from pathlib import Path +from sqlmesh.core.context import Context +from sqlmesh.dbt.converter.convert import convert_project_files, resolve_fqns_to_model_names +import uuid +import sqlmesh.core.constants as c + + +def test_convert_project_files(sushi_dbt_context: Context, tmp_path: Path) -> None: + src_context = sushi_dbt_context + src_path = sushi_dbt_context.path + output_path = tmp_path / f"output_{uuid.uuid4().hex}" + + convert_project_files(src_path, output_path) + + target_context = Context(paths=output_path) + + assert src_context.models.keys() == target_context.models.keys() + + target_context.plan(auto_apply=True) + + +def test_convert_project_files_includes_library_macros( + sushi_dbt_context: Context, tmp_path: Path +) -> None: + src_path = sushi_dbt_context.path + output_path = tmp_path / f"output_{uuid.uuid4().hex}" + + (src_path / "macros" / "call_library.sql").write_text(""" +{% macro call_library() %} + {{ dbt.current_timestamp() }} +{% endmacro %} +""") + + convert_project_files(src_path, output_path) + + migrated_output_macros_path = output_path / "macros" / c.MIGRATED_DBT_PACKAGES + assert (migrated_output_macros_path / "dbt" / "current_timestamp.sql").exists() + # note: the DBT manifest is smart enough to prune "dbt / default__current_timestamp.sql" from the list so it is not migrated + assert (migrated_output_macros_path / "dbt_duckdb" / "duckdb__current_timestamp.sql").exists() + + +def test_resolve_fqns_to_model_names(empty_dbt_context: Context) -> None: + ctx = empty_dbt_context + + # macro that uses a property of {{ ref() }} and also creates another ref() + (ctx.path / "macros" / "foo.sql").write_text( + """ +{% macro foo(relation) %} + {{ relation.name }} r + left join {{ source('external', 'orders') }} et + on r.id = et.id +{% endmacro %} +""" + ) + + # model 1 - can be fully unwrapped + (ctx.path / "models" / "model1.sql").write_text( + """ +{{ + config( + materialized='incremental', + incremental_strategy='delete+insert', + time_column='ds' + ) +}} + +select * from {{ ref('items') }} +{% if is_incremental() %} + where ds > (select max(ds) from {{ this }}) +{% endif %} +""" + ) + + # model 2 - has ref passed to macro as parameter and also another ref nested in macro + (ctx.path / "models" / "model2.sql").write_text( + """ +select * from {{ foo(ref('model1')) }} union select * from {{ ref('items') }} +""" + ) + + ctx.load() + + assert len(ctx.models) == 3 + + model1 = ctx.models['"memory"."project"."model1"'] + model2 = ctx.models['"memory"."project"."model2"'] + + assert model1.depends_on == {'"memory"."project_raw"."items"'} + assert model2.depends_on == { + '"memory"."project"."model1"', + '"memory"."external"."orders"', + '"memory"."project_raw"."items"', + } + + # All dependencies in model 1 can be tracked by the native loader but its very difficult to cover all the edge cases at conversion time + # so we still populate depends_on() + assert resolve_fqns_to_model_names(ctx, model1.depends_on) == {"project_raw.items"} + + # For model 2, the external model "external.orders" should be removed from depends_on + # If it was output verbatim as depends_on ("memory"."external"."orders"), the native loader would throw an error like: + # - Error: Failed to load model definition, 'Dot' object is not iterable + assert resolve_fqns_to_model_names(ctx, model2.depends_on) == { + "project.model1", + "project_raw.items", + } diff --git a/tests/dbt/converter/test_jinja.py b/tests/dbt/converter/test_jinja.py new file mode 100644 index 0000000000..5d9e8f3d73 --- /dev/null +++ b/tests/dbt/converter/test_jinja.py @@ -0,0 +1,439 @@ +import pytest +from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor +from sqlmesh.dbt.converter.jinja import JinjaGenerator, convert_jinja_query, convert_jinja_macro +import sqlmesh.dbt.converter.jinja_transforms as jt +from pathlib import Path +from sqlmesh.core.context import Context +import sqlmesh.core.dialect as d +from sqlglot import exp +from _pytest.mark.structures import ParameterSet +from sqlmesh.core.model import SqlModel, load_sql_based_model +from sqlmesh.utils import columns_to_types_all_known + + +def _load_fixture(name: str) -> ParameterSet: + return pytest.param( + (Path(__file__).parent / "fixtures" / name).read_text(encoding="utf8"), id=name + ) + + +@pytest.mark.parametrize( + "original_jinja", + [ + "select 1", + "select bar from {{ ref('foo') }} as f", + "select max(ds) from {{ this }}", + "{% if is_incremental() %}where ds > (select max(ds) from {{ this }}){% endif %}", + "foo {% if sqlmesh_incremental is defined %} bar {% endif %} bar", + "foo between '{{ start_ds }}' and '{{ end_ds }}'", + "{{ 42 }}", + "{{ foo.bar }}", + "{{ 'baz' }}", + "{{ col }} BETWEEN '{{ dates[0] }}' AND '{{ dates[1] }}'", + "{% set foo = bar(baz, bing='bong') %}", + "{% if a == 'ds' %}foo{% elif a == 'ts' %}bar{% elif a < 'ys' or (b != 'ds' and c >= 'ts') %}baz{% else %}bing{% endif %}", + "{% set my_string = my_string ~ stuff ~ ', ' ~ 1 %}", + "{{ context.do_some_action('param') }}", + "{% set big_ole_block %}foo{% endset %}", + "{% if not loop.last %}foo{% endif %}", + "{% for a, b in some_func(a=foo['bar'][0], b=c.d[5]).items() %}foo_{{ a }}_{{ b }}{% endfor %}", + "{{ column | replace(prefix, '') }}", + "{{ column | filter('a', foo='bar') }}", + "{% filter upper %}foo{% endfilter %}", + "{% filter foo(0, bar='baz') %}foo{% endfilter %}", + "{% if foo in ('bar', 'baz') %}bar{% endif %}", + "{% if foo not in ('bar', 'baz') %}bing{% endif %}", + "{% if (field.a if field.a else field.b) | lower not in ('c', 'd') %}foo{% endif %}", + "{% do foo.bar('baz') %}", + "{% set a = (col | lower + '_') + b %}", + "{{ foo[1:10] | lower }}", + "{{ foo[1:] }}", + "{{ foo[:1] }}", + "{% for col in all_columns if col.name in columns_to_compare and col.name in special_names %}{{ col }}{% endfor %}", + "{{ ' or ' if not loop.first else '' }}", + "{% set foo = ['a', 'b', c, d.e, f[0], g.h.i[0][1]] %}", + """{% set foo = "('%Y%m%d', partition_id)" %}""", + "{% set foo = (graph.nodes.values() | selectattr('name', 'equalto', model_name) | list)[0] %}", + "{% set foo.bar = baz.bing(database='foo') %}", + "{{ return(('some', 'tuple')) }}", + "{% call foo('bar', baz=True) %}bar{% endcall %}", + "{% call(user) dump_users(list_of_user) %}bar{% endcall %}", + "{% macro foo(a, b='default', c=None) %}{% endmacro %}", + # "{# some comment #}", #todo: comments get stripped entirely + # "foo\n{%- if bar -%} baz {% endif -%}", #todo: whitespace trim handling is a nice-to-have + _load_fixture("model_query_incremental.sql"), + _load_fixture("macro_dbt_incremental.sql"), + _load_fixture("jinja_nested_if.sql"), + ], +) +def test_generator_roundtrip(original_jinja: str) -> None: + registry = JinjaMacroRegistry() + env = registry.build_environment() + + ast = env.parse(original_jinja) + generated = JinjaGenerator().generate(ast) + + assert generated == original_jinja + + me = MacroExtractor() + # basically just test this doesnt throw an exception. + # The MacroExtractor uses SQLGLot's tokenizer and not Jinja's so these need to work when the converted project is loaded by the native loader + me.extract(generated) + + +def test_generator_sql_comment_macro(): + jinja_str = "-- before sql comment{% macro foo() %}-- inner sql comment{% endmacro %}" + + registry = JinjaMacroRegistry() + env = registry.build_environment() + + ast = env.parse(jinja_str) + generated = JinjaGenerator().generate(ast) + + assert ( + generated == "-- before sql comment\n{% macro foo() %}-- inner sql comment\n{% endmacro %}" + ) + + # check roundtripping an existing newline doesnt keep adding newlines + assert JinjaGenerator().generate(env.parse(generated)) == generated + + +@pytest.mark.parametrize("original_jinja", [_load_fixture("macro_func_with_params.sql")]) +def test_generator_roundtrip_ignore_whitespace(original_jinja: str) -> None: + """ + This makes the following assumptions: + - SQL isnt too sensitive about indentation / whitespace + - The Jinja AST doesnt capture enough information to perfectly replicate the input template with regards to whitespace handling + + So if, disregarding whitespace, the original input string is the same as the AST being run through the generator: the test passes + """ + registry = JinjaMacroRegistry() + env = registry.build_environment() + + ast = env.parse(original_jinja) + + generated = JinjaGenerator().generate(ast) + + assert " ".join(original_jinja.split()) == " ".join(generated.split()) + + +def test_convert_jinja_query(sushi_dbt_context: Context) -> None: + model = sushi_dbt_context.models['"memory"."sushi"."customer_revenue_by_day"'] + assert isinstance(model, SqlModel) + + query = model.query + assert isinstance(query, d.JinjaQuery) + + result = convert_jinja_query(sushi_dbt_context, model, query) + + assert isinstance(result, exp.Query) + + assert ( + result.sql(dialect=model.dialect, pretty=True) + == """WITH order_total AS ( + SELECT + oi.order_id AS order_id, + SUM(oi.quantity * i.price) AS total, + oi.ds AS ds + FROM sushi_raw.order_items AS oi + LEFT JOIN sushi_raw.items AS i + ON oi.item_id = i.id AND oi.ds = i.ds + WHERE + oi.ds BETWEEN @start_ds AND @end_ds + GROUP BY + oi.order_id, + oi.ds +) +SELECT + CAST(o.customer_id AS INT) AS customer_id, /* Customer id */ + CAST(SUM(ot.total) AS DOUBLE) AS revenue, /* Revenue from orders made by this customer */ + CAST(o.ds AS TEXT) AS ds /* Date */ +FROM sushi_raw.orders AS o +LEFT JOIN order_total AS ot + ON o.id = ot.order_id AND o.ds = ot.ds +WHERE + o.ds BETWEEN @start_ds AND @end_ds +GROUP BY + o.customer_id, + o.ds""" + ) + + +def test_convert_jinja_query_exclude_transform(empty_dbt_context: Context) -> None: + ctx = empty_dbt_context + + (ctx.path / "models" / "model1.sql").write_text(""" + {{ + config( + materialized='incremental', + incremental_strategy='delete+insert', + time_column='ds' + ) + }} + + select * from {{ ref('items') }} + {% if is_incremental() %} + where ds > (select max(ds) from {{ this }}) + {% endif %} + """) + + ctx.load() + + model = ctx.models['"memory"."project"."model1"'] + assert isinstance(model, SqlModel) + + query = model.query + assert isinstance(query, d.JinjaQuery) + + converted_query = convert_jinja_query( + ctx, + model, + query, + exclude=[jt.resolve_dbt_ref_to_model_name, jt.rewrite_dbt_ref_to_migrated_ref], + ) + sql = converted_query.sql() + + assert "{{ ref('items') }}" in sql + assert "{{ this }}" not in sql + assert "{% if is_incremental() %}" not in sql + assert "{% endif %}" not in sql + + +def test_convert_jinja_query_self_referencing(empty_dbt_context: Context) -> None: + ctx = empty_dbt_context + + (ctx.path / "models" / "model1.sql").write_text(""" + {{ + config( + materialized='incremental', + incremental_strategy='delete+insert', + time_column='ds' + ) + }} + + select * from {{ ref('items') }} + {% if is_incremental() %} + where ds > (select max(ds) from {{ this }}) + {% endif %} + """) + + ctx.load() + + model = ctx.models['"memory"."project"."model1"'] + assert model.columns_to_types_or_raise + assert ( + not model.depends_on_self + ) # the DBT loader doesnt detect self-references within is_incremental blocks + assert isinstance(model, SqlModel) + + query = model.query + assert isinstance(query, d.JinjaQuery) + + converted_query = convert_jinja_query(ctx, model, query) + converted_model_definition = model.copy().render_definition()[0].sql() + + # load from scratch to use the native loader and clear @cached_property's + ctx.upsert_model( + load_sql_based_model( + expressions=[d.parse_one(converted_model_definition), converted_query], + default_catalog=ctx.default_catalog, + ) + ) + converted_model = ctx.models['"memory"."project"."model1"'] + assert isinstance(converted_model, SqlModel) + + assert not "{% is_incremental" in converted_model.query.sql() + assert ( + converted_model.depends_on_self + ) # Once the is_incremental blocks are removed, the model can be detected as self referencing + assert columns_to_types_all_known( + converted_model.columns_to_types_or_raise + ) # columns to types must all be known for self-referencing models + + +def test_convert_jinja_query_self_referencing_columns_to_types_not_all_known( + empty_dbt_context: Context, +) -> None: + ctx = empty_dbt_context + + (ctx.path / "models" / "model1.sql").write_text(""" + {{ + config( + materialized='incremental', + incremental_strategy='delete+insert', + time_column='ds' + ) + }} + + select id, name, ds from external.table + {% if is_incremental() %} + where ds > (select max(ds) from {{ this }}) + {% endif %} + """) + + ctx.load() + + model = ctx.models['"memory"."project"."model1"'] + assert model.columns_to_types_or_raise + assert ( + not model.depends_on_self + ) # the DBT loader doesnt detect self-references within is_incremental blocks + assert isinstance(model, SqlModel) + + query = model.query + assert isinstance(query, d.JinjaQuery) + + converted_query = convert_jinja_query(ctx, model, query) + converted_model_definition = model.render_definition()[0].sql() + + # load from scratch to use the native loader and clear @cached_property's + ctx.upsert_model( + load_sql_based_model( + expressions=[d.parse_one(converted_model_definition), converted_query], + jinja_macros=model.jinja_macros, + default_catalog=ctx.default_catalog, + ) + ) + converted_model = ctx.models['"memory"."project"."model1"'] + assert isinstance(converted_model, SqlModel) + + # {% is_incremental() %} block should be retained because removing it would make the model self-referencing but the columns_to_types + # arent all known so this would create a load error like: Error: Self-referencing models require inferrable column types. + assert "{% if is_incremental" in converted_model.query.sql() + assert "{{ this }}" not in converted_model.query.sql() + assert not converted_model.depends_on_self + + assert not columns_to_types_all_known( + converted_model.columns_to_types_or_raise + ) # this is ok because the model is not self-referencing + + +def test_convert_jinja_query_migrated_ref(empty_dbt_context: Context) -> None: + ctx = empty_dbt_context + + (ctx.path / "models" / "model1.sql").write_text(""" + {{ + config( + materialized='incremental', + incremental_strategy='delete+insert', + time_column='ds' + ) + }} + + {% macro ref_handler(relation) %} + {{ relation.name }} + {% endmacro %} + + select * from {{ ref_handler(ref("items")) }} + """) + + ctx.load() + + model = ctx.models['"memory"."project"."model1"'] + assert isinstance(model, SqlModel) + query = model.query + assert isinstance(query, d.JinjaQuery) + + converted_query = convert_jinja_query(ctx, model, query) + + assert ( + """select * from {{ ref_handler(__migrated_ref(database='memory', schema='project_raw', identifier='items', sqlmesh_model_name='project_raw.items')) }}""" + in converted_query.sql() + ) + + +def test_convert_jinja_query_post_statement(empty_dbt_context: Context) -> None: + ctx = empty_dbt_context + + (ctx.path / "models" / "model1.sql").write_text(""" + {{ + config( + materialized='incremental', + incremental_strategy='delete+insert', + time_column='ds', + post_hook="create index foo_idx on {{ this }} (id)" + ) + }} + + select * from {{ ref("items") }} + """) + + ctx.load() + + model = ctx.models['"memory"."project"."model1"'] + assert isinstance(model, SqlModel) + + assert model.post_statements + post_statement = model.post_statements[0] + assert isinstance(post_statement, d.JinjaStatement) + + converted_post_statement = convert_jinja_query(ctx, model, post_statement) + + assert "CREATE INDEX foo_idx ON project.model1(id)" in converted_post_statement.sql( + dialect="duckdb" + ) + + +@pytest.mark.parametrize( + "input,expected", + [ + ( + """ + {% macro incremental_by_time(col, time_type) %} + {% if is_incremental() %} + WHERE + {{ col }} > (select max({{ col }}) from {{ this }}) + {% endif %} + {% if sqlmesh_incremental is defined %} + {% set dates = incremental_dates_by_time_type(time_type) %} + WHERE + {{ col }} BETWEEN '{{ dates[0] }}' AND '{{ dates[1] }}' + {% endif %} + {% endmacro %} + """, + """ + {% macro incremental_by_time(col, time_type) %} + {% set dates = incremental_dates_by_time_type(time_type) %} + WHERE + {{ col }} BETWEEN '{{ dates[0] }}' AND '{{ dates[1] }}' + {% endmacro %} + """, + ), + ( + """ + {% macro foo(iterations) %} + with base as ( + select * from {{ ref('customer_revenue_by_day') }} + ), + iter as ( + {% for i in range(0, iterations) %} + 'iter_{{ i }}' as iter_num_{{ i }} + {% if not loop.last %},{% endif %} + {% endfor %} + ) + select 1 + {% endmacro %}""", + """ + {% macro foo(iterations) %} + with base as ( + select * from sushi.customer_revenue_by_day + ), + iter as ( + {% for i in range(0, iterations) %} + 'iter_{{ i }}' as iter_num_{{ i }} + {% if not loop.last %},{% endif %} + {% endfor %} + ) + select 1 + {% endmacro %}""", + ), + ( + """{% macro expand_ref(model_name) %}{{ ref(model_name) }}{% endmacro %}""", + """{% macro expand_ref(model_name) %}{{ ref(model_name) }}{% endmacro %}""", + ), + ], +) +def test_convert_jinja_macro(input: str, expected: str, sushi_dbt_context: Context) -> None: + result = convert_jinja_macro(sushi_dbt_context, input.strip()) + + assert " ".join(result.split()) == " ".join(expected.strip().split()) diff --git a/tests/dbt/converter/test_jinja_transforms.py b/tests/dbt/converter/test_jinja_transforms.py new file mode 100644 index 0000000000..c7d060ea40 --- /dev/null +++ b/tests/dbt/converter/test_jinja_transforms.py @@ -0,0 +1,453 @@ +import pytest +import typing as t +from sqlglot import parse_one +from sqlmesh.core.model import create_sql_model, create_external_model +from sqlmesh.dbt.converter.jinja import transform, JinjaGenerator +import sqlmesh.dbt.converter.jinja_transforms as jt +from sqlmesh.dbt.converter.common import JinjaTransform +from sqlmesh.utils.jinja import environment, Environment, ENVIRONMENT +from sqlmesh.core.context import Context +from sqlmesh.core.config import Config, ModelDefaultsConfig + + +def transform_str( + input: str, handler: JinjaTransform, environment: t.Optional[Environment] = None +) -> str: + environment = environment or ENVIRONMENT + ast = environment.parse(input) + return JinjaGenerator().generate(transform(ast, handler)) + + +@pytest.mark.parametrize( + "input,expected", + [ + ("select * from {{ ref('bar') }} as t", "select * from foo.bar as t"), + ("select * from {{ ref('bar', version=1) }} as t", "select * from foo.bar_v1 as t"), + ("select * from {{ ref('bar', v=1) }} as t", "select * from foo.bar_v1 as t"), + ( + "select * from {{ ref('unknown') }} as t", + "select * from __unresolved_ref__.unknown as t", + ), + ( + "{% macro foo() %}select * from {{ ref('bar') }}{% endmacro %}", + "{% macro foo() %}select * from foo.bar{% endmacro %}", + ), + # these shouldnt be transformed as the macro call might rely on some property of the Relation object returned by ref() + ("{{ dbt_utils.union_relations([ref('foo')]) }},", None), + ("select * from {% if some_macro(ref('bar')) %}foo{% endif %}", None), + ( + "select * from {% if some_macro(ref('bar')) %}{{ ref('bar') }}{% endif %}", + "select * from {% if some_macro(ref('bar')) %}foo.bar{% endif %}", + ), + ("{{ some_macro(ref('bar')) }}", None), + ("{{ some_macro(table=ref('bar')) }}", None), + ], +) +def test_resolve_dbt_ref_to_model_name(input: str, expected: t.Optional[str]) -> None: + expected = expected or input + + from dbt.adapters.base import BaseRelation + + # note: bigquery dialect chosen because its identifiers have backticks + # but internally SQLMesh stores model fqn with double quotes + config = Config(model_defaults=ModelDefaultsConfig(dialect="bigquery")) + ctx = Context(config=config) + ctx.default_catalog = "sqlmesh" + + assert ctx.default_catalog == "sqlmesh" + assert ctx.default_dialect == "bigquery" + + model = create_sql_model( + name="foo.bar", query=parse_one("select 1"), default_catalog=ctx.default_catalog + ) + model2 = create_sql_model( + name="foo.bar_v1", query=parse_one("select 1"), default_catalog=ctx.default_catalog + ) + ctx.upsert_model(model) + ctx.upsert_model(model2) + + assert '"sqlmesh"."foo"."bar"' in ctx.models + + def _resolve_ref(ref_name: str, version: t.Optional[int] = None) -> t.Optional[BaseRelation]: + if ref_name == "bar": + identifier = "bar" + if version: + identifier = f"bar_v{version}" + + relation = BaseRelation.create( + database="sqlmesh", schema="foo", identifier=identifier, quote_character="`" + ) + assert ( + relation.render() == "`sqlmesh`.`foo`.`bar`" + if not version + else f"`sqlmesh`.`foo`.`bar_v{version}`" + ) + return relation + return None + + jinja_env = environment() + jinja_env.globals["ref"] = _resolve_ref + + assert ( + transform_str( + input, + jt.resolve_dbt_ref_to_model_name(ctx.models, jinja_env, dialect=ctx.default_dialect), + ) + == expected + ) + + +@pytest.mark.parametrize( + "input,expected", + [ + ( + "select * from {{ ref('bar') }} as t", + "select * from {{ __migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar') }} as t", + ), + ( + "{% macro foo() %}select * from {{ ref('bar') }}{% endmacro %}", + "{% macro foo() %}select * from {{ __migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar') }}{% endmacro %}", + ), + ( + "{{ dbt_utils.union_relations([ref('bar')]) }}", + "{{ dbt_utils.union_relations([__migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar')]) }}", + ), + ( + "select * from {% if some_macro(ref('bar')) %}foo{% endif %}", + "select * from {% if some_macro(__migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar')) %}foo{% endif %}", + ), + ( + "select * from {% if some_macro(ref('bar')) %}{{ ref('bar') }}{% endif %}", + "select * from {% if some_macro(__migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar')) %}{{ __migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar') }}{% endif %}", + ), + ( + "{{ some_macro(ref('bar')) }}", + "{{ some_macro(__migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar')) }}", + ), + ( + "{{ some_macro(table=ref('bar')) }}", + "{{ some_macro(table=__migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar')) }}", + ), + ], +) +def test_rewrite_dbt_ref_to_migrated_ref(input: str, expected: t.Optional[str]) -> None: + expected = expected or input + + from dbt.adapters.base import BaseRelation + + # note: bigquery dialect chosen because its identifiers have backticks + # but internally SQLMesh stores model fqn with double quotes + config = Config(model_defaults=ModelDefaultsConfig(dialect="bigquery")) + ctx = Context(config=config) + ctx.default_catalog = "sqlmesh" + + assert ctx.default_catalog == "sqlmesh" + assert ctx.default_dialect == "bigquery" + + model = create_sql_model( + name="foo.bar", query=parse_one("select 1"), default_catalog=ctx.default_catalog + ) + ctx.upsert_model(model) + + assert '"sqlmesh"."foo"."bar"' in ctx.models + + def _resolve_ref(ref_name: str) -> t.Optional[BaseRelation]: + if ref_name == "bar": + relation = BaseRelation.create( + database="sqlmesh", schema="foo", identifier="bar", quote_character="`" + ) + assert relation.render() == "`sqlmesh`.`foo`.`bar`" + return relation + return None + + jinja_env = environment() + jinja_env.globals["ref"] = _resolve_ref + + assert ( + transform_str( + input, + jt.rewrite_dbt_ref_to_migrated_ref(ctx.models, jinja_env, dialect=ctx.default_dialect), + ) + == expected + ) + + +@pytest.mark.parametrize( + "input,expected", + [ + ("select * from {{ source('upstream', 'foo') }} as t", "select * from upstream.foo as t"), + ("select * from {{ source('unknown', 'foo') }} as t", "select * from unknown.foo as t"), + ( + "{% macro foo() %}select * from {{ source('upstream', 'foo') }}{% endmacro %}", + "{% macro foo() %}select * from upstream.foo{% endmacro %}", + ), + # these shouldnt be transformed as the macro call might rely on some property of the Relation object returned by source() + ("select * from {% if some_macro(source('upstream', 'foo')) %}foo{% endif %}", None), + ("{{ dbt_utils.union_relations([source('upstream', 'foo')]) }},", None), + ( + "select * from {% if some_macro(source('upstream', 'foo')) %}{{ source('upstream', 'foo') }}{% endif %}", + "select * from {% if some_macro(source('upstream', 'foo')) %}upstream.foo{% endif %}", + ), + ("{{ some_macro(source('upstream', 'foo')) }}", None), + ("{% set results = run_query('select foo from ' ~ source('schema', 'table')) %}", None), + ], +) +def test_resolve_dbt_source_to_model_name(input: str, expected: t.Optional[str]) -> None: + expected = expected or input + + from dbt.adapters.base import BaseRelation + + # note: bigquery dialect chosen because its identifiers have backticks + # but internally SQLMesh stores model fqn with double quotes + config = Config(model_defaults=ModelDefaultsConfig(dialect="bigquery")) + ctx = Context(config=config) + ctx.default_catalog = "sqlmesh" + + assert ctx.default_catalog == "sqlmesh" + assert ctx.default_dialect == "bigquery" + + model = create_external_model(name="upstream.foo", default_catalog=ctx.default_catalog) + ctx.upsert_model(model) + + assert '"sqlmesh"."upstream"."foo"' in ctx.models + + def _resolve_source(schema_name: str, table_name: str) -> t.Optional[BaseRelation]: + if schema_name == "upstream" and table_name == "foo": + relation = BaseRelation.create( + database="sqlmesh", schema="upstream", identifier="foo", quote_character="`" + ) + assert relation.render() == "`sqlmesh`.`upstream`.`foo`" + return relation + return None + + jinja_env = environment() + jinja_env.globals["source"] = _resolve_source + + assert ( + transform_str( + input, + jt.resolve_dbt_source_to_model_name(ctx.models, jinja_env, dialect=ctx.default_dialect), + ) + == expected + ) + + +@pytest.mark.parametrize( + "input,expected", + [ + ( + "select * from {{ source('upstream', 'foo') }} as t", + "select * from {{ __migrated_source(database='sqlmesh', schema='upstream', identifier='foo') }} as t", + ), + ( + "select * from {{ source('unknown', 'foo') }} as t", + "select * from {{ source('unknown', 'foo') }} as t", + ), + ( + "{% macro foo() %}select * from {{ source('upstream', 'foo') }}{% endmacro %}", + "{% macro foo() %}select * from {{ __migrated_source(database='sqlmesh', schema='upstream', identifier='foo') }}{% endmacro %}", + ), + ( + "select * from {% if some_macro(source('upstream', 'foo')) %}foo{% endif %}", + "select * from {% if some_macro(__migrated_source(database='sqlmesh', schema='upstream', identifier='foo')) %}foo{% endif %}", + ), + ( + "{{ dbt_utils.union_relations([source('upstream', 'foo')]) }},", + "{{ dbt_utils.union_relations([__migrated_source(database='sqlmesh', schema='upstream', identifier='foo')]) }},", + ), + ( + "select * from {% if some_macro(source('upstream', 'foo')) %}{{ source('upstream', 'foo') }}{% endif %}", + "select * from {% if some_macro(__migrated_source(database='sqlmesh', schema='upstream', identifier='foo')) %}{{ __migrated_source(database='sqlmesh', schema='upstream', identifier='foo') }}{% endif %}", + ), + ( + "{{ some_macro(source('upstream', 'foo')) }}", + "{{ some_macro(__migrated_source(database='sqlmesh', schema='upstream', identifier='foo')) }}", + ), + ( + "{% set results = run_query('select foo from ' ~ source('upstream', 'foo')) %}", + "{% set results = run_query('select foo from ' ~ __migrated_source(database='sqlmesh', schema='upstream', identifier='foo')) %}", + ), + ], +) +def test_rewrite_dbt_source_to_migrated_source(input: str, expected: t.Optional[str]) -> None: + expected = expected or input + + from dbt.adapters.base import BaseRelation + + # note: bigquery dialect chosen because its identifiers have backticks + # but internally SQLMesh stores model fqn with double quotes + config = Config(model_defaults=ModelDefaultsConfig(dialect="bigquery")) + ctx = Context(config=config) + ctx.default_catalog = "sqlmesh" + + assert ctx.default_catalog == "sqlmesh" + assert ctx.default_dialect == "bigquery" + + model = create_external_model(name="upstream.foo", default_catalog=ctx.default_catalog) + ctx.upsert_model(model) + + assert '"sqlmesh"."upstream"."foo"' in ctx.models + + def _resolve_source(schema_name: str, table_name: str) -> t.Optional[BaseRelation]: + if schema_name == "upstream" and table_name == "foo": + relation = BaseRelation.create( + database="sqlmesh", schema="upstream", identifier="foo", quote_character="`" + ) + assert relation.render() == "`sqlmesh`.`upstream`.`foo`" + return relation + return None + + jinja_env = environment() + jinja_env.globals["source"] = _resolve_source + + assert ( + transform_str( + input, + jt.rewrite_dbt_source_to_migrated_source( + ctx.models, jinja_env, dialect=ctx.default_dialect + ), + ) + == expected + ) + + +@pytest.mark.parametrize( + "input,expected", + [ + ("select * from {{ this }}", "select * from foo.bar"), + ("{% if foo(this) %}bar{% endif %}", None), + ("select * from {{ this.identifier }}", None), + ], +) +def test_resolve_dbt_this_to_model_name(input: str, expected: t.Optional[str]): + expected = expected or input + assert transform_str(input, jt.resolve_dbt_this_to_model_name("foo.bar")) == expected + + +@pytest.mark.parametrize( + "input,expected", + [ + # sqlmesh_incremental present, is_incremental() block removed + ( + """ + select * from foo where + {% if is_incremental() %}ds > (select max(ds)) from {{ this }}){% endif %} + {% if sqlmesh_incremental is defined %}ds BETWEEN {{ start_ds }} and {{ end_ds }}{% endif %} + """, + """ + select * from foo + where + {% if sqlmesh_incremental is defined %}ds BETWEEN {{ start_ds }} and {{ end_ds }}{% endif %} + """, + ), + # sqlmesh_incremental is NOT present; is_incremental() blocks untouched + ( + """ + select * from foo + where + {% if is_incremental() %}ds > (select max(ds)) from {{ this }}){% endif %} + """, + """ + select * from foo + where + {% if is_incremental() %}ds > (select max(ds)) from {{ this }}){% endif %} + """, + ), + ], +) +def test_deduplicate_incremental_checks(input: str, expected: str) -> None: + assert " ".join(transform_str(input, jt.deduplicate_incremental_checks()).split()) == " ".join( + expected.split() + ) + + +@pytest.mark.parametrize( + "input,expected", + [ + # is_incremental() removed + ( + "select * from foo where {% if is_incremental() %}ds >= (select max(ds) from {{ this }} ){% endif %}", + "select * from foo where ds >= (select max(ds) from {{ this }} )", + ), + # sqlmesh_incremental removed + ( + "select * from foo where {% if sqlmesh_incremental is defined %}ds BETWEEN {{ start_ds }} and {{ end_ds }}{% endif %}", + "select * from foo where ds BETWEEN {{ start_ds }} and {{ end_ds }}", + ), + # else untouched + ( + "select * from foo where {% if is_incremental() %}ds >= (select max(ds) from {{ this }} ){% else %}ds is not null{% endif %}", + "select * from foo where {% if is_incremental() %}ds >= (select max(ds) from {{ this }} ){% else %}ds is not null{% endif %}", + ), + ], +) +def test_unpack_incremental_checks(input: str, expected: str) -> None: + assert " ".join(transform_str(input, jt.unpack_incremental_checks()).split()) == " ".join( + expected.split() + ) + + +@pytest.mark.parametrize( + "input,expected", + [ + ("{{ start_ds }}", "@start_ds"), + ( + "select id, ds from foo where ds between {{ start_ts }} and {{ end_ts }}", + "select id, ds from foo where ds between @start_ts and @end_ts", + ), + ("select {{ some_macro(start_ts) }}", None), + ("{{ start_date }}", "@start_date"), + ("'{{ start_date }}'", "'@start_ds'"), # date inside string literal should remain a string + ], +) +def test_rewrite_sqlmesh_predefined_variables_to_sqlmesh_macro_syntax( + input: str, expected: t.Optional[str] +) -> None: + expected = expected or input + assert ( + transform_str(input, jt.rewrite_sqlmesh_predefined_variables_to_sqlmesh_macro_syntax()) + == expected + ) + + +@pytest.mark.parametrize( + "input,expected,package", + [ + ("{{ var('foo') }}", "{{ var('foo') }}", None), + ("{{ var('foo') }}", "{{ var('foo', __dbt_package='test') }}", "test"), + ( + "{{ var('foo', 'default') }}", + "{{ var('foo', 'default', __dbt_package='test') }}", + "test", + ), + ( + "{% if 'col_name' in var('history_columns') %}bar{% endif %}", + "{% if 'col_name' in var('history_columns', __dbt_package='test') %}bar{% endif %}", + "test", + ), + ], +) +def test_append_dbt_package_kwarg_to_var_calls( + input: str, expected: str, package: t.Optional[str] +) -> None: + assert ( + transform_str(input, jt.append_dbt_package_kwarg_to_var_calls(package_name=package)) + == expected + ) + + +@pytest.mark.parametrize( + "input,expected", + [ + ( + "select * from foo where ds between '@start_dt' and '@end_dt'", + "SELECT * FROM foo WHERE ds BETWEEN @start_dt AND @end_dt", + ), + ( + "select * from foo where bar <> '@unrelated'", + "SELECT * FROM foo WHERE bar <> '@unrelated'", + ), + ], +) +def test_unwrap_macros_in_string_literals(input: str, expected: str) -> None: + assert parse_one(input).transform(jt.unwrap_macros_in_string_literals()).sql() == expected diff --git a/tests/utils/test_jinja.py b/tests/utils/test_jinja.py index 5eb00aeb3c..3660adaa95 100644 --- a/tests/utils/test_jinja.py +++ b/tests/utils/test_jinja.py @@ -9,6 +9,8 @@ MacroReturnVal, call_name, nodes, + extract_macro_references_and_variables, + extract_dbt_adapter_dispatch_targets, ) @@ -175,6 +177,54 @@ def test_macro_registry_trim(): assert not trimmed_registry_for_package_b.root_macros +def test_macro_registry_trim_keeps_dbt_adapter_dispatch(): + registry = JinjaMacroRegistry() + extractor = MacroExtractor() + + registry.add_macros( + extractor.extract( + """ + {% macro foo(col) %} + {{ adapter.dispatch('foo', 'test_package') }} + {% endmacro %} + + {% macro default__foo(col) %} + foo_{{ col }} + {% endmacro %} + + {% macro unrelated() %}foo{% endmacro %} + """, + dialect="duckdb", + ), + package="test_package", + ) + + assert sorted(list(registry.packages["test_package"].keys())) == [ + "default__foo", + "foo", + "unrelated", + ] + assert sorted(str(r) for r in registry.packages["test_package"]["foo"].depends_on) == [ + "adapter.dispatch", + "test_package.default__foo", + "test_package.duckdb__foo", + ] + + query_str = """ + select * from {{ test_package.foo('bar') }} + """ + + references, _ = extract_macro_references_and_variables(query_str, dbt_target_name="test") + references_list = list(references) + assert len(references_list) == 1 + assert str(references_list[0]) == "test_package.foo" + + trimmed_registry = registry.trim(references) + + # duckdb__foo is missing from this list because it's not actually defined as a macro + assert sorted(list(trimmed_registry.packages["test_package"].keys())) == ["default__foo", "foo"] + + def test_macro_return(): macros = "{% macro test_return() %}{{ macro_return([1, 2, 3]) }}{% endmacro %}" @@ -302,3 +352,31 @@ def test_dbt_adapter_macro_scope(): rendered = registry.build_environment().from_string("{{ spark__macro_a() }}").render() assert rendered.strip() == "macro_a" + + +def test_extract_dbt_adapter_dispatch_targets(): + assert extract_dbt_adapter_dispatch_targets(""" + {% macro my_macro(arg1, arg2) -%} + {{ return(adapter.dispatch('my_macro')(arg1, arg2)) }} + {% endmacro %} + """) == [("my_macro", None)] + + assert extract_dbt_adapter_dispatch_targets(""" + {% macro my_macro(arg1, arg2) -%} + {{ return(adapter.dispatch('my_macro', 'foo')(arg1, arg2)) }} + {% endmacro %} + """) == [("my_macro", "foo")] + + assert extract_dbt_adapter_dispatch_targets("""{{ adapter.dispatch('my_macro') }}""") == [ + ("my_macro", None) + ] + + assert extract_dbt_adapter_dispatch_targets(""" + {% macro foo() %} + {{ adapter.dispatch('my_macro') }} + {{ some_other_call() }} + {{ return(adapter.dispatch('other_macro', 'other_package')) }} + {% endmacro %} + """) == [("my_macro", None), ("other_macro", "other_package")] + + assert extract_dbt_adapter_dispatch_targets("no jinja") == []