Skip to content

Commit dc285df

Browse files
committed
Feat: Add plan option to always compare against prod
1 parent 2f8e3b7 commit dc285df

5 files changed

Lines changed: 155 additions & 7 deletions

File tree

sqlmesh/core/config/plan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class PlanConfig(BaseConfig):
2020
auto_apply: Whether to automatically apply the new plan after creation.
2121
use_finalized_state: Whether to compare against the latest finalized environment state, or to use
2222
whatever state the target environment is currently in.
23+
always_compare_against_prod: Whether to always compare against production when planning, even if the target environment exists.
2324
"""
2425

2526
forward_only: bool = False
@@ -30,3 +31,4 @@ class PlanConfig(BaseConfig):
3031
no_prompts: bool = True
3132
auto_apply: bool = False
3233
use_finalized_state: bool = False
34+
always_compare_against_prod: bool = False

sqlmesh/core/context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,7 +1477,7 @@ def plan_builder(
14771477

14781478
snapshots = self._snapshots(models_override)
14791479
context_diff = self._context_diff(
1480-
environment or c.PROD,
1480+
environment=environment,
14811481
snapshots=snapshots,
14821482
create_from=create_from,
14831483
force_no_diff=restate_models is not None
@@ -2609,11 +2609,12 @@ def _context_diff(
26092609
diff_rendered: bool = False,
26102610
) -> ContextDiff:
26112611
environment = Environment.sanitize_name(environment)
2612+
26122613
if force_no_diff:
26132614
return ContextDiff.create_no_diff(environment, self.state_reader)
26142615

26152616
return ContextDiff.create(
2616-
environment,
2617+
environment=environment,
26172618
snapshots=snapshots or self.snapshots,
26182619
create_from=create_from or c.PROD,
26192620
state_reader=self.state_reader,
@@ -2624,6 +2625,7 @@ def _context_diff(
26242625
environment_statements=self._environment_statements,
26252626
gateway_managed_virtual_layer=self.config.gateway_managed_virtual_layer,
26262627
infer_python_dependencies=self.config.infer_python_dependencies,
2628+
always_compare_against_prod=self.config.plan.always_compare_against_prod,
26272629
)
26282630

26292631
def _destroy(self) -> None:

sqlmesh/core/context_diff.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import sys
1616
import typing as t
17+
import logging
18+
1719
from difflib import ndiff, unified_diff
1820
from functools import cached_property
1921
from sqlmesh.core import constants as c
@@ -38,6 +40,8 @@
3840

3941
IGNORED_PACKAGES = {"sqlmesh", "sqlglot"}
4042

43+
logger = logging.getLogger(__name__)
44+
4145

4246
class ContextDiff(PydanticModel):
4347
"""ContextDiff is an object representing the difference between two environments.
@@ -88,6 +92,8 @@ class ContextDiff(PydanticModel):
8892
"""Environment statements."""
8993
diff_rendered: bool = False
9094
"""Whether the diff should compare raw vs rendered models"""
95+
initial_environment: str = ""
96+
"""The initial target environment (e.g 'dev'), if the plan option `always_compare_to_prod` is set"""
9197

9298
@classmethod
9399
def create(
@@ -103,6 +109,8 @@ def create(
103109
environment_statements: t.Optional[t.List[EnvironmentStatements]] = [],
104110
gateway_managed_virtual_layer: bool = False,
105111
infer_python_dependencies: bool = True,
112+
initial_environment: t.Optional[str] = None,
113+
always_compare_against_prod: bool = False,
106114
) -> ContextDiff:
107115
"""Create a ContextDiff object.
108116
@@ -127,8 +135,17 @@ def create(
127135
Returns:
128136
The ContextDiff object.
129137
"""
130-
environment = environment.lower()
138+
initial_environment = environment
139+
environment = _get_target_environment(
140+
environment, state_reader, always_compare_against_prod
141+
)
142+
131143
env = state_reader.get_environment(environment)
144+
initial_env = (
145+
env
146+
if initial_environment == environment
147+
else state_reader.get_environment(initial_environment)
148+
)
132149

133150
create_from_env_exists = False
134151
if env is None or env.expired:
@@ -222,6 +239,7 @@ def create(
222239

223240
return ContextDiff(
224241
environment=environment,
242+
initial_environment=initial_environment,
225243
is_new_environment=is_new_environment,
226244
is_unfinalized_environment=bool(env and not env.finalized_ts),
227245
normalize_environment_name=is_new_environment or bool(env and env.normalize_name),
@@ -232,7 +250,9 @@ def create(
232250
modified_snapshots=modified_snapshots,
233251
snapshots=merged_snapshots,
234252
new_snapshots=new_snapshots,
235-
previous_plan_id=env.plan_id if env and not is_new_environment else None,
253+
previous_plan_id=initial_env.plan_id
254+
if initial_env and not is_new_environment
255+
else None,
236256
previously_promoted_snapshot_ids=previously_promoted_snapshot_ids,
237257
previous_finalized_snapshots=env.previous_finalized_snapshots if env else None,
238258
previous_requirements=env.requirements if env else {},
@@ -261,8 +281,9 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD
261281

262282
snapshots = state_reader.get_snapshots(env.snapshots)
263283

284+
environment = env.name
264285
return ContextDiff(
265-
environment=env.name,
286+
environment=environment,
266287
is_new_environment=False,
267288
is_unfinalized_environment=False,
268289
normalize_environment_name=env.normalize_name,
@@ -281,6 +302,7 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD
281302
previous_environment_statements=[],
282303
previous_gateway_managed_virtual_layer=env.gateway_managed,
283304
gateway_managed_virtual_layer=env.gateway_managed,
305+
initial_environment=environment,
284306
)
285307

286308
@property
@@ -479,6 +501,23 @@ def text_diff(self, name: str) -> str:
479501
return ""
480502

481503

504+
def _get_target_environment(
505+
environment: str, state_reader: StateReader, always_compare_against_prod: bool = False
506+
) -> str:
507+
if always_compare_against_prod:
508+
prod = state_reader.get_environment(c.PROD)
509+
if prod:
510+
logger.warning(
511+
f"Comparing against production environment instead of {environment}. Note that this may lead to "
512+
"additional backfills as accumulated changes are still pushed to the target environment."
513+
)
514+
environment = c.PROD
515+
else:
516+
environment = environment or c.PROD
517+
518+
return environment.lower()
519+
520+
482521
def _build_requirements(
483522
provided_requirements: t.Dict[str, str],
484523
excluded_requirements: t.Set[str],

sqlmesh/core/plan/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def __init__(
156156
self.override_end = end is not None
157157
self.environment_naming_info = EnvironmentNamingInfo.from_environment_catalog_mapping(
158158
environment_catalog_mapping or {},
159-
name=self._context_diff.environment,
159+
name=self._context_diff.initial_environment,
160160
suffix_target=environment_suffix_target,
161161
normalize_name=self._context_diff.normalize_environment_name,
162162
gateway_managed=self._context_diff.gateway_managed_virtual_layer,

tests/core/test_integration.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from datetime import timedelta
77
from unittest import mock
88
from unittest.mock import patch
9-
9+
import logging
1010
import os
1111
import numpy as np
1212
import pandas as pd
@@ -36,6 +36,7 @@
3636
from sqlmesh.core.console import Console, get_console
3737
from sqlmesh.core.context import Context
3838
from sqlmesh.core.config.categorizer import CategorizerConfig
39+
from sqlmesh.core.config.plan import PlanConfig
3940
from sqlmesh.core.engine_adapter import EngineAdapter
4041
from sqlmesh.core.environment import EnvironmentNamingInfo
4142
from sqlmesh.core.macros import macro
@@ -6208,3 +6209,107 @@ def test_render_path_instead_of_model(tmp_path: Path):
62086209

62096210
# Case 3: Render the model successfully
62106211
assert ctx.render("test_model").sql() == 'SELECT 1 AS "col"'
6212+
6213+
6214+
@use_terminal_console
6215+
def test_plan_always_compare_against_prod(mocker: MockerFixture, tmp_path: Path):
6216+
def plan_with_output(ctx: Context, environment: str):
6217+
with patch.object(logger, "info") as mock_logger:
6218+
with capture_output() as output:
6219+
ctx.load()
6220+
ctx.plan(environment, no_prompts=True, auto_apply=True)
6221+
6222+
# Facade logs info "Promoting environment {environment}"
6223+
assert mock_logger.call_args[0][1] == environment
6224+
6225+
return output
6226+
6227+
models_dir = tmp_path / "models"
6228+
6229+
logger = logging.getLogger("sqlmesh.core.state_sync.db.facade")
6230+
6231+
create_temp_file(
6232+
tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 1 AS col"
6233+
)
6234+
6235+
config = Config(plan=PlanConfig(always_compare_against_prod=True))
6236+
ctx = Context(paths=[tmp_path], config=config)
6237+
6238+
# Case 1: Neither prod nor dev exists, so dev is initialized
6239+
output = plan_with_output(ctx, "dev")
6240+
6241+
assert """`dev` environment will be initialized""" in output.stdout
6242+
6243+
# Case 2: Prod does not exist, so dev is updated
6244+
create_temp_file(
6245+
tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 5 AS col"
6246+
)
6247+
6248+
plan = ctx.plan_builder("dev").build()
6249+
6250+
assert plan.context_diff.initial_environment == "dev"
6251+
assert plan.context_diff.environment == "dev"
6252+
6253+
output = plan_with_output(ctx, "dev")
6254+
6255+
assert "Differences from the `dev` environment" in output.stdout
6256+
6257+
# Case 3: Prod is initialized, so plan comparisons moving forward should be against prod
6258+
output = plan_with_output(ctx, "prod")
6259+
6260+
assert "`prod` environment will be initialized" in output.stdout
6261+
6262+
# Case 4: Dev is updated with a breaking change, so plan comparisons moving forward should be against prod
6263+
create_temp_file(
6264+
tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 10 AS col"
6265+
)
6266+
ctx.load()
6267+
6268+
plan = ctx.plan_builder("dev").build()
6269+
6270+
assert plan.context_diff.initial_environment == "dev"
6271+
assert plan.context_diff.environment == "prod"
6272+
6273+
assert (
6274+
next(iter(plan.context_diff.snapshots.values())).change_category
6275+
== SnapshotChangeCategory.BREAKING
6276+
)
6277+
6278+
output = plan_with_output(ctx, "dev")
6279+
6280+
assert "Differences from the `prod` environment" in output.stdout
6281+
6282+
# Case 5: Dev is updated with a metadata change, but comparison against prod shows both the previous and the current changes
6283+
# so it's still classified as a breaking change
6284+
create_temp_file(
6285+
tmp_path,
6286+
models_dir / "a.sql",
6287+
"MODEL (name test.a, kind FULL, owner 'test'); SELECT 10 AS col",
6288+
)
6289+
ctx.load()
6290+
6291+
plan = ctx.plan_builder("dev").build()
6292+
6293+
assert plan.context_diff.initial_environment == "dev"
6294+
assert plan.context_diff.environment == "prod"
6295+
6296+
assert (
6297+
next(iter(plan.context_diff.snapshots.values())).change_category
6298+
== SnapshotChangeCategory.BREAKING
6299+
)
6300+
6301+
output = plan_with_output(ctx, "dev")
6302+
6303+
assert "Differences from the `prod` environment" in output.stdout
6304+
6305+
assert (
6306+
"""MODEL (
6307+
name test.a,
6308+
+ owner test,
6309+
kind FULL
6310+
)
6311+
SELECT
6312+
- 5 AS col
6313+
+ 10 AS col"""
6314+
in output.stdout
6315+
)

0 commit comments

Comments
 (0)