Skip to content

Commit 75a357e

Browse files
committed
Move get_target_environment to ContextDiff
1 parent 62d1f7e commit 75a357e

2 files changed

Lines changed: 30 additions & 25 deletions

File tree

sqlmesh/core/context.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2593,25 +2593,6 @@ def _snapshots(
25932593

25942594
return {name: stored_snapshots.get(s.snapshot_id, s) for name, s in snapshots.items()}
25952595

2596-
def _get_target_environment(self, environment: t.Optional[str] = None) -> t.Tuple[str, str]:
2597-
environment = environment or self.config.default_target_environment
2598-
environment = Environment.sanitize_name(environment)
2599-
2600-
initial_environment = environment
2601-
2602-
if self.config.plan.always_compare_against_prod:
2603-
prod = self.state_reader.get_environment(c.PROD)
2604-
if prod:
2605-
logger.warning(
2606-
f"Comparing against production environment instead of {environment}. Note that this may lead to "
2607-
"additional backfills as accumulated changes are still pushed to the target environment."
2608-
)
2609-
environment = c.PROD
2610-
else:
2611-
environment = environment or c.PROD
2612-
2613-
return environment.lower(), initial_environment.lower()
2614-
26152596
def _context_diff(
26162597
self,
26172598
environment: str,
@@ -2621,13 +2602,13 @@ def _context_diff(
26212602
ensure_finalized_snapshots: bool = False,
26222603
diff_rendered: bool = False,
26232604
) -> ContextDiff:
2624-
target_environment, initial_environment = self._get_target_environment(environment)
2605+
environment = Environment.sanitize_name(environment)
26252606

26262607
if force_no_diff:
26272608
return ContextDiff.create_no_diff(environment, self.state_reader)
26282609

26292610
return ContextDiff.create(
2630-
environment=target_environment,
2611+
environment=environment,
26312612
snapshots=snapshots or self.snapshots,
26322613
create_from=create_from or c.PROD,
26332614
state_reader=self.state_reader,
@@ -2638,7 +2619,7 @@ def _context_diff(
26382619
environment_statements=self._environment_statements,
26392620
gateway_managed_virtual_layer=self.config.gateway_managed_virtual_layer,
26402621
infer_python_dependencies=self.config.infer_python_dependencies,
2641-
initial_environment=initial_environment,
2622+
always_compare_against_prod=self.config.plan.always_compare_against_prod,
26422623
)
26432624

26442625
def _destroy(self) -> None:

sqlmesh/core/context_diff.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import sys
1616
import typing as t
17+
import logging
18+
1719
from difflib import ndiff, unified_diff
1820
from functools import cached_property
1921
from sqlmesh.core import constants as c
@@ -38,6 +40,8 @@
3840

3941
IGNORED_PACKAGES = {"sqlmesh", "sqlglot"}
4042

43+
logger = logging.getLogger(__name__)
44+
4145

4246
class ContextDiff(PydanticModel):
4347
"""ContextDiff is an object representing the difference between two environments.
@@ -106,6 +110,7 @@ def create(
106110
gateway_managed_virtual_layer: bool = False,
107111
infer_python_dependencies: bool = True,
108112
initial_environment: t.Optional[str] = None,
113+
always_compare_against_prod: bool = False,
109114
) -> ContextDiff:
110115
"""Create a ContextDiff object.
111116
@@ -130,10 +135,12 @@ def create(
130135
Returns:
131136
The ContextDiff object.
132137
"""
133-
environment = environment.lower()
134-
env = state_reader.get_environment(environment)
138+
initial_environment = environment
139+
environment = _get_target_environment(
140+
environment, state_reader, always_compare_against_prod
141+
)
135142

136-
initial_environment = initial_environment or environment
143+
env = state_reader.get_environment(environment)
137144
initial_env = (
138145
env
139146
if initial_environment == environment
@@ -492,6 +499,23 @@ def text_diff(self, name: str) -> str:
492499
return ""
493500

494501

502+
def _get_target_environment(
503+
environment: str, state_reader: StateReader, always_compare_against_prod: bool = False
504+
) -> str:
505+
if always_compare_against_prod:
506+
prod = state_reader.get_environment(c.PROD)
507+
if prod:
508+
logger.warning(
509+
f"Comparing against production environment instead of {environment}. Note that this may lead to "
510+
"additional backfills as accumulated changes are still pushed to the target environment."
511+
)
512+
environment = c.PROD
513+
else:
514+
environment = environment or c.PROD
515+
516+
return environment.lower()
517+
518+
495519
def _build_requirements(
496520
provided_requirements: t.Dict[str, str],
497521
excluded_requirements: t.Set[str],

0 commit comments

Comments
 (0)