Skip to content

Commit b7d7668

Browse files
committed
Feat: Add plan option to always compare against prod
1 parent 906ed7a commit b7d7668

5 files changed

Lines changed: 155 additions & 7 deletions

File tree

sqlmesh/core/config/plan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class PlanConfig(BaseConfig):
2020
auto_apply: Whether to automatically apply the new plan after creation.
2121
use_finalized_state: Whether to compare against the latest finalized environment state, or to use
2222
whatever state the target environment is currently in.
23+
always_compare_against_prod: Whether to always compare against production when planning, even if the target environment exists.
2324
"""
2425

2526
forward_only: bool = False
@@ -30,3 +31,4 @@ class PlanConfig(BaseConfig):
3031
no_prompts: bool = True
3132
auto_apply: bool = False
3233
use_finalized_state: bool = False
34+
always_compare_against_prod: bool = False

sqlmesh/core/context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,7 +1480,7 @@ def plan_builder(
14801480

14811481
snapshots = self._snapshots(models_override)
14821482
context_diff = self._context_diff(
1483-
environment or c.PROD,
1483+
environment=environment,
14841484
snapshots=snapshots,
14851485
create_from=create_from,
14861486
force_no_diff=restate_models is not None
@@ -2630,11 +2630,12 @@ def _context_diff(
26302630
diff_rendered: bool = False,
26312631
) -> ContextDiff:
26322632
environment = Environment.sanitize_name(environment)
2633+
26332634
if force_no_diff:
26342635
return ContextDiff.create_no_diff(environment, self.state_reader)
26352636

26362637
return ContextDiff.create(
2637-
environment,
2638+
environment=environment,
26382639
snapshots=snapshots or self.snapshots,
26392640
create_from=create_from or c.PROD,
26402641
state_reader=self.state_reader,
@@ -2645,6 +2646,7 @@ def _context_diff(
26452646
environment_statements=self._environment_statements,
26462647
gateway_managed_virtual_layer=self.config.gateway_managed_virtual_layer,
26472648
infer_python_dependencies=self.config.infer_python_dependencies,
2649+
always_compare_against_prod=self.config.plan.always_compare_against_prod,
26482650
)
26492651

26502652
def _destroy(self) -> None:

sqlmesh/core/context_diff.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import sys
1616
import typing as t
17+
import logging
18+
1719
from difflib import ndiff, unified_diff
1820
from functools import cached_property
1921
from sqlmesh.core import constants as c
@@ -38,6 +40,8 @@
3840

3941
IGNORED_PACKAGES = {"sqlmesh", "sqlglot"}
4042

43+
logger = logging.getLogger(__name__)
44+
4145

4246
class ContextDiff(PydanticModel):
4347
"""ContextDiff is an object representing the difference between two environments.
@@ -88,6 +92,8 @@ class ContextDiff(PydanticModel):
8892
"""Environment statements."""
8993
diff_rendered: bool = False
9094
"""Whether the diff should compare raw vs rendered models"""
95+
initial_environment: str = ""
96+
"""The initial target environment (e.g 'dev'), if the plan option `always_compare_to_prod` is set"""
9197

9298
@classmethod
9399
def create(
@@ -103,6 +109,8 @@ def create(
103109
environment_statements: t.Optional[t.List[EnvironmentStatements]] = [],
104110
gateway_managed_virtual_layer: bool = False,
105111
infer_python_dependencies: bool = True,
112+
initial_environment: t.Optional[str] = None,
113+
always_compare_against_prod: bool = False,
106114
) -> ContextDiff:
107115
"""Create a ContextDiff object.
108116
@@ -127,8 +135,17 @@ def create(
127135
Returns:
128136
The ContextDiff object.
129137
"""
130-
environment = environment.lower()
138+
initial_environment = environment
139+
environment = _get_target_environment(
140+
environment, state_reader, always_compare_against_prod
141+
)
142+
131143
env = state_reader.get_environment(environment)
144+
initial_env = (
145+
env
146+
if initial_environment == environment
147+
else state_reader.get_environment(initial_environment)
148+
)
132149

133150
create_from_env_exists = False
134151
if env is None or env.expired:
@@ -222,6 +239,7 @@ def create(
222239

223240
return ContextDiff(
224241
environment=environment,
242+
initial_environment=initial_environment,
225243
is_new_environment=is_new_environment,
226244
is_unfinalized_environment=bool(env and not env.finalized_ts),
227245
normalize_environment_name=is_new_environment or bool(env and env.normalize_name),
@@ -232,7 +250,9 @@ def create(
232250
modified_snapshots=modified_snapshots,
233251
snapshots=merged_snapshots,
234252
new_snapshots=new_snapshots,
235-
previous_plan_id=env.plan_id if env and not is_new_environment else None,
253+
previous_plan_id=initial_env.plan_id
254+
if initial_env and not is_new_environment
255+
else None,
236256
previously_promoted_snapshot_ids=previously_promoted_snapshot_ids,
237257
previous_finalized_snapshots=env.previous_finalized_snapshots if env else None,
238258
previous_requirements=env.requirements if env else {},
@@ -261,8 +281,9 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD
261281

262282
snapshots = state_reader.get_snapshots(env.snapshots)
263283

284+
environment = env.name
264285
return ContextDiff(
265-
environment=env.name,
286+
environment=environment,
266287
is_new_environment=False,
267288
is_unfinalized_environment=False,
268289
normalize_environment_name=env.normalize_name,
@@ -281,6 +302,7 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD
281302
previous_environment_statements=[],
282303
previous_gateway_managed_virtual_layer=env.gateway_managed,
283304
gateway_managed_virtual_layer=env.gateway_managed,
305+
initial_environment=environment,
284306
)
285307

286308
@property
@@ -479,6 +501,23 @@ def text_diff(self, name: str) -> str:
479501
return ""
480502

481503

504+
def _get_target_environment(
505+
environment: str, state_reader: StateReader, always_compare_against_prod: bool = False
506+
) -> str:
507+
if always_compare_against_prod:
508+
prod = state_reader.get_environment(c.PROD)
509+
if prod:
510+
logger.warning(
511+
f"Comparing against production environment instead of {environment}. Note that this may lead to "
512+
"additional backfills as accumulated changes are still pushed to the target environment."
513+
)
514+
environment = c.PROD
515+
else:
516+
environment = environment or c.PROD
517+
518+
return environment.lower()
519+
520+
482521
def _build_requirements(
483522
provided_requirements: t.Dict[str, str],
484523
excluded_requirements: t.Set[str],

sqlmesh/core/plan/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(
159159
self.override_end = end is not None
160160
self.environment_naming_info = EnvironmentNamingInfo.from_environment_catalog_mapping(
161161
environment_catalog_mapping or {},
162-
name=self._context_diff.environment,
162+
name=self._context_diff.initial_environment,
163163
suffix_target=environment_suffix_target,
164164
normalize_name=self._context_diff.normalize_environment_name,
165165
gateway_managed=self._context_diff.gateway_managed_virtual_layer,

tests/core/test_integration.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datetime import timedelta
77
from unittest import mock
88
from unittest.mock import patch
9-
9+
import logging
1010
import os
1111
import numpy as np # noqa: TID253
1212
import pandas as pd # noqa: TID253
@@ -37,6 +37,7 @@
3737
from sqlmesh.core.console import Console, get_console
3838
from sqlmesh.core.context import Context
3939
from sqlmesh.core.config.categorizer import CategorizerConfig
40+
from sqlmesh.core.config.plan import PlanConfig
4041
from sqlmesh.core.engine_adapter import EngineAdapter
4142
from sqlmesh.core.environment import EnvironmentNamingInfo
4243
from sqlmesh.core.macros import macro
@@ -6252,3 +6253,107 @@ def test_render_path_instead_of_model(tmp_path: Path):
62526253

62536254
# Case 3: Render the model successfully
62546255
assert ctx.render("test_model").sql() == 'SELECT 1 AS "col"'
6256+
6257+
6258+
@use_terminal_console
6259+
def test_plan_always_compare_against_prod(mocker: MockerFixture, tmp_path: Path):
6260+
def plan_with_output(ctx: Context, environment: str):
6261+
with patch.object(logger, "info") as mock_logger:
6262+
with capture_output() as output:
6263+
ctx.load()
6264+
ctx.plan(environment, no_prompts=True, auto_apply=True)
6265+
6266+
# Facade logs info "Promoting environment {environment}"
6267+
assert mock_logger.call_args[0][1] == environment
6268+
6269+
return output
6270+
6271+
models_dir = tmp_path / "models"
6272+
6273+
logger = logging.getLogger("sqlmesh.core.state_sync.db.facade")
6274+
6275+
create_temp_file(
6276+
tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 1 AS col"
6277+
)
6278+
6279+
config = Config(plan=PlanConfig(always_compare_against_prod=True))
6280+
ctx = Context(paths=[tmp_path], config=config)
6281+
6282+
# Case 1: Neither prod nor dev exists, so dev is initialized
6283+
output = plan_with_output(ctx, "dev")
6284+
6285+
assert """`dev` environment will be initialized""" in output.stdout
6286+
6287+
# Case 2: Prod does not exist, so dev is updated
6288+
create_temp_file(
6289+
tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 5 AS col"
6290+
)
6291+
6292+
plan = ctx.plan_builder("dev").build()
6293+
6294+
assert plan.context_diff.initial_environment == "dev"
6295+
assert plan.context_diff.environment == "dev"
6296+
6297+
output = plan_with_output(ctx, "dev")
6298+
6299+
assert "Differences from the `dev` environment" in output.stdout
6300+
6301+
# Case 3: Prod is initialized, so plan comparisons moving forward should be against prod
6302+
output = plan_with_output(ctx, "prod")
6303+
6304+
assert "`prod` environment will be initialized" in output.stdout
6305+
6306+
# Case 4: Dev is updated with a breaking change, so plan comparisons moving forward should be against prod
6307+
create_temp_file(
6308+
tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 10 AS col"
6309+
)
6310+
ctx.load()
6311+
6312+
plan = ctx.plan_builder("dev").build()
6313+
6314+
assert plan.context_diff.initial_environment == "dev"
6315+
assert plan.context_diff.environment == "prod"
6316+
6317+
assert (
6318+
next(iter(plan.context_diff.snapshots.values())).change_category
6319+
== SnapshotChangeCategory.BREAKING
6320+
)
6321+
6322+
output = plan_with_output(ctx, "dev")
6323+
6324+
assert "Differences from the `prod` environment" in output.stdout
6325+
6326+
# Case 5: Dev is updated with a metadata change, but comparison against prod shows both the previous and the current changes
6327+
# so it's still classified as a breaking change
6328+
create_temp_file(
6329+
tmp_path,
6330+
models_dir / "a.sql",
6331+
"MODEL (name test.a, kind FULL, owner 'test'); SELECT 10 AS col",
6332+
)
6333+
ctx.load()
6334+
6335+
plan = ctx.plan_builder("dev").build()
6336+
6337+
assert plan.context_diff.initial_environment == "dev"
6338+
assert plan.context_diff.environment == "prod"
6339+
6340+
assert (
6341+
next(iter(plan.context_diff.snapshots.values())).change_category
6342+
== SnapshotChangeCategory.BREAKING
6343+
)
6344+
6345+
output = plan_with_output(ctx, "dev")
6346+
6347+
assert "Differences from the `prod` environment" in output.stdout
6348+
6349+
assert (
6350+
"""MODEL (
6351+
name test.a,
6352+
+ owner test,
6353+
kind FULL
6354+
)
6355+
SELECT
6356+
- 5 AS col
6357+
+ 10 AS col"""
6358+
in output.stdout
6359+
)

0 commit comments

Comments
 (0)