Skip to content

Commit 9475cde

Browse files
committed
PR Feedback 2, various fixes and improvements
1 parent 6e08b44 commit 9475cde

5 files changed

Lines changed: 60 additions & 47 deletions

File tree

sqlmesh/core/console.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def show_environment_difference_summary(
221221
self,
222222
context_diff: ContextDiff,
223223
no_diff: bool = True,
224+
environment: t.Optional[str] = None,
224225
) -> None:
225226
"""Displays a summary of differences for the environment."""
226227

@@ -647,6 +648,7 @@ def show_environment_difference_summary(
647648
self,
648649
context_diff: ContextDiff,
649650
no_diff: bool = True,
651+
environment: t.Optional[str] = None,
650652
) -> None:
651653
pass
652654

@@ -1526,18 +1528,21 @@ def show_environment_difference_summary(
15261528
self,
15271529
context_diff: ContextDiff,
15281530
no_diff: bool = True,
1531+
environment: t.Optional[str] = None,
15291532
) -> None:
15301533
"""Shows a summary of the environment differences.
15311534
15321535
Args:
15331536
context_diff: The context diff to use to print the summary
15341537
no_diff: Hide the actual environment statement differences.
1538+
environment: The initial target environment
15351539
"""
15361540
if context_diff.is_new_environment:
1541+
new_environment = environment or context_diff.environment
15371542
msg = (
1538-
f"\n`{context_diff.environment}` environment will be initialized"
1543+
f"\n`{new_environment}` environment will be initialized"
15391544
if not context_diff.create_from_env_exists
1540-
else f"\nNew environment `{context_diff.environment}` will be created from `{context_diff.create_from}`"
1545+
else f"\nNew environment `{new_environment}` will be created from `{context_diff.create_from}`"
15411546
)
15421547
self._print(Tree(f"[bold]{msg}\n"))
15431548
if not context_diff.has_snapshot_changes:
@@ -1788,6 +1793,7 @@ def _prompt_categorize(
17881793
self.show_environment_difference_summary(
17891794
plan.context_diff,
17901795
no_diff=no_diff,
1796+
environment=plan_builder.environment_naming_info.name,
17911797
)
17921798

17931799
if plan.context_diff.has_changes:
@@ -2898,18 +2904,21 @@ def show_environment_difference_summary(
28982904
self,
28992905
context_diff: ContextDiff,
29002906
no_diff: bool = True,
2907+
environment: t.Optional[str] = None,
29012908
) -> None:
29022909
"""Shows a summary of the environment differences.
29032910
29042911
Args:
29052912
context_diff: The context diff to use to print the summary.
29062913
no_diff: Hide the actual environment statements differences.
2914+
environment: The initial target environment
29072915
"""
29082916
if context_diff.is_new_environment:
2917+
new_environment = environment or context_diff.environment
29092918
msg = (
2910-
f"\n**`{context_diff.environment}` environment will be initialized**"
2919+
f"\n**`{new_environment}` environment will be initialized**"
29112920
if not context_diff.create_from_env_exists
2912-
else f"\n**New environment `{context_diff.environment}` will be created from `{context_diff.create_from}`**"
2921+
else f"\n**New environment `{new_environment}` will be created from `{context_diff.create_from}`**"
29132922
)
29142923
self._print(msg)
29152924
if not context_diff.has_snapshot_changes:
@@ -3495,6 +3504,7 @@ def show_environment_difference_summary(
34953504
self,
34963505
context_diff: ContextDiff,
34973506
no_diff: bool = True,
3507+
environment: t.Optional[str] = None,
34983508
) -> None:
34993509
self._write("Environment Difference Summary:")
35003510

sqlmesh/core/context.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,13 +1477,14 @@ def plan_builder(
14771477

14781478
snapshots = self._snapshots(models_override)
14791479
context_diff = self._context_diff(
1480-
environment=environment,
1480+
environment or c.PROD,
14811481
snapshots=snapshots,
14821482
create_from=create_from,
14831483
force_no_diff=restate_models is not None
14841484
or (backfill_models is not None and not backfill_models),
14851485
ensure_finalized_snapshots=self.config.plan.use_finalized_state,
14861486
diff_rendered=diff_rendered,
1487+
always_compare_against_prod=self.config.plan.always_compare_against_prod,
14871488
)
14881489
modified_model_names = {
14891490
*context_diff.modified_snapshots,
@@ -1517,6 +1518,7 @@ def plan_builder(
15171518

15181519
return self.PLAN_BUILDER_TYPE(
15191520
context_diff=context_diff,
1521+
environment=environment or c.PROD,
15201522
start=start,
15211523
end=end,
15221524
execution_time=execution_time,
@@ -1628,6 +1630,7 @@ def diff(self, environment: t.Optional[str] = None, detailed: bool = False) -> b
16281630
self.console.show_environment_difference_summary(
16291631
context_diff,
16301632
no_diff=not detailed,
1633+
environment=environment,
16311634
)
16321635
if context_diff.has_changes:
16331636
self.console.show_model_difference_summary(
@@ -2607,14 +2610,14 @@ def _context_diff(
26072610
force_no_diff: bool = False,
26082611
ensure_finalized_snapshots: bool = False,
26092612
diff_rendered: bool = False,
2613+
always_compare_against_prod: bool = False,
26102614
) -> ContextDiff:
26112615
environment = Environment.sanitize_name(environment)
2612-
26132616
if force_no_diff:
26142617
return ContextDiff.create_no_diff(environment, self.state_reader)
26152618

26162619
return ContextDiff.create(
2617-
environment=environment or c.PROD,
2620+
environment,
26182621
snapshots=snapshots or self.snapshots,
26192622
create_from=create_from or c.PROD,
26202623
state_reader=self.state_reader,
@@ -2625,7 +2628,7 @@ def _context_diff(
26252628
environment_statements=self._environment_statements,
26262629
gateway_managed_virtual_layer=self.config.gateway_managed_virtual_layer,
26272630
infer_python_dependencies=self.config.infer_python_dependencies,
2628-
always_compare_against_prod=self.config.plan.always_compare_against_prod,
2631+
always_compare_against_prod=always_compare_against_prod,
26292632
)
26302633

26312634
def _destroy(self) -> None:

sqlmesh/core/context_diff.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ class ContextDiff(PydanticModel):
8989
"""Environment statements."""
9090
diff_rendered: bool = False
9191
"""Whether the diff should compare raw vs rendered models"""
92-
initial_environment: str = ""
93-
"""The initial target environment (e.g 'dev'), if the plan option `always_compare_to_prod` is set"""
9492

9593
@classmethod
9694
def create(
@@ -106,7 +104,6 @@ def create(
106104
environment_statements: t.Optional[t.List[EnvironmentStatements]] = [],
107105
gateway_managed_virtual_layer: bool = False,
108106
infer_python_dependencies: bool = True,
109-
initial_environment: t.Optional[str] = None,
110107
always_compare_against_prod: bool = False,
111108
) -> ContextDiff:
112109
"""Create a ContextDiff object.
@@ -133,33 +130,34 @@ def create(
133130
The ContextDiff object.
134131
"""
135132
initial_environment = environment.lower()
136-
137-
environment = _get_target_environment(
138-
environment, state_reader, always_compare_against_prod
139-
)
140-
141-
env = state_reader.get_environment(environment)
142-
initial_env = (
143-
env
144-
if initial_environment == environment
145-
else state_reader.get_environment(initial_environment)
146-
)
133+
initial_env = state_reader.get_environment(initial_environment)
147134

148135
create_from_env_exists = False
149-
if env is None or env.expired:
150-
env = state_reader.get_environment(create_from.lower())
136+
if initial_env is None or initial_env.expired:
137+
initial_env = state_reader.get_environment(create_from.lower())
151138

152-
if not env and create_from != c.PROD:
139+
if not initial_env and create_from != c.PROD:
153140
get_console().log_warning(
154141
f"The environment name '{create_from}' was passed to the `plan` command's `--create-from` argument, but '{create_from}' does not exist. Initializing new environment '{environment}' from scratch."
155142
)
156143

157144
is_new_environment = True
158-
create_from_env_exists = env is not None
145+
create_from_env_exists = initial_env is not None
159146
previously_promoted_snapshot_ids = set()
160147
else:
161148
is_new_environment = False
162-
previously_promoted_snapshot_ids = {s.snapshot_id for s in env.promoted_snapshots}
149+
previously_promoted_snapshot_ids = {
150+
s.snapshot_id for s in initial_env.promoted_snapshots
151+
}
152+
153+
# Find the proper environment to diff against, this might be different than the "initial" (i.e user provided) environment
154+
# e.g it will default to prod if the plan option `always_compare_against_prod` is set.
155+
environment = _get_diff_environment(environment, state_reader, always_compare_against_prod)
156+
env = (
157+
initial_env
158+
if (initial_environment == environment)
159+
else state_reader.get_environment(environment)
160+
)
163161

164162
environment_snapshot_infos = []
165163
if env:
@@ -237,7 +235,6 @@ def create(
237235

238236
return ContextDiff(
239237
environment=environment,
240-
initial_environment=initial_environment,
241238
is_new_environment=is_new_environment,
242239
is_unfinalized_environment=bool(env and not env.finalized_ts),
243240
normalize_environment_name=is_new_environment or bool(env and env.normalize_name),
@@ -279,9 +276,8 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD
279276

280277
snapshots = state_reader.get_snapshots(env.snapshots)
281278

282-
environment = env.name
283279
return ContextDiff(
284-
environment=environment,
280+
environment=env.name,
285281
is_new_environment=False,
286282
is_unfinalized_environment=False,
287283
normalize_environment_name=env.normalize_name,
@@ -300,7 +296,6 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD
300296
previous_environment_statements=[],
301297
previous_gateway_managed_virtual_layer=env.gateway_managed,
302298
gateway_managed_virtual_layer=env.gateway_managed,
303-
initial_environment=environment,
304299
)
305300

306301
@property
@@ -499,7 +494,7 @@ def text_diff(self, name: str) -> str:
499494
return ""
500495

501496

502-
def _get_target_environment(
497+
def _get_diff_environment(
503498
environment: str, state_reader: StateReader, always_compare_against_prod: bool = False
504499
) -> str:
505500
if always_compare_against_prod:

sqlmesh/core/plan/builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import defaultdict
77
from functools import cached_property
88

9-
9+
from sqlmesh.core import constants as c
1010
from sqlmesh.core.console import PlanBuilderConsole, get_console
1111
from sqlmesh.core.config import (
1212
AutoCategorizationMode,
@@ -115,6 +115,7 @@ def __init__(
115115
interval_end_per_model: t.Optional[t.Dict[str, int]] = None,
116116
console: t.Optional[PlanBuilderConsole] = None,
117117
user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None,
118+
environment: str = c.PROD,
118119
):
119120
self._context_diff = context_diff
120121
self._no_gaps = no_gaps
@@ -156,7 +157,7 @@ def __init__(
156157
self.override_end = end is not None
157158
self.environment_naming_info = EnvironmentNamingInfo.from_environment_catalog_mapping(
158159
environment_catalog_mapping or {},
159-
name=self._context_diff.initial_environment,
160+
name=environment,
160161
suffix_target=environment_suffix_target,
161162
normalize_name=self._context_diff.normalize_environment_name,
162163
gateway_managed=self._context_diff.gateway_managed_virtual_layer,

tests/core/test_integration.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6224,6 +6224,11 @@ def plan_with_output(ctx: Context, environment: str):
62246224

62256225
return output
62266226

6227+
def assert_environments(ctx: Context, input_env: str, promote_env: str, diff_env: str):
6228+
plan_builder = ctx.plan_builder(input_env)
6229+
assert plan_builder.environment_naming_info.name == promote_env
6230+
assert plan_builder.build().context_diff.environment == diff_env
6231+
62276232
models_dir = tmp_path / "models"
62286233

62296234
logger = logging.getLogger("sqlmesh.core.state_sync.db.facade")
@@ -6239,44 +6244,40 @@ def plan_with_output(ctx: Context, environment: str):
62396244
output = plan_with_output(ctx, "dev")
62406245

62416246
assert """`dev` environment will be initialized""" in output.stdout
6247+
assert_environments(ctx, input_env="dev", promote_env="dev", diff_env="dev")
62426248

62436249
# Case 2: Prod does not exist, so dev is updated
62446250
create_temp_file(
62456251
tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 5 AS col"
62466252
)
62476253

6248-
plan = ctx.plan_builder("dev").build()
6249-
6250-
assert plan.context_diff.initial_environment == "dev"
6251-
assert plan.context_diff.environment == "dev"
6252-
62536254
output = plan_with_output(ctx, "dev")
62546255

6256+
assert_environments(ctx, input_env="dev", promote_env="dev", diff_env="dev")
62556257
assert "Differences from the `dev` environment" in output.stdout
62566258

62576259
# Case 3: Prod is initialized, so plan comparisons moving forward should be against prod
62586260
output = plan_with_output(ctx, "prod")
6259-
62606261
assert "`prod` environment will be initialized" in output.stdout
62616262

6262-
# Case 4: Dev is updated with a breaking change, so plan comparisons moving forward should be against prod
6263+
assert_environments(ctx, input_env="prod", promote_env="prod", diff_env="prod")
6264+
6265+
# Case 4: Dev is updated with a breaking change. Prod exists now so plan comparisons moving forward should be against prod
62636266
create_temp_file(
62646267
tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 10 AS col"
62656268
)
62666269
ctx.load()
62676270

62686271
plan = ctx.plan_builder("dev").build()
62696272

6270-
assert plan.context_diff.initial_environment == "dev"
6271-
assert plan.context_diff.environment == "prod"
6273+
assert_environments(ctx, input_env="dev", promote_env="dev", diff_env="prod")
62726274

62736275
assert (
62746276
next(iter(plan.context_diff.snapshots.values())).change_category
62756277
== SnapshotChangeCategory.BREAKING
62766278
)
62776279

62786280
output = plan_with_output(ctx, "dev")
6279-
62806281
assert "Differences from the `prod` environment" in output.stdout
62816282

62826283
# Case 5: Dev is updated with a metadata change, but comparison against prod shows both the previous and the current changes
@@ -6290,16 +6291,14 @@ def plan_with_output(ctx: Context, environment: str):
62906291

62916292
plan = ctx.plan_builder("dev").build()
62926293

6293-
assert plan.context_diff.initial_environment == "dev"
6294-
assert plan.context_diff.environment == "prod"
6294+
assert_environments(ctx, input_env="dev", promote_env="dev", diff_env="prod")
62956295

62966296
assert (
62976297
next(iter(plan.context_diff.snapshots.values())).change_category
62986298
== SnapshotChangeCategory.BREAKING
62996299
)
63006300

63016301
output = plan_with_output(ctx, "dev")
6302-
63036302
assert "Differences from the `prod` environment" in output.stdout
63046303

63056304
assert (
@@ -6313,3 +6312,8 @@ def plan_with_output(ctx: Context, environment: str):
63136312
+ 10 AS col"""
63146313
in output.stdout
63156314
)
6315+
6316+
# Case 6: Check that we can still run Context::diff() against any environment
6317+
for environment in ["dev", "prod"]:
6318+
context_diff = ctx._context_diff(environment)
6319+
assert context_diff.environment == environment

0 commit comments

Comments
 (0)