@@ -89,8 +89,6 @@ class ContextDiff(PydanticModel):
8989 """Environment statements."""
9090 diff_rendered : bool = False
9191 """Whether the diff should compare raw vs rendered models"""
92- initial_environment : str = ""
93- """The initial target environment (e.g 'dev'), if the plan option `always_compare_to_prod` is set"""
9492
9593 @classmethod
9694 def create (
@@ -106,7 +104,6 @@ def create(
106104 environment_statements : t .Optional [t .List [EnvironmentStatements ]] = [],
107105 gateway_managed_virtual_layer : bool = False ,
108106 infer_python_dependencies : bool = True ,
109- initial_environment : t .Optional [str ] = None ,
110107 always_compare_against_prod : bool = False ,
111108 ) -> ContextDiff :
112109 """Create a ContextDiff object.
@@ -133,33 +130,34 @@ def create(
133130 The ContextDiff object.
134131 """
135132 initial_environment = environment .lower ()
136-
137- environment = _get_target_environment (
138- environment , state_reader , always_compare_against_prod
139- )
140-
141- env = state_reader .get_environment (environment )
142- initial_env = (
143- env
144- if initial_environment == environment
145- else state_reader .get_environment (initial_environment )
146- )
133+ initial_env = state_reader .get_environment (initial_environment )
147134
148135 create_from_env_exists = False
149- if env is None or env .expired :
150- env = state_reader .get_environment (create_from .lower ())
136+ if initial_env is None or initial_env .expired :
137+ initial_env = state_reader .get_environment (create_from .lower ())
151138
152- if not env and create_from != c .PROD :
139+ if not initial_env and create_from != c .PROD :
153140 get_console ().log_warning (
154141 f"The environment name '{ create_from } ' was passed to the `plan` command's `--create-from` argument, but '{ create_from } ' does not exist. Initializing new environment '{ environment } ' from scratch."
155142 )
156143
157144 is_new_environment = True
158- create_from_env_exists = env is not None
145+ create_from_env_exists = initial_env is not None
159146 previously_promoted_snapshot_ids = set ()
160147 else :
161148 is_new_environment = False
162- previously_promoted_snapshot_ids = {s .snapshot_id for s in env .promoted_snapshots }
149+ previously_promoted_snapshot_ids = {
150+ s .snapshot_id for s in initial_env .promoted_snapshots
151+ }
152+
153+ # Find the proper environment to diff against, this might be different than the "initial" (i.e user provided) environment
154+ # e.g it will default to prod if the plan option `always_compare_against_prod` is set.
155+ environment = _get_diff_environment (environment , state_reader , always_compare_against_prod )
156+ env = (
157+ initial_env
158+ if (initial_environment == environment )
159+ else state_reader .get_environment (environment )
160+ )
163161
164162 environment_snapshot_infos = []
165163 if env :
@@ -237,7 +235,6 @@ def create(
237235
238236 return ContextDiff (
239237 environment = environment ,
240- initial_environment = initial_environment ,
241238 is_new_environment = is_new_environment ,
242239 is_unfinalized_environment = bool (env and not env .finalized_ts ),
243240 normalize_environment_name = is_new_environment or bool (env and env .normalize_name ),
@@ -279,9 +276,8 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD
279276
280277 snapshots = state_reader .get_snapshots (env .snapshots )
281278
282- environment = env .name
283279 return ContextDiff (
284- environment = environment ,
280+ environment = env . name ,
285281 is_new_environment = False ,
286282 is_unfinalized_environment = False ,
287283 normalize_environment_name = env .normalize_name ,
@@ -300,7 +296,6 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD
300296 previous_environment_statements = [],
301297 previous_gateway_managed_virtual_layer = env .gateway_managed ,
302298 gateway_managed_virtual_layer = env .gateway_managed ,
303- initial_environment = environment ,
304299 )
305300
306301 @property
@@ -499,7 +494,7 @@ def text_diff(self, name: str) -> str:
499494 return ""
500495
501496
502- def _get_target_environment (
497+ def _get_diff_environment (
503498 environment : str , state_reader : StateReader , always_compare_against_prod : bool = False
504499) -> str :
505500 if always_compare_against_prod :
0 commit comments