Skip to content

Commit 3b75e3c

Browse files
committed
PR Feedback 2, various fixes and improvements
1 parent 15a0b81 commit 3b75e3c

5 files changed

Lines changed: 60 additions & 47 deletions

File tree

sqlmesh/core/console.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def show_environment_difference_summary(
219219
self,
220220
context_diff: ContextDiff,
221221
no_diff: bool = True,
222+
environment: t.Optional[str] = None,
222223
) -> None:
223224
"""Displays a summary of differences for the environment."""
224225

@@ -645,6 +646,7 @@ def show_environment_difference_summary(
645646
self,
646647
context_diff: ContextDiff,
647648
no_diff: bool = True,
649+
environment: t.Optional[str] = None,
648650
) -> None:
649651
pass
650652

@@ -1524,18 +1526,21 @@ def show_environment_difference_summary(
15241526
self,
15251527
context_diff: ContextDiff,
15261528
no_diff: bool = True,
1529+
environment: t.Optional[str] = None,
15271530
) -> None:
15281531
"""Shows a summary of the environment differences.
15291532
15301533
Args:
15311534
context_diff: The context diff to use to print the summary
15321535
no_diff: Hide the actual environment statement differences.
1536+
environment: The initial target environment
15331537
"""
15341538
if context_diff.is_new_environment:
1539+
new_environment = environment or context_diff.environment
15351540
msg = (
1536-
f"\n`{context_diff.environment}` environment will be initialized"
1541+
f"\n`{new_environment}` environment will be initialized"
15371542
if not context_diff.create_from_env_exists
1538-
else f"\nNew environment `{context_diff.environment}` will be created from `{context_diff.create_from}`"
1543+
else f"\nNew environment `{new_environment}` will be created from `{context_diff.create_from}`"
15391544
)
15401545
self._print(Tree(f"[bold]{msg}\n"))
15411546
if not context_diff.has_snapshot_changes:
@@ -1786,6 +1791,7 @@ def _prompt_categorize(
17861791
self.show_environment_difference_summary(
17871792
plan.context_diff,
17881793
no_diff=no_diff,
1794+
environment=plan_builder.environment_naming_info.name,
17891795
)
17901796

17911797
if plan.context_diff.has_changes:
@@ -2898,18 +2904,21 @@ def show_environment_difference_summary(
28982904
self,
28992905
context_diff: ContextDiff,
29002906
no_diff: bool = True,
2907+
environment: t.Optional[str] = None,
29012908
) -> None:
29022909
"""Shows a summary of the environment differences.
29032910
29042911
Args:
29052912
context_diff: The context diff to use to print the summary.
29062913
no_diff: Hide the actual environment statements differences.
2914+
environment: The initial target environment
29072915
"""
29082916
if context_diff.is_new_environment:
2917+
new_environment = environment or context_diff.environment
29092918
msg = (
2910-
f"\n**`{context_diff.environment}` environment will be initialized**"
2919+
f"\n**`{new_environment}` environment will be initialized**"
29112920
if not context_diff.create_from_env_exists
2912-
else f"\n**New environment `{context_diff.environment}` will be created from `{context_diff.create_from}`**"
2921+
else f"\n**New environment `{new_environment}` will be created from `{context_diff.create_from}`**"
29132922
)
29142923
self._print(msg)
29152924
if not context_diff.has_snapshot_changes:
@@ -3501,6 +3510,7 @@ def show_environment_difference_summary(
35013510
self,
35023511
context_diff: ContextDiff,
35033512
no_diff: bool = True,
3513+
environment: t.Optional[str] = None,
35043514
) -> None:
35053515
self._write("Environment Difference Summary:")
35063516

sqlmesh/core/context.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,13 +1480,14 @@ def plan_builder(
14801480

14811481
snapshots = self._snapshots(models_override)
14821482
context_diff = self._context_diff(
1483-
environment=environment,
1483+
environment or c.PROD,
14841484
snapshots=snapshots,
14851485
create_from=create_from,
14861486
force_no_diff=restate_models is not None
14871487
or (backfill_models is not None and not backfill_models),
14881488
ensure_finalized_snapshots=self.config.plan.use_finalized_state,
14891489
diff_rendered=diff_rendered,
1490+
always_compare_against_prod=self.config.plan.always_compare_against_prod,
14901491
)
14911492
modified_model_names = {
14921493
*context_diff.modified_snapshots,
@@ -1520,6 +1521,7 @@ def plan_builder(
15201521

15211522
return self.PLAN_BUILDER_TYPE(
15221523
context_diff=context_diff,
1524+
environment=environment or c.PROD,
15231525
start=start,
15241526
end=end,
15251527
execution_time=execution_time,
@@ -1642,6 +1644,7 @@ def diff(self, environment: t.Optional[str] = None, detailed: bool = False) -> b
16421644
self.console.show_environment_difference_summary(
16431645
context_diff,
16441646
no_diff=not detailed,
1647+
environment=environment,
16451648
)
16461649
if context_diff.has_changes:
16471650
self.console.show_model_difference_summary(
@@ -2628,14 +2631,14 @@ def _context_diff(
26282631
force_no_diff: bool = False,
26292632
ensure_finalized_snapshots: bool = False,
26302633
diff_rendered: bool = False,
2634+
always_compare_against_prod: bool = False,
26312635
) -> ContextDiff:
26322636
environment = Environment.sanitize_name(environment)
2633-
26342637
if force_no_diff:
26352638
return ContextDiff.create_no_diff(environment, self.state_reader)
26362639

26372640
return ContextDiff.create(
2638-
environment=environment or c.PROD,
2641+
environment,
26392642
snapshots=snapshots or self.snapshots,
26402643
create_from=create_from or c.PROD,
26412644
state_reader=self.state_reader,
@@ -2646,7 +2649,7 @@ def _context_diff(
26462649
environment_statements=self._environment_statements,
26472650
gateway_managed_virtual_layer=self.config.gateway_managed_virtual_layer,
26482651
infer_python_dependencies=self.config.infer_python_dependencies,
2649-
always_compare_against_prod=self.config.plan.always_compare_against_prod,
2652+
always_compare_against_prod=always_compare_against_prod,
26502653
)
26512654

26522655
def _destroy(self) -> None:

sqlmesh/core/context_diff.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ class ContextDiff(PydanticModel):
8989
"""Environment statements."""
9090
diff_rendered: bool = False
9191
"""Whether the diff should compare raw vs rendered models"""
92-
initial_environment: str = ""
93-
"""The initial target environment (e.g 'dev'), if the plan option `always_compare_to_prod` is set"""
9492

9593
@classmethod
9694
def create(
@@ -106,7 +104,6 @@ def create(
106104
environment_statements: t.Optional[t.List[EnvironmentStatements]] = [],
107105
gateway_managed_virtual_layer: bool = False,
108106
infer_python_dependencies: bool = True,
109-
initial_environment: t.Optional[str] = None,
110107
always_compare_against_prod: bool = False,
111108
) -> ContextDiff:
112109
"""Create a ContextDiff object.
@@ -133,33 +130,34 @@ def create(
133130
The ContextDiff object.
134131
"""
135132
initial_environment = environment.lower()
136-
137-
environment = _get_target_environment(
138-
environment, state_reader, always_compare_against_prod
139-
)
140-
141-
env = state_reader.get_environment(environment)
142-
initial_env = (
143-
env
144-
if initial_environment == environment
145-
else state_reader.get_environment(initial_environment)
146-
)
133+
initial_env = state_reader.get_environment(initial_environment)
147134

148135
create_from_env_exists = False
149-
if env is None or env.expired:
150-
env = state_reader.get_environment(create_from.lower())
136+
if initial_env is None or initial_env.expired:
137+
initial_env = state_reader.get_environment(create_from.lower())
151138

152-
if not env and create_from != c.PROD:
139+
if not initial_env and create_from != c.PROD:
153140
get_console().log_warning(
154141
f"The environment name '{create_from}' was passed to the `plan` command's `--create-from` argument, but '{create_from}' does not exist. Initializing new environment '{environment}' from scratch."
155142
)
156143

157144
is_new_environment = True
158-
create_from_env_exists = env is not None
145+
create_from_env_exists = initial_env is not None
159146
previously_promoted_snapshot_ids = set()
160147
else:
161148
is_new_environment = False
162-
previously_promoted_snapshot_ids = {s.snapshot_id for s in env.promoted_snapshots}
149+
previously_promoted_snapshot_ids = {
150+
s.snapshot_id for s in initial_env.promoted_snapshots
151+
}
152+
153+
# Find the proper environment to diff against, this might be different than the "initial" (i.e user provided) environment
154+
# e.g it will default to prod if the plan option `always_compare_against_prod` is set.
155+
environment = _get_diff_environment(environment, state_reader, always_compare_against_prod)
156+
env = (
157+
initial_env
158+
if (initial_environment == environment)
159+
else state_reader.get_environment(environment)
160+
)
163161

164162
environment_snapshot_infos = []
165163
if env:
@@ -237,7 +235,6 @@ def create(
237235

238236
return ContextDiff(
239237
environment=environment,
240-
initial_environment=initial_environment,
241238
is_new_environment=is_new_environment,
242239
is_unfinalized_environment=bool(env and not env.finalized_ts),
243240
normalize_environment_name=is_new_environment or bool(env and env.normalize_name),
@@ -279,9 +276,8 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD
279276

280277
snapshots = state_reader.get_snapshots(env.snapshots)
281278

282-
environment = env.name
283279
return ContextDiff(
284-
environment=environment,
280+
environment=env.name,
285281
is_new_environment=False,
286282
is_unfinalized_environment=False,
287283
normalize_environment_name=env.normalize_name,
@@ -300,7 +296,6 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD
300296
previous_environment_statements=[],
301297
previous_gateway_managed_virtual_layer=env.gateway_managed,
302298
gateway_managed_virtual_layer=env.gateway_managed,
303-
initial_environment=environment,
304299
)
305300

306301
@property
@@ -499,7 +494,7 @@ def text_diff(self, name: str) -> str:
499494
return ""
500495

501496

502-
def _get_target_environment(
497+
def _get_diff_environment(
503498
environment: str, state_reader: StateReader, always_compare_against_prod: bool = False
504499
) -> str:
505500
if always_compare_against_prod:

sqlmesh/core/plan/builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import defaultdict
77
from functools import cached_property
88

9-
9+
from sqlmesh.core import constants as c
1010
from sqlmesh.core.console import PlanBuilderConsole, get_console
1111
from sqlmesh.core.config import (
1212
AutoCategorizationMode,
@@ -117,6 +117,7 @@ def __init__(
117117
interval_end_per_model: t.Optional[t.Dict[str, int]] = None,
118118
console: t.Optional[PlanBuilderConsole] = None,
119119
user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None,
120+
environment: str = c.PROD,
120121
):
121122
self._context_diff = context_diff
122123
self._no_gaps = no_gaps
@@ -159,7 +160,7 @@ def __init__(
159160
self.override_end = end is not None
160161
self.environment_naming_info = EnvironmentNamingInfo.from_environment_catalog_mapping(
161162
environment_catalog_mapping or {},
162-
name=self._context_diff.initial_environment,
163+
name=environment,
163164
suffix_target=environment_suffix_target,
164165
normalize_name=self._context_diff.normalize_environment_name,
165166
gateway_managed=self._context_diff.gateway_managed_virtual_layer,

tests/core/test_integration.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6268,6 +6268,11 @@ def plan_with_output(ctx: Context, environment: str):
62686268

62696269
return output
62706270

6271+
def assert_environments(ctx: Context, input_env: str, promote_env: str, diff_env: str):
6272+
plan_builder = ctx.plan_builder(input_env)
6273+
assert plan_builder.environment_naming_info.name == promote_env
6274+
assert plan_builder.build().context_diff.environment == diff_env
6275+
62716276
models_dir = tmp_path / "models"
62726277

62736278
logger = logging.getLogger("sqlmesh.core.state_sync.db.facade")
@@ -6283,44 +6288,40 @@ def plan_with_output(ctx: Context, environment: str):
62836288
output = plan_with_output(ctx, "dev")
62846289

62856290
assert """`dev` environment will be initialized""" in output.stdout
6291+
assert_environments(ctx, input_env="dev", promote_env="dev", diff_env="dev")
62866292

62876293
# Case 2: Prod does not exist, so dev is updated
62886294
create_temp_file(
62896295
tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 5 AS col"
62906296
)
62916297

6292-
plan = ctx.plan_builder("dev").build()
6293-
6294-
assert plan.context_diff.initial_environment == "dev"
6295-
assert plan.context_diff.environment == "dev"
6296-
62976298
output = plan_with_output(ctx, "dev")
62986299

6300+
assert_environments(ctx, input_env="dev", promote_env="dev", diff_env="dev")
62996301
assert "Differences from the `dev` environment" in output.stdout
63006302

63016303
# Case 3: Prod is initialized, so plan comparisons moving forward should be against prod
63026304
output = plan_with_output(ctx, "prod")
6303-
63046305
assert "`prod` environment will be initialized" in output.stdout
63056306

6306-
# Case 4: Dev is updated with a breaking change, so plan comparisons moving forward should be against prod
6307+
assert_environments(ctx, input_env="prod", promote_env="prod", diff_env="prod")
6308+
6309+
# Case 4: Dev is updated with a breaking change. Prod exists now so plan comparisons moving forward should be against prod
63076310
create_temp_file(
63086311
tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 10 AS col"
63096312
)
63106313
ctx.load()
63116314

63126315
plan = ctx.plan_builder("dev").build()
63136316

6314-
assert plan.context_diff.initial_environment == "dev"
6315-
assert plan.context_diff.environment == "prod"
6317+
assert_environments(ctx, input_env="dev", promote_env="dev", diff_env="prod")
63166318

63176319
assert (
63186320
next(iter(plan.context_diff.snapshots.values())).change_category
63196321
== SnapshotChangeCategory.BREAKING
63206322
)
63216323

63226324
output = plan_with_output(ctx, "dev")
6323-
63246325
assert "Differences from the `prod` environment" in output.stdout
63256326

63266327
# Case 5: Dev is updated with a metadata change, but comparison against prod shows both the previous and the current changes
@@ -6334,16 +6335,14 @@ def plan_with_output(ctx: Context, environment: str):
63346335

63356336
plan = ctx.plan_builder("dev").build()
63366337

6337-
assert plan.context_diff.initial_environment == "dev"
6338-
assert plan.context_diff.environment == "prod"
6338+
assert_environments(ctx, input_env="dev", promote_env="dev", diff_env="prod")
63396339

63406340
assert (
63416341
next(iter(plan.context_diff.snapshots.values())).change_category
63426342
== SnapshotChangeCategory.BREAKING
63436343
)
63446344

63456345
output = plan_with_output(ctx, "dev")
6346-
63476346
assert "Differences from the `prod` environment" in output.stdout
63486347

63496348
assert (
@@ -6357,3 +6356,8 @@ def plan_with_output(ctx: Context, environment: str):
63576356
+ 10 AS col"""
63586357
in output.stdout
63596358
)
6359+
6360+
# Case 6: Check that we can still run Context::diff() against any environment
6361+
for environment in ["dev", "prod"]:
6362+
context_diff = ctx._context_diff(environment)
6363+
assert context_diff.environment == environment

0 commit comments

Comments
 (0)