diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 7d17b3b863..1899470617 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -2698,16 +2698,22 @@ def _run_janitor(self, ignore_ttl: bool = False) -> None: def _cleanup_environments(self, current_ts: t.Optional[int] = None) -> None: current_ts = current_ts or now_timestamp() - expired_environments = self.state_sync.get_expired_environments(current_ts=current_ts) - - cleanup_expired_views( - default_adapter=self.engine_adapter, - engine_adapters=self.engine_adapters, - environments=expired_environments, - warn_on_delete_failure=self.config.janitor.warn_on_delete_failure, - console=self.console, + expired_environments_summaries = self.state_sync.get_expired_environments( + current_ts=current_ts ) + for expired_env_summary in expired_environments_summaries: + expired_env = self.state_reader.get_environment(expired_env_summary.name) + + if expired_env: + cleanup_expired_views( + default_adapter=self.engine_adapter, + engine_adapters=self.engine_adapters, + environments=[expired_env], + warn_on_delete_failure=self.config.janitor.warn_on_delete_failure, + console=self.console, + ) + self.state_sync.delete_expired_environments(current_ts=current_ts) def _try_connection(self, connection_name: str, validator: t.Callable[[], None]) -> None: diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 4f46ccf9b8..4a4e31854f 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -304,12 +304,12 @@ def get_expired_snapshots( """ @abc.abstractmethod - def get_expired_environments(self, current_ts: int) -> t.List[Environment]: + def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary]: """Returns the expired environments. Expired environments are environments that have exceeded their time-to-live value. Returns: - The list of environments to remove, the filter to remove environments. + The list of environment summaries to remove. """ @@ -418,7 +418,7 @@ def finalize(self, environment: Environment) -> None: @abc.abstractmethod def delete_expired_environments( self, current_ts: t.Optional[int] = None - ) -> t.List[Environment]: + ) -> t.List[EnvironmentSummary]: """Removes expired environments. Expired environments are environments that have exceeded their time-to-live value. diff --git a/sqlmesh/core/state_sync/db/environment.py b/sqlmesh/core/state_sync/db/environment.py index dcf915a91c..b06d6160cc 100644 --- a/sqlmesh/core/state_sync/db/environment.py +++ b/sqlmesh/core/state_sync/db/environment.py @@ -165,27 +165,24 @@ def finalize(self, environment: Environment) -> None: where=environment_filter, ) - def get_expired_environments(self, current_ts: int) -> t.List[Environment]: + def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary]: """Returns the expired environments. Expired environments are environments that have exceeded their time-to-live value. Returns: - The list of environments to remove, the filter to remove environments. + The list of environment summaries to remove. """ - rows = fetchall( - self.engine_adapter, - self._environments_query( - where=self._create_expiration_filter_expr(current_ts), - lock_for_update=True, - ), - ) - expired_environments = [self._environment_from_row(r) for r in rows] - return expired_environments + environment_summaries = self.get_environments_summary() + return [ + env_summary + for env_summary in environment_summaries + if env_summary.expiration_ts is not None and env_summary.expiration_ts <= current_ts + ] def delete_expired_environments( self, current_ts: t.Optional[int] = None - ) -> t.List[Environment]: + ) -> t.List[EnvironmentSummary]: """Deletes expired environments. Returns: diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 7e48418317..2a27c5fd92 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -274,7 +274,7 @@ def get_expired_snapshots( self.environment_state.get_environments(), current_ts=current_ts, ignore_ttl=ignore_ttl ) - def get_expired_environments(self, current_ts: int) -> t.List[Environment]: + def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary]: return self.environment_state.get_expired_environments(current_ts=current_ts) @transactional() @@ -294,7 +294,7 @@ def delete_expired_snapshots( @transactional() def delete_expired_environments( self, current_ts: t.Optional[int] = None - ) -> t.List[Environment]: + ) -> t.List[EnvironmentSummary]: current_ts = current_ts or now_timestamp() return self.environment_state.delete_expired_environments(current_ts=current_ts) diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index 173264f9b5..dd68b5c515 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -1115,7 +1115,7 @@ def test_delete_expired_environments(state_sync: EngineAdapterStateSync, make_sn assert state_sync.get_environment_statements(env_a.name) == environment_statements deleted_environments = state_sync.delete_expired_environments() - assert deleted_environments == [env_a] + assert deleted_environments == [env_a.summary] assert state_sync.get_environment(env_a.name) is None assert state_sync.get_environment(env_b.name) == env_b diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 3d02d32e7e..276dd38afc 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -859,9 +859,9 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None: adapter_mock.dialect = "duckdb" state_sync_mock = mocker.MagicMock() - state_sync_mock.get_expired_environments.return_value = [ + environments = [ Environment( - name="test_environment", + name="test_environment1", suffix_target=EnvironmentSuffixTarget.TABLE, snapshots=[x.table_info for x in sushi_context.snapshots.values()], start_at="2022-01-01", @@ -870,7 +870,7 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None: previous_plan_id="test_plan_id", ), Environment( - name="test_environment", + name="test_environment2", suffix_target=EnvironmentSuffixTarget.SCHEMA, snapshots=[x.table_info for x in sushi_context.snapshots.values()], start_at="2022-01-01", @@ -880,6 +880,11 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None: ), ] + state_sync_mock.get_expired_environments.return_value = [env.summary for env in environments] + state_sync_mock.get_environment = lambda name: next( + env for env in environments if env.name == name + ) + sushi_context._engine_adapter = adapter_mock sushi_context.engine_adapters = {sushi_context.config.default_gateway: adapter_mock} sushi_context._state_sync = state_sync_mock @@ -891,7 +896,7 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None: adapter_mock.drop_schema.assert_has_calls( [ call( - schema_("sushi__test_environment", "memory"), + schema_("sushi__test_environment2", "memory"), cascade=True, ignore_if_not_exists=True, ), @@ -903,7 +908,7 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None: adapter_mock.drop_view.assert_has_calls( [ call( - "memory.sushi.waiter_as_customer_by_day__test_environment", + "memory.sushi.waiter_as_customer_by_day__test_environment1", ignore_if_not_exists=True, ), ]