Skip to content
5 changes: 5 additions & 0 deletions sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,11 @@ def run(ctx: click.Context, environment: t.Optional[str] = None, **kwargs: t.Any
is_flag=True,
help="Wait for the environment to be deleted before returning. If not specified, the environment will be deleted asynchronously by the janitor process. This option requires a connection to the data warehouse.",
)
@click.option(
"--cleanup-snapshots",
is_flag=True,
help="After invalidating, immediately delete physical snapshot tables that are exclusively owned by this environment (not referenced by any other environment). Cleanup runs synchronously regardless of --sync.",
)
@click.pass_context
@error_handler
@cli_analytics
Expand Down
41 changes: 39 additions & 2 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
Snapshot,
SnapshotEvaluator,
SnapshotFingerprint,
SnapshotId,
missing_intervals,
to_table_mapping,
)
Expand All @@ -108,7 +109,11 @@
StateReader,
StateSync,
)
from sqlmesh.core.janitor import cleanup_expired_views, delete_expired_snapshots
from sqlmesh.core.janitor import (
cleanup_expired_views,
delete_expired_snapshots,
delete_snapshots_for_environment,
)
from sqlmesh.core.table_diff import TableDiff
from sqlmesh.core.test import (
ModelTextTestResult,
Expand Down Expand Up @@ -1847,18 +1852,50 @@ def apply(
)

@python_api_analytics
def invalidate_environment(self, name: str, sync: bool = False) -> None:
def invalidate_environment(
self, name: str, sync: bool = False, cleanup_snapshots: bool = False
) -> None:
"""Invalidates the target environment by setting its expiration timestamp to now.

Args:
name: The name of the environment to invalidate.
sync: If True, the call blocks until the environment is deleted. Otherwise, the environment will
be deleted asynchronously by the janitor process.
cleanup_snapshots: If True, immediately deletes physical snapshot tables that are exclusively
owned by this environment (not referenced by any other environment). Cleanup runs
synchronously regardless of --sync.
"""
name = Environment.sanitize_name(name)
sync = sync or cleanup_snapshots

target_snapshot_ids: t.Set[SnapshotId] = set()
if cleanup_snapshots:
# Capture snapshot IDs before invalidation so we can scope the cleanup afterwards.
env = self.state_sync.get_environment(name)
if env is None:
logger.warning("Environment '%s' does not exist; skipping snapshot cleanup.", name)
return
target_snapshot_ids = {s.snapshot_id for s in env.snapshots}

self.state_sync.invalidate_environment(name)

if sync:
self._cleanup_environments(name=name)
if cleanup_snapshots and target_snapshot_ids:
failures = delete_snapshots_for_environment(
self.state_sync,
self.snapshot_evaluator,
target_snapshot_ids,
console=self.console,
)
if failures:
summary = "\n".join(failures)
if self.config.janitor.warn_on_delete_failure:
self.console.log_warning(
f"Snapshot cleanup completed with failures:\n{summary}"
)
else:
raise SQLMeshError(f"Snapshot cleanup completed with failures:\n{summary}")
self.console.log_success(f"Environment '{name}' deleted.")
else:
self.console.log_success(f"Environment '{name}' invalidated.")
Expand Down
71 changes: 70 additions & 1 deletion sqlmesh/core/janitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlmesh.core.console import Console
from sqlmesh.core.dialect import schema_
from sqlmesh.core.environment import Environment
from sqlmesh.core.snapshot import SnapshotEvaluator
from sqlmesh.core.snapshot import SnapshotEvaluator, SnapshotId
from sqlmesh.core.state_sync import StateSync
from sqlmesh.core.state_sync.common import (
logger,
Expand Down Expand Up @@ -193,3 +193,72 @@ def delete_expired_snapshots(
failures.append(message)
logger.info("Cleaned up %s expired snapshots", num_expired_snapshots)
return failures


def delete_snapshots_for_environment(
state_sync: StateSync,
snapshot_evaluator: SnapshotEvaluator,
target_snapshot_ids: t.Collection[SnapshotId],
*,
force_delete: bool = False,
console: t.Optional[Console] = None,
) -> t.List[str]:
"""Delete snapshots that are exclusively owned by a specific (now-deleted) environment.

This performs a scoped cleanup: only the provided snapshot IDs are considered for deletion,
and only those that are not referenced by any remaining active environment will be removed.

Args:
state_sync: StateSync instance to query and delete snapshot state from.
snapshot_evaluator: SnapshotEvaluator instance to clean up physical tables.
target_snapshot_ids: The snapshot IDs to consider for deletion (typically from the
environment that was just invalidated/deleted).
force_delete: If True, delete snapshot state records even when physical table cleanup fails.
console: Optional console for reporting progress.

Returns:
List of failure messages encountered during cleanup.
"""
if not target_snapshot_ids:
return []

failures: t.List[str] = []
batch = state_sync.get_expired_snapshots(
ignore_ttl=True,
batch_range=ExpiredBatchRange.all_batch_range(),
target_snapshot_ids=target_snapshot_ids,
)
if batch is None:
return failures

logger.info(
"Cleaning up %s snapshots exclusively owned by invalidated environment",
len(batch.expired_snapshot_ids),
)

cleanup_succeeded = True
if batch.cleanup_tasks:
try:
snapshot_evaluator.cleanup(
target_snapshots=batch.cleanup_tasks,
on_complete=console.update_cleanup_progress if console else None,
)
except Exception as failed_drops:
message = f"Failed to clean up: {failed_drops}"
logger.warning(message)
failures.append(message)
cleanup_succeeded = False

if cleanup_succeeded or force_delete:
try:
state_sync.delete_snapshots(batch.expired_snapshot_ids)
logger.info(
"Cleaned up %s snapshots from invalidated environment",
len(batch.expired_snapshot_ids),
)
except Exception as e:
message = f"Failed to delete snapshot state records: {e}"
logger.warning(message)
failures.append(message)

return failures
6 changes: 6 additions & 0 deletions sqlmesh/core/state_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,16 @@ def get_expired_snapshots(
batch_range: ExpiredBatchRange,
current_ts: t.Optional[int] = None,
ignore_ttl: bool = False,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> t.Optional[ExpiredSnapshotBatch]:
"""Returns a single batch of expired snapshots ordered by (updated_ts, name, identifier).

Args:
current_ts: Timestamp used to evaluate expiration.
ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced).
batch_range: The range of the batch to fetch.
target_snapshot_ids: If provided, only consider snapshots with these IDs. Useful for
scoped cleanup after environment invalidation.

Returns:
A batch describing expired snapshots or None if no snapshots are pending cleanup.
Expand Down Expand Up @@ -368,6 +371,7 @@ def delete_expired_snapshots(
batch_range: ExpiredBatchRange,
ignore_ttl: bool = False,
current_ts: t.Optional[int] = None,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> None:
"""Removes expired snapshots.

Expand All @@ -379,6 +383,8 @@ def delete_expired_snapshots(
ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting
all snapshots that are not referenced in any environment
current_ts: Timestamp used to evaluate expiration.
target_snapshot_ids: If provided, only delete snapshots with these IDs. Useful for
scoped cleanup after environment invalidation.
"""

@abc.abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/state_sync/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,14 @@ def delete_expired_snapshots(
batch_range: ExpiredBatchRange,
ignore_ttl: bool = False,
current_ts: t.Optional[int] = None,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> None:
self.snapshot_cache.clear()
self.state_sync.delete_expired_snapshots(
batch_range=batch_range,
ignore_ttl=ignore_ttl,
current_ts=current_ts,
target_snapshot_ids=target_snapshot_ids,
)

def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None:
Expand Down
4 changes: 4 additions & 0 deletions sqlmesh/core/state_sync/db/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,15 @@ def get_expired_snapshots(
batch_range: ExpiredBatchRange,
current_ts: t.Optional[int] = None,
ignore_ttl: bool = False,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> t.Optional[ExpiredSnapshotBatch]:
current_ts = current_ts or now_timestamp()
return self.snapshot_state.get_expired_snapshots(
environments=self.environment_state.get_environments(),
current_ts=current_ts,
ignore_ttl=ignore_ttl,
batch_range=batch_range,
target_snapshot_ids=target_snapshot_ids,
)

def get_expired_environments(
Expand All @@ -287,11 +289,13 @@ def delete_expired_snapshots(
batch_range: ExpiredBatchRange,
ignore_ttl: bool = False,
current_ts: t.Optional[int] = None,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> None:
batch = self.get_expired_snapshots(
ignore_ttl=ignore_ttl,
current_ts=current_ts,
batch_range=batch_range,
target_snapshot_ids=target_snapshot_ids,
)
if batch and batch.expired_snapshot_ids:
self.snapshot_state.delete_snapshots(batch.expired_snapshot_ids)
Expand Down
11 changes: 11 additions & 0 deletions sqlmesh/core/state_sync/db/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def get_expired_snapshots(
current_ts: int,
ignore_ttl: bool,
batch_range: ExpiredBatchRange,
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
) -> t.Optional[ExpiredSnapshotBatch]:
expired_query = exp.select("name", "identifier", "version", "updated_ts").from_(
self.snapshots_table
Expand All @@ -180,6 +181,16 @@ def get_expired_snapshots(
(exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts
)

if target_snapshot_ids is not None:
target_conditions = list(
snapshot_id_filter(
self.engine_adapter,
target_snapshot_ids,
batch_size=self.SNAPSHOT_BATCH_SIZE,
)
)
expired_query = expired_query.where(exp.or_(*target_conditions))

expired_query = expired_query.where(batch_range.where_filter)

promoted_snapshot_ids = {
Expand Down
70 changes: 70 additions & 0 deletions tests/core/integration/test_aux_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,76 @@ def test_invalidating_environment(sushi_context: Context):
assert start_schemas - schemas_after_janitor == {"sushi__dev"}


def test_invalidate_environment_cleanup_snapshots_scoped(tmp_path: Path):
"""Test that --cleanup-snapshots only deletes snapshots exclusively owned by the invalidated env."""
models_dir = tmp_path / "models"
models_dir.mkdir()
(models_dir / "model1.sql").write_text("MODEL(name test.model1, kind FULL); SELECT 1 AS col")
(models_dir / "model2.sql").write_text("MODEL(name test.model2, kind FULL); SELECT 2 AS col")

ctx = Context(
paths=[tmp_path],
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")),
)

# Apply both models to prod and dev.
ctx.plan("prod", no_prompts=True, auto_apply=True)
ctx.plan("dev", no_prompts=True, auto_apply=True, include_unmodified=True)

prod_env = ctx.state_sync.get_environment("prod")
dev_env = ctx.state_sync.get_environment("dev")
assert prod_env is not None
assert dev_env is not None

prod_snapshot_ids = {s.snapshot_id for s in prod_env.snapshots}
dev_snapshot_ids = {s.snapshot_id for s in dev_env.snapshots}

# In a virtual environment, dev shares snapshots with prod.
# Shared snapshots must NOT be deleted when invalidating dev with --cleanup-snapshots.
shared_snapshot_ids = prod_snapshot_ids & dev_snapshot_ids

ctx.invalidate_environment("dev", cleanup_snapshots=True)

# The dev environment record should be gone.
assert ctx.state_sync.get_environment("dev") is None

# Shared snapshots (also in prod) must still exist.
remaining_snapshots = ctx.state_sync.get_snapshots(list(shared_snapshot_ids))
assert set(remaining_snapshots.keys()) == shared_snapshot_ids

# Prod environment should be unaffected.
assert ctx.state_sync.get_environment("prod") is not None


def test_invalidate_environment_cleanup_snapshots_exclusive(tmp_path: Path):
"""Test that --cleanup-snapshots deletes snapshots exclusively owned by the invalidated env."""
models_dir = tmp_path / "models"
models_dir.mkdir()
(models_dir / "model1.sql").write_text("MODEL(name test.model1, kind FULL); SELECT 1 AS col")

ctx = Context(
paths=[tmp_path],
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")),
)

# Apply model1 to dev only (not prod). These snapshots will be exclusively owned by dev.
ctx.plan("dev", no_prompts=True, auto_apply=True)

dev_env = ctx.state_sync.get_environment("dev")
assert dev_env is not None
dev_snapshot_ids = {s.snapshot_id for s in dev_env.snapshots}
assert dev_snapshot_ids

ctx.invalidate_environment("dev", cleanup_snapshots=True)

# The dev environment record should be gone.
assert ctx.state_sync.get_environment("dev") is None

# All dev-exclusive snapshots should have been deleted.
remaining_snapshots = ctx.state_sync.get_snapshots(list(dev_snapshot_ids))
assert not remaining_snapshots


@time_machine.travel("2023-01-08 15:00:00 UTC")
def test_evaluate_uncategorized_snapshot(init_and_plan_context: t.Callable):
context, plan = init_and_plan_context("examples/sushi")
Expand Down
Loading
Loading