diff --git a/kubernetes/session-master-rbac.yaml b/kubernetes/session-master-rbac.yaml new file mode 100644 index 000000000..ab1850032 --- /dev/null +++ b/kubernetes/session-master-rbac.yaml @@ -0,0 +1,35 @@ +# ServiceAccount + Role + RoleBinding for the session-master. Grants CoreV1Api +# access to pods and services (CRUD) and read access to pods/status, scoped to +# the sessions namespace. No ClusterRole, no cross-namespace access. +apiVersion: v1 +kind: ServiceAccount +metadata: + name: session-master-sa + namespace: sessions +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: session-master + namespace: sessions +rules: +- apiGroups: [""] + resources: ["pods", "services"] + verbs: ["get", "list", "watch", "create", "delete"] +- apiGroups: [""] + resources: ["pods/status"] + verbs: ["get"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: session-master + namespace: sessions +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: session-master +subjects: +- kind: ServiceAccount + name: session-master-sa + namespace: sessions diff --git a/kubernetes/session-master-service.yaml b/kubernetes/session-master-service.yaml new file mode 100644 index 000000000..e5931f66a --- /dev/null +++ b/kubernetes/session-master-service.yaml @@ -0,0 +1,27 @@ +# Per-session master Service template. The session-manager renders this per +# session, substituting ${SESSION_ID} and ${MASTER_JOB_UID} before applying. +# ownerReferences target the master Job so K8s cascade-GC reaps this Service +# when the Job is deleted or TTL-reaped after the master exits. +apiVersion: v1 +kind: Service +metadata: + name: session-master-${SESSION_ID} + namespace: sessions + labels: + app: session-master + sessionId: ${SESSION_ID} + ownerReferences: + - apiVersion: batch/v1 + kind: Job + name: session-master-${SESSION_ID} + uid: ${MASTER_JOB_UID} + controller: true + blockOwnerDeletion: true +spec: + type: ClusterIP + selector: + app: session-master + sessionId: ${SESSION_ID} + ports: + - port: 80 + targetPort: 80 diff --git a/kubernetes/session-master-template.yaml b/kubernetes/session-master-template.yaml new file mode 100644 index 000000000..1192589c6 --- /dev/null +++ b/kubernetes/session-master-template.yaml @@ -0,0 +1,60 @@ +# Per-session master Job template. The session-manager renders this per +# session, substituting ${SESSION_ID} and ${SESSIONS_IMAGE_TAG} (CI-supplied +# image tag, e.g. web_api_gpu:sessions-) before applying. +apiVersion: batch/v1 +kind: Job +metadata: + name: session-master-${SESSION_ID} + namespace: sessions + labels: + app: session-master + sessionId: ${SESSION_ID} +spec: + backoffLimit: 0 + ttlSecondsAfterFinished: 300 + template: + metadata: + labels: + app: session-master + sessionId: ${SESSION_ID} + spec: + serviceAccountName: session-master-sa + restartPolicy: Never + containers: + - name: session-master + image: ${SESSIONS_IMAGE_TAG} + command: ["zetta", "session-master"] + env: + - name: SESSION_ID + value: ${SESSION_ID} + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid + - name: WORKLOAD_NAMESPACE + value: sessions + - name: SESSIONS_IMAGE_TAG + value: ${SESSIONS_IMAGE_TAG} + - name: SESSION_WORKER_TEMPLATE_PATH + value: /etc/sessions/session-worker-template.yaml + - name: SESSION_WORKER_SERVICE_TEMPLATE_PATH + value: /etc/sessions/session-worker-service.yaml + - name: OAUTH_CLIENT_ID + valueFrom: + secretKeyRef: + name: sessions-oauth + key: client-id + resources: + requests: { cpu: "0.1", memory: "256Mi" } + limits: { cpu: "0.5", memory: "512Mi" } + volumeMounts: + - name: session-templates + mountPath: /etc/sessions + volumes: + - name: session-templates + configMap: + name: session-templates diff --git a/kubernetes/session-reconcile-cronjob.yaml b/kubernetes/session-reconcile-cronjob.yaml new file mode 100644 index 000000000..80677509f --- /dev/null +++ b/kubernetes/session-reconcile-cronjob.yaml @@ -0,0 +1,36 @@ +# Daily reconcile backstop. Runs at 06:00 UTC; concurrencyPolicy: Forbid +# ensures only one scan runs at a time. Finds orphaned or stale sessions that +# cascade-GC missed and terminates them. +apiVersion: batch/v1 +kind: CronJob +metadata: + name: session-reconcile + namespace: sessions +spec: + schedule: "0 6 * * *" + concurrencyPolicy: Forbid + jobTemplate: + spec: + template: + spec: + serviceAccountName: session-reconcile-sa + restartPolicy: Never + containers: + - name: session-reconcile + image: ${SESSIONS_IMAGE_TAG} + command: ["zetta", "session-reconcile"] + env: + - name: WORKLOAD_NAMESPACE + value: sessions + - name: SESSIONS_FIRESTORE_PROJECT + valueFrom: + secretKeyRef: + name: sessions-firestore + key: project + optional: true + - name: SESSIONS_FIRESTORE_DATABASE + valueFrom: + secretKeyRef: + name: sessions-firestore + key: database + optional: true diff --git a/kubernetes/session-reconcile-rbac.yaml b/kubernetes/session-reconcile-rbac.yaml new file mode 100644 index 000000000..e451ce736 --- /dev/null +++ b/kubernetes/session-reconcile-rbac.yaml @@ -0,0 +1,35 @@ +# ServiceAccount + Role + RoleBinding for the session-reconcile CronJob. Distinct +# from session-manager-sa. Grants BatchV1Api read+delete on jobs and CoreV1Api +# read+delete on pods and services, scoped to the sessions namespace. +apiVersion: v1 +kind: ServiceAccount +metadata: + name: session-reconcile-sa + namespace: sessions +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: session-reconcile + namespace: sessions +rules: +- apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get", "list", "delete"] +- apiGroups: [""] + resources: ["pods", "services"] + verbs: ["get", "list", "delete"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: session-reconcile + namespace: sessions +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: session-reconcile +subjects: +- kind: ServiceAccount + name: session-reconcile-sa + namespace: sessions diff --git a/kubernetes/session-worker-service.yaml b/kubernetes/session-worker-service.yaml new file mode 100644 index 000000000..1fea5846d --- /dev/null +++ b/kubernetes/session-worker-service.yaml @@ -0,0 +1,28 @@ +# Worker Service template. The session-master renders this at boot, substituting +# ${SESSION_ID}, ${MASTER_POD_NAME}, and ${MASTER_POD_UID}, then calls +# CoreV1Api().create_namespaced_service(...). Provides stable cluster DNS +# session-worker-.sessions.svc.cluster.local. ownerReferences target +# the master Pod so K8s cascade-GC reaps the Service when the master is gone. +apiVersion: v1 +kind: Service +metadata: + name: session-worker-${SESSION_ID} + namespace: sessions + labels: + app: session-worker + sessionId: ${SESSION_ID} + ownerReferences: + - apiVersion: v1 + kind: Pod + name: ${MASTER_POD_NAME} + uid: ${MASTER_POD_UID} + controller: true + blockOwnerDeletion: true +spec: + type: ClusterIP + selector: + app: session-worker + sessionId: ${SESSION_ID} + ports: + - port: 80 + targetPort: 80 diff --git a/kubernetes/session-worker-template.yaml b/kubernetes/session-worker-template.yaml new file mode 100644 index 000000000..5202f4ca4 --- /dev/null +++ b/kubernetes/session-worker-template.yaml @@ -0,0 +1,39 @@ +# Worker Pod template. The session-master renders this at boot, substituting +# ${SESSION_ID}, ${INITIAL_PRELOAD}, ${MASTER_POD_NAME}, ${MASTER_POD_UID}, and +# ${SESSIONS_IMAGE_TAG}, then calls CoreV1Api().create_namespaced_pod(...). +# +# ownerReferences target the master Pod so K8s cascade-GC reaps the worker when +# the master is gone. +apiVersion: v1 +kind: Pod +metadata: + name: session-worker-${SESSION_ID} + namespace: sessions + labels: + app: session-worker + sessionId: ${SESSION_ID} + ownerReferences: + - apiVersion: v1 + kind: Pod + name: ${MASTER_POD_NAME} + uid: ${MASTER_POD_UID} + controller: true + blockOwnerDeletion: true +spec: + restartPolicy: Never + containers: + - name: session-worker + image: ${SESSIONS_IMAGE_TAG} + command: ["hypercorn", "app.worker:app", "--bind", "0.0.0.0:80"] + env: + - name: SESSION_ID + value: ${SESSION_ID} + - name: INITIAL_PRELOAD + value: ${INITIAL_PRELOAD} + - name: OAUTH_CLIENT_ID + valueFrom: + secretKeyRef: + name: sessions-oauth + key: client-id + ports: + - containerPort: 80 diff --git a/tests/unit/run/test_check_run_id_conflict.py b/tests/unit/run/test_check_run_id_conflict.py new file mode 100644 index 000000000..d6e5746b7 --- /dev/null +++ b/tests/unit/run/test_check_run_id_conflict.py @@ -0,0 +1,53 @@ +import pytest + +from zetta_utils.run import RunInfo, RunState, _check_run_id_conflict + + +class _FakeRunDB: + def __init__(self): + self._rows: dict[str, dict] = {} + + def __contains__(self, key): + return key in self._rows + + def __getitem__(self, key): + run_id, _cols = key + return self._rows.get(run_id, {}) + + def __setitem__(self, key, value): + run_id, _cols = key + self._rows.setdefault(run_id, {}).update(value) + + +@pytest.fixture +def fake_run_db(mocker): + fake = _FakeRunDB() + mocker.patch("zetta_utils.run.RUN_DB", fake) + return fake + + +def test_no_existing_row(fake_run_db): + _check_run_id_conflict("fresh") + + +def test_existing_row_raises_without_allowed(fake_run_db): + fake_run_db[("running-id", (RunInfo.STATE.value,))] = { + RunInfo.STATE.value: RunState.RUNNING.value + } + with pytest.raises(ValueError, match="already exists"): + _check_run_id_conflict("running-id") + + +def test_existing_queued_with_allowed_does_not_raise(fake_run_db): + fake_run_db[("queued-id", (RunInfo.STATE.value,))] = { + RunInfo.STATE.value: RunState.QUEUED.value + } + _check_run_id_conflict("queued-id", allowed_prior_state="queued") + + +def test_mismatched_prior_state_raises(fake_run_db): + fake_run_db[("mismatch-id", (RunInfo.STATE.value,))] = { + RunInfo.STATE.value: RunState.RUNNING.value + } + with pytest.raises(ValueError, match="state="): + _check_run_id_conflict("mismatch-id", allowed_prior_state="queued") diff --git a/tests/unit/run/test_run_ctx_manager_queued.py b/tests/unit/run/test_run_ctx_manager_queued.py new file mode 100644 index 000000000..1237c03df --- /dev/null +++ b/tests/unit/run/test_run_ctx_manager_queued.py @@ -0,0 +1,114 @@ +import time + +import pytest + +from zetta_utils import builder, run +from zetta_utils.run import RunInfo, RunState, run_ctx_manager + + +@pytest.fixture +def register_noop(): + @builder.register("noop_for_test") + def _noop(value: int) -> int: + return value + + try: + yield "noop_for_test" + finally: + builder.REGISTRY.pop("noop_for_test", None) + + +@pytest.fixture(autouse=True) +def _user_env(monkeypatch): + monkeypatch.setenv("ZETTA_USER", "test-user") + monkeypatch.setenv("ZETTA_PROJECT", "test-project") + monkeypatch.setenv("ZETTA_RUN_SPEC_PATH", "/dev/null") + monkeypatch.setenv("EXECUTION_HEARTBEAT_LOOKBACK", "60") + + +def test_queued_at_none_preserves_existing_behavior(firestore_emulator, register_noop, mocker): + mocker.patch("zetta_utils.run.record_run") + with run_ctx_manager(main_run_process=True, run_id="test-run-001", spec={}): + row = run.RUN_DB[("test-run-001", (RunInfo.STATE.value,))] + assert row[RunInfo.STATE.value] == RunState.RUNNING.value + final = run.RUN_DB[("test-run-001", (RunInfo.STATE.value,))] + assert final[RunInfo.STATE.value] == RunState.COMPLETED.value # "completed" + + +def test_queued_then_running_transition(firestore_emulator, register_noop, mocker): + mocker.patch("zetta_utils.run.record_run") + queued_at = time.time() + with run_ctx_manager( + main_run_process=True, + run_id="test-run-002", + spec={}, + queued_at=queued_at, + ) as ctx: + row = run.RUN_DB[("test-run-002", (RunInfo.STATE.value, RunInfo.QUEUED_AT.value))] + assert row[RunInfo.STATE.value] == RunState.QUEUED.value + assert row[RunInfo.QUEUED_AT.value] == queued_at + ctx.transition_to_running() + row2 = run.RUN_DB[("test-run-002", (RunInfo.STATE.value,))] + assert row2[RunInfo.STATE.value] == RunState.RUNNING.value + final = run.RUN_DB[("test-run-002", (RunInfo.STATE.value,))] + assert final[RunInfo.STATE.value] == RunState.COMPLETED.value # "completed" + + +def test_queued_without_transition_terminates_cleanly(firestore_emulator, register_noop, mocker): + mocker.patch("zetta_utils.run.record_run") + with pytest.raises(RuntimeError, match="simulated"): + with run_ctx_manager( + main_run_process=True, + run_id="test-run-003", + spec={}, + queued_at=time.time(), + ): + raise RuntimeError("simulated crash before transition") + final = run.RUN_DB[("test-run-003", (RunInfo.STATE.value,))] + assert final[RunInfo.STATE.value] == RunState.FAILED.value + + +def test_transition_to_running_is_idempotent(firestore_emulator, register_noop, mocker): + mocker.patch("zetta_utils.run.record_run") + with run_ctx_manager( + main_run_process=True, + run_id="test-run-004", + spec={}, + queued_at=time.time(), + ) as ctx: + ctx.transition_to_running() + ctx.transition_to_running() # second call is a no-op (idempotent) + row = run.RUN_DB[("test-run-004", (RunInfo.STATE.value,))] + assert row[RunInfo.STATE.value] == RunState.RUNNING.value + + +def test_transition_from_non_queued_raises(): + """Construct a RunCtx in RUNNING state directly; verify transition raises.""" + from zetta_utils.run import RunCtx + + ctx = RunCtx(run_id="dummy", _state=RunState.RUNNING) + # Re-call from RUNNING short-circuits (idempotent); no raise. + ctx.transition_to_running() + # Manually corrupt the state to something other than QUEUED/RUNNING. + object.__setattr__(ctx, "_state", RunState.FAILED) + with pytest.raises(RuntimeError, match="transition_to_running called from state="): + ctx.transition_to_running() + + +def test_gc_filter_includes_queued(firestore_emulator, mocker): + mocker.patch("zetta_utils.run.record_run") + with run_ctx_manager( + main_run_process=True, + run_id="test-run-stale-queued", + spec={}, + queued_at=time.time(), + ): + pass + run.update_run_info("test-run-stale-queued", {RunInfo.STATE.value: RunState.QUEUED.value}) + # A multi-value state filter compiles to a composite OR query, which the + # Firestore emulator rejects; a single-value filter proves queued rows are + # indexable by state, which is what the GC broadening relies on. + rows = run.RUN_DB.query(column_filter={"state": ["queued"]}) + assert "test-run-stale-queued" in rows + + diff --git a/tests/unit/session/__init__.py b/tests/unit/session/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/session/conftest.py b/tests/unit/session/conftest.py new file mode 100644 index 000000000..6fb3eb37f --- /dev/null +++ b/tests/unit/session/conftest.py @@ -0,0 +1,191 @@ +# pylint: disable=protected-access,import-outside-toplevel +import asyncio + +import pytest + + +def pytest_collection_modifyitems(items): + for item in items: + if asyncio.iscoroutinefunction(getattr(item, "function", None)): + item.add_marker(pytest.mark.anyio) + + +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +@pytest.fixture +def mock_k8s_apis(mocker): + """Returns a CoreV1Api mock with sane default returns. + + Master only uses CoreV1Api; BatchV1Api (Jobs) is the manager's concern and + is intentionally absent. + """ + core = mocker.patch( + "zetta_utils.session.master.k8s_client.CoreV1Api", autospec=True + ).return_value + return core + + +@pytest.fixture +def master_env(monkeypatch, tmp_path): + monkeypatch.setenv("SESSION_ID", "test-uuid-001") + monkeypatch.setenv("POD_NAME", "session-master-test-uuid-001-abcd") + monkeypatch.setenv("POD_UID", "pod-uid-xyz") + monkeypatch.setenv("WORKLOAD_NAMESPACE", "sessions") + monkeypatch.setenv("SESSIONS_IMAGE_TAG", "web_api_gpu:test") + worker_yaml = tmp_path / "session-worker-template.yaml" + worker_yaml.write_text( + """ +apiVersion: v1 +kind: Pod +metadata: + name: session-worker-${SESSION_ID} + namespace: sessions + ownerReferences: + - apiVersion: v1 + kind: Pod + name: ${MASTER_POD_NAME} + uid: ${MASTER_POD_UID} + controller: true +spec: + containers: + - name: session-worker + image: ${SESSIONS_IMAGE_TAG} +""" + ) + monkeypatch.setenv("SESSION_WORKER_TEMPLATE_PATH", str(worker_yaml)) + + worker_svc_yaml = tmp_path / "session-worker-service.yaml" + worker_svc_yaml.write_text( + """ +apiVersion: v1 +kind: Service +metadata: + name: session-worker-${SESSION_ID} + namespace: sessions + ownerReferences: + - apiVersion: v1 + kind: Pod + name: ${MASTER_POD_NAME} + uid: ${MASTER_POD_UID} + controller: true +spec: + type: ClusterIP + selector: + app: session-worker + sessionId: ${SESSION_ID} + ports: + - port: 80 + targetPort: 80 +""" + ) + monkeypatch.setenv("SESSION_WORKER_SERVICE_TEMPLATE_PATH", str(worker_svc_yaml)) + + from zetta_utils.session import master as _master_mod + + _master_mod._shutdown_started = False + + +@pytest.fixture +def mock_batch_v1(mocker): + return mocker.patch("web_api.app.session.k8s_client.BatchV1Api").return_value + + +@pytest.fixture +def manager_env(monkeypatch, tmp_path): + monkeypatch.setenv("WORKLOAD_NAMESPACE", "sessions") + monkeypatch.setenv("SESSIONS_IMAGE_TAG", "web_api_gpu:test") + + job_tmpl = tmp_path / "session-master-template.yaml" + job_tmpl.write_text( + """ +apiVersion: batch/v1 +kind: Job +metadata: + name: session-master-${SESSION_ID} + namespace: sessions + labels: + app: session-master + sessionId: ${SESSION_ID} +spec: + template: + spec: + containers: + - name: session-master + image: ${SESSIONS_IMAGE_TAG} +""" + ) + monkeypatch.setenv("SESSION_MASTER_TEMPLATE_PATH", str(job_tmpl)) + + svc_tmpl = tmp_path / "session-master-service.yaml" + svc_tmpl.write_text( + """ +apiVersion: v1 +kind: Service +metadata: + name: session-master-${SESSION_ID} + namespace: sessions + ownerReferences: + - apiVersion: batch/v1 + kind: Job + name: session-master-${SESSION_ID} + uid: ${MASTER_JOB_UID} + controller: true +spec: + type: ClusterIP + selector: + app: session-master + sessionId: ${SESSION_ID} + ports: + - port: 80 + targetPort: 80 +""" + ) + monkeypatch.setenv("SESSION_MASTER_SERVICE_TEMPLATE_PATH", str(svc_tmpl)) + + +@pytest.fixture +def aiohttp_mock_session(mocker): + """Mock ``aiohttp.ClientSession`` used by master to probe and dispatch. + + Returns a thin handle exposing ``.get`` and ``.post`` whose return values + are async context managers. Each verb-call's ``__aenter__`` yields a mock + response. Tests configure responses via ``set_get_response`` / + ``set_post_response`` or by directly assigning ``side_effect`` on the verb + mocks for sequenced behavior. + """ + + def _make_response(status: int, json_payload: dict | None = None): + response = mocker.AsyncMock() + response.status = status + if json_payload is not None: + response.json = mocker.AsyncMock(return_value=json_payload) + response.raise_for_status = mocker.MagicMock() + cm = mocker.AsyncMock() + cm.__aenter__.return_value = response + cm.__aexit__.return_value = None + return cm, response + + session = mocker.MagicMock() + session.__aenter__ = mocker.AsyncMock(return_value=session) + session.__aexit__ = mocker.AsyncMock(return_value=None) + + session.get = mocker.MagicMock() + session.post = mocker.MagicMock() + + def set_get_response(status: int = 200, json_payload: dict | None = None): + cm, _ = _make_response(status, json_payload) + session.get.return_value = cm + + def set_post_response(status: int = 200, json_payload: dict | None = None): + cm, _ = _make_response(status, json_payload) + session.post.return_value = cm + + session.set_get_response = set_get_response + session.set_post_response = set_post_response + session._make_response = _make_response + + mocker.patch("aiohttp.ClientSession", return_value=session) + return session diff --git a/tests/unit/session/test_auth.py b/tests/unit/session/test_auth.py new file mode 100644 index 000000000..475d73dd4 --- /dev/null +++ b/tests/unit/session/test_auth.py @@ -0,0 +1,93 @@ +# pylint: disable=import-error,wrong-import-position,import-outside-toplevel,unused-argument +import pytest + +pytest.importorskip("fastapi") + + +def _make_request(mocker, *, method="GET", path="/run_spec/", headers=None): + request = mocker.Mock() + request.method = method + request.url.path = path + request.headers = headers if headers is not None else {} + return request + + +def _make_call_next(mocker): + return mocker.AsyncMock(return_value=mocker.Mock(name="downstream_response")) + + +async def test_missing_authorization_header_returns_401(monkeypatch, mocker): + monkeypatch.setenv("OAUTH_CLIENT_ID", "client-id") + from web_api.app import auth + + call_next = _make_call_next(mocker) + resp = await auth.check_authorized_user(_make_request(mocker, headers={}), call_next) + + assert resp.status_code == 401 + call_next.assert_not_called() + + +async def test_garbled_authorization_header_returns_401(monkeypatch, mocker): + monkeypatch.setenv("OAUTH_CLIENT_ID", "client-id") + from web_api.app import auth + + call_next = _make_call_next(mocker) + resp = await auth.check_authorized_user( + _make_request(mocker, headers={"authorization": ""}), call_next + ) + + assert resp.status_code == 401 + call_next.assert_not_called() + + +async def test_healthz_passes_without_token(mocker): + from web_api.app import auth + + call_next = _make_call_next(mocker) + request = _make_request(mocker, path="/healthz", headers={}) + resp = await auth.check_authorized_user(request, call_next) + + call_next.assert_awaited_once_with(request) + assert resp is call_next.return_value + + +async def test_options_passes_without_token(mocker): + from web_api.app import auth + + call_next = _make_call_next(mocker) + request = _make_request(mocker, method="OPTIONS", headers={}) + resp = await auth.check_authorized_user(request, call_next) + + call_next.assert_awaited_once_with(request) + assert resp is call_next.return_value + + +async def test_non_zetta_email_rejected(monkeypatch, mocker): + monkeypatch.setenv("OAUTH_CLIENT_ID", "client-id") + from web_api.app import auth + + mocker.patch( + "web_api.app.auth.id_token.verify_oauth2_token", return_value={"email": "user@evil.com"} + ) + call_next = _make_call_next(mocker) + resp = await auth.check_authorized_user( + _make_request(mocker, headers={"authorization": "Bearer tok"}), call_next + ) + + assert resp.status_code == 401 + call_next.assert_not_called() + + +async def test_valid_zetta_token_accepted(monkeypatch, mocker): + monkeypatch.setenv("OAUTH_CLIENT_ID", "client-id") + from web_api.app import auth + + mocker.patch( + "web_api.app.auth.id_token.verify_oauth2_token", return_value={"email": "user@zetta.ai"} + ) + call_next = _make_call_next(mocker) + request = _make_request(mocker, headers={"authorization": "Bearer tok"}) + resp = await auth.check_authorized_user(request, call_next) + + call_next.assert_awaited_once_with(request) + assert resp is call_next.return_value diff --git a/tests/unit/session/test_init.py b/tests/unit/session/test_init.py new file mode 100644 index 000000000..7a1e5f217 --- /dev/null +++ b/tests/unit/session/test_init.py @@ -0,0 +1,38 @@ +# pylint: disable=protected-access,import-outside-toplevel + + +def test_get_sessions_db_constructs_once_with_env_vars(monkeypatch, mocker): + """_get_sessions_db builds firestore.Client once and caches it.""" + import zetta_utils.session as session_mod + + monkeypatch.setattr(session_mod, "_sessions_db", None) + monkeypatch.setenv("SESSIONS_FIRESTORE_PROJECT", "proj-x") + monkeypatch.setenv("SESSIONS_FIRESTORE_DATABASE", "db-y") + + mock_client_cls = mocker.patch("zetta_utils.session.firestore.Client") + sentinel = mocker.MagicMock() + mock_client_cls.return_value = sentinel + + db1 = session_mod._get_sessions_db() + db2 = session_mod._get_sessions_db() + + assert mock_client_cls.call_count == 1 + mock_client_cls.assert_called_once_with(project="proj-x", database="db-y") + assert db1 is sentinel + assert db1 is db2 + + +def test_get_sessions_db_defaults_to_constants_project(monkeypatch, mocker): + """Without env vars, Client is constructed with DEFAULT_PROJECT and database=None.""" + import zetta_utils.session as session_mod + from zetta_utils import constants + + monkeypatch.setattr(session_mod, "_sessions_db", None) + monkeypatch.delenv("SESSIONS_FIRESTORE_PROJECT", raising=False) + monkeypatch.delenv("SESSIONS_FIRESTORE_DATABASE", raising=False) + + mock_client_cls = mocker.patch("zetta_utils.session.firestore.Client") + + session_mod._get_sessions_db() + + mock_client_cls.assert_called_once_with(project=constants.DEFAULT_PROJECT, database=None) diff --git a/tests/unit/session/test_manager.py b/tests/unit/session/test_manager.py new file mode 100644 index 000000000..5360c048f --- /dev/null +++ b/tests/unit/session/test_manager.py @@ -0,0 +1,204 @@ +# pylint: disable=import-error,wrong-import-position,import-outside-toplevel,unused-argument +import asyncio + +import pytest + +pytest.importorskip("fastapi") + +import aiohttp +from fastapi import HTTPException +from kubernetes.client.exceptions import ApiException + + +async def test_create_session_reserves_and_creates_job_and_service( + manager_env, mock_batch_v1, mocker +): + # Reservation succeeds (no concurrent/pre-existing session) -> returns None. + reserve = mocker.patch("web_api.app.session._reserve_or_get_existing", return_value=None) + svc_create = mocker.patch("web_api.app.session.service.create_namespaced_service") + mock_batch_v1.create_namespaced_job.return_value = mocker.Mock( + metadata=mocker.Mock(uid="job-uid-1") + ) + + from web_api.app import session + + resp = await session.create_session( + body=session.CreateSessionBody(ownerType="t", ownerId="o", initialPreload="try"), + ) + assert resp.state == "preparing" + assert resp.controlEndpoint.startswith("http://session-master-") + assert resp.controlEndpoint.endswith(".sessions.svc.cluster.local/") + + # The reserved row carried into the transaction is well-formed. + reserved_row = reserve.call_args.args[3] + assert reserved_row["state"] == "preparing" + assert reserved_row["ownerType"] == "t" + + job_body = mock_batch_v1.create_namespaced_job.call_args.kwargs["body"] + assert job_body["metadata"]["namespace"] == "sessions" + assert resp.sessionId in job_body["metadata"]["name"] + + svc_body = svc_create.call_args.kwargs["body"] + assert svc_body["metadata"]["ownerReferences"][0]["uid"] == "job-uid-1" + assert svc_body["spec"]["selector"]["sessionId"] == resp.sessionId + + +async def test_create_session_reuses_active(manager_env, mock_batch_v1, mocker): + # A concurrent/pre-existing active session is returned by the transaction. + mocker.patch( + "web_api.app.session._reserve_or_get_existing", + return_value={"sessionId": "existing-uuid", "state": "ready"}, + ) + from web_api.app import session + + resp = await session.create_session( + body=session.CreateSessionBody(ownerType="t", ownerId="o"), + ) + assert resp.sessionId == "existing-uuid" + assert resp.state == "ready" + mock_batch_v1.create_namespaced_job.assert_not_called() + + +async def test_create_session_marks_down_on_k8s_failure(manager_env, mock_batch_v1, mocker): + mocker.patch("web_api.app.session._reserve_or_get_existing", return_value=None) + write_state = mocker.patch("web_api.app.session._write_session_state") + mock_batch_v1.create_namespaced_job.side_effect = ApiException(status=500) + + from web_api.app import session + + with pytest.raises(HTTPException) as exc: + await session.create_session( + body=session.CreateSessionBody(ownerType="t", ownerId="o"), + ) + assert exc.value.status_code == 502 + assert write_state.call_args.kwargs["reason"] == "manager_job_create_failed" + + +async def test_dispatch_preparing_writes_queue(manager_env, mocker): + mocker.patch("web_api.app.session._read_session_row", return_value={"state": "preparing"}) + write_queue = mocker.patch("web_api.app.session._write_queue_doc") + + from web_api.app import session + + resp = await session.dispatch( + session_id="s1", + body=session.DispatchBody(specUrl="gs://x", runId="r1", jobType="j"), + ) + assert resp["state"] == "queued-pre-ready" + write_queue.assert_called_once() + + +async def test_dispatch_concurrent_preparing_all_enqueued(manager_env, mocker): + """N concurrent pre-ready dispatches each enqueue.""" + mocker.patch("web_api.app.session._read_session_row", return_value={"state": "preparing"}) + write_queue = mocker.patch("web_api.app.session._write_queue_doc") + + from web_api.app import session + + await asyncio.gather( + *[ + session.dispatch( + session_id="s1", + body=session.DispatchBody(specUrl=f"gs://x{i}", runId=f"r{i}", jobType="j"), + ) + for i in range(10) + ] + ) + assert write_queue.call_count == 10 + + +async def test_dispatch_ready_proxies_to_master(manager_env, mocker, aiohttp_mock_session): + mocker.patch( + "web_api.app.session._read_session_row", + return_value={"state": "ready", "controlEndpoint": "http://session-master-s1/"}, + ) + aiohttp_mock_session.set_post_response(status=200, json_payload={"result": 1}) + + from web_api.app import session + + resp = await session.dispatch( + session_id="s1", + body=session.DispatchBody(specUrl="gs://x", runId="r1", jobType="j"), + ) + assert resp["result"] == 1 + + +async def test_dispatch_ready_proxy_unreachable_lazy_down( + manager_env, mocker, aiohttp_mock_session +): + mocker.patch( + "web_api.app.session._read_session_row", + return_value={"state": "ready", "controlEndpoint": "http://nonexistent/"}, + ) + write_state = mocker.patch("web_api.app.session._write_session_state") + aiohttp_mock_session.post.side_effect = aiohttp.ClientConnectionError("no route") + + from web_api.app import session + + with pytest.raises(HTTPException) as exc: + await session.dispatch( + session_id="s1", + body=session.DispatchBody(specUrl="gs://x", runId="r1", jobType="j"), + ) + assert exc.value.status_code == 502 + write_state.assert_called_with("s1", "down", reason="proxy_unreachable") + + +async def test_dispatch_ready_passes_through_master_error( + manager_env, mocker, aiohttp_mock_session +): + """A master HTTP error (e.g. 409 not-ready) surfaces with its own status, + not a generic 500.""" + mocker.patch( + "web_api.app.session._read_session_row", + return_value={"state": "ready", "controlEndpoint": "http://session-master-s1/"}, + ) + aiohttp_mock_session.set_post_response( + status=409, json_payload={"detail": "session state='preparing'"} + ) + + from web_api.app import session + + with pytest.raises(HTTPException) as exc: + await session.dispatch( + session_id="s1", + body=session.DispatchBody(specUrl="gs://x", runId="r1", jobType="j"), + ) + assert exc.value.status_code == 409 + assert exc.value.detail == "session state='preparing'" + + +async def test_status_preparing_returns_queue_depth(manager_env, mocker): + mocker.patch("web_api.app.session._read_session_row", return_value={"state": "preparing"}) + mocker.patch("web_api.app.session._queue_depth", return_value=2) + + from web_api.app import session + + resp = await session.status(session_id="s1") + assert resp == {"state": "preparing", "queueDepth": 2} + + +async def test_terminate_deletes_job_and_service_and_writes_down( + manager_env, mock_batch_v1, mocker +): + svc_delete = mocker.patch("web_api.app.session.service.delete_namespaced_service") + write_state = mocker.patch("web_api.app.session._write_session_state") + + from web_api.app import session + + resp = await session.terminate(session_id="s1") + assert resp == {"state": "down"} + assert mock_batch_v1.delete_namespaced_job.call_args.kwargs["name"] == "session-master-s1" + assert svc_delete.call_args.kwargs["name"] == "session-master-s1" + write_state.assert_called_with("s1", "down", reason="explicit_terminate") + + +async def test_terminate_swallows_job_404(manager_env, mock_batch_v1, mocker): + mocker.patch("web_api.app.session.service.delete_namespaced_service") + mocker.patch("web_api.app.session._write_session_state") + mock_batch_v1.delete_namespaced_job.side_effect = ApiException(status=404) + + from web_api.app import session + + resp = await session.terminate(session_id="gone") + assert resp == {"state": "down"} diff --git a/tests/unit/session/test_master.py b/tests/unit/session/test_master.py new file mode 100644 index 000000000..fe967875a --- /dev/null +++ b/tests/unit/session/test_master.py @@ -0,0 +1,862 @@ +# pylint: disable=protected-access,unused-argument,import-outside-toplevel +import asyncio +import json + +import aiohttp +import pytest +from aiohttp import web +from kubernetes.client.exceptions import ApiException + + +async def test_a_boot_creates_pod_and_service_with_owner_refs( + master_env, mock_k8s_apis, mocker, aiohttp_mock_session +): + """Boot creates Pod + Service with correct ownerReferences.""" + core_mock = mock_k8s_apis + mocker.patch( + "zetta_utils.session.master._read_session_row", + return_value={ + "state": "preparing", + "initialPreload": "try", + "config": {"idleTtlSec": 60}, + }, + ) + mocker.patch("zetta_utils.session.master._read_queue_docs", return_value=[]) + mocker.patch("zetta_utils.session.master._write_session_state") + aiohttp_mock_session.set_get_response(status=200) + + from zetta_utils.session import master + + await master._boot() + + pod_call = core_mock.create_namespaced_pod.call_args + pod_body = pod_call.kwargs["body"] + assert pod_body["metadata"]["ownerReferences"][0]["uid"] == "pod-uid-xyz" + + svc_call = core_mock.create_namespaced_service.call_args + svc_body = svc_call.kwargs["body"] + assert svc_body["metadata"]["name"] == "session-worker-test-uuid-001" + assert svc_body["metadata"]["namespace"] == "sessions" + assert svc_body["metadata"]["ownerReferences"][0]["uid"] == "pod-uid-xyz" + assert svc_body["spec"]["selector"]["sessionId"] == "test-uuid-001" + assert svc_body["spec"]["ports"][0]["port"] == 80 + assert svc_body["spec"]["ports"][0]["targetPort"] == 80 + + +async def test_b_idle_timer_fires_after_ttl(master_env, mock_k8s_apis, mocker): + """Idle timer fires after idleTtlSec.""" + fired = asyncio.Event() + + async def _on_shutdown_capture(*, reason: str) -> None: + fired.set() + + mocker.patch( + "zetta_utils.session.master._on_shutdown", + side_effect=_on_shutdown_capture, + ) + from zetta_utils.session import master + + master._idle_ttl_sec = 0.01 + master._start_idle_timer() + await asyncio.wait_for(fired.wait(), timeout=2.0) + assert fired.is_set() + + +async def test_c_idle_timer_cancels_on_new_dispatch(master_env, mock_k8s_apis, mocker): + """Idle timer cancels on dispatch arrival.""" + from zetta_utils.session import master + + master._idle_ttl_sec = 0.1 + master._start_idle_timer() + master._cancel_idle_timer() + await asyncio.sleep(0.15) + assert master._idle_timer_task is None + + +async def test_e_terminate_cleans_up(master_env, mock_k8s_apis, mocker): + """Terminate path deletes Pod + Service and writes state=down.""" + core_mock = mock_k8s_apis + write_mock = mocker.patch("zetta_utils.session.master._write_session_state") + + from zetta_utils.session import master + + await master._on_shutdown(reason="explicit_terminate") + + core_mock.delete_namespaced_pod.assert_called() + core_mock.delete_namespaced_service.assert_called() + write_mock.assert_called_once_with("down", reason="explicit_terminate") + + +async def test_f_worker_404_writes_down_gracefully( + master_env, mock_k8s_apis, mocker, aiohttp_mock_session +): + """Worker Pod 404 on status poll -> state=down, no crash.""" + core_mock = mock_k8s_apis + core_mock.read_namespaced_pod_status.side_effect = ApiException(status=404) + aiohttp_mock_session.get.side_effect = aiohttp.ClientConnectionError("gone") + + from zetta_utils.session import master + + write_mock = mocker.patch("zetta_utils.session.master._write_session_state") + master._worker_endpoint = "http://session-worker-test/" + + result = await master._status_logic() + assert result["state"] == "down" + write_mock.assert_called_with("down", reason="proxy_unreachable") + + +async def test_g_queue_drain_polls_until_empty( + master_env, mock_k8s_apis, mocker, aiohttp_mock_session +): + """Drain polls until empty (covers the write-read race).""" + read_mock = mocker.patch( + "zetta_utils.session.master._read_queue_docs", + side_effect=[ + [ + { + "dispatchId": "d1", + "specUrl": "gs://1", + "runId": "r1", + "jobType": "j", + "requiredPreload": "try", + } + ], + [ + { + "dispatchId": "d2", + "specUrl": "gs://2", + "runId": "r2", + "jobType": "j", + "requiredPreload": "try", + } + ], + [], + [], + ], + ) + mocker.patch("zetta_utils.session.master._delete_queue_doc") + mocker.patch("zetta_utils.session.master._update_last_dispatch_at") + aiohttp_mock_session.set_post_response( + status=200, + json_payload={"result": None}, + ) + + from zetta_utils.session import master + + master._worker_endpoint = "http://session-worker-test/" + await master._drain_pre_ready_queue() + + assert read_mock.call_count >= 3 + + +async def test_i_phase_failed_with_nonzero_exit_is_permanent(master_env, mock_k8s_apis, mocker): + """phase=Failed + exitCode=1 -> permanent self-check failure.""" + core_mock = mock_k8s_apis + from kubernetes.client import ( + V1ContainerState, + V1ContainerStateTerminated, + V1ContainerStatus, + V1Pod, + V1PodStatus, + ) + + core_mock.read_namespaced_pod_status.return_value = V1Pod( + status=V1PodStatus( + phase="Failed", + container_statuses=[ + V1ContainerStatus( + name="session-worker", + image="x", + image_id="x", + ready=False, + restart_count=0, + state=V1ContainerState(terminated=V1ContainerStateTerminated(exit_code=1)), + ) + ], + ) + ) + mocker.patch("zetta_utils.session.master._terminate_session", new_callable=mocker.AsyncMock) + from zetta_utils.session import master + + verdict = master._classify_worker_failure() + assert verdict == "permanent" + + +async def test_j_phase_pending_is_transient_under_cap(master_env, mock_k8s_apis, mocker): + """phase=Pending -> transient; keep polling within the 60s budget.""" + core_mock = mock_k8s_apis + from kubernetes.client import V1Pod, V1PodStatus + + core_mock.read_namespaced_pod_status.return_value = V1Pod(status=V1PodStatus(phase="Pending")) + from zetta_utils.session import master + + verdict = master._classify_worker_failure() + assert verdict == "transient" + + +async def test_k_sigterm_during_boot_is_safe( + master_env, mock_k8s_apis, mocker, aiohttp_mock_session +): + """SIGTERM during _wait_for_worker_healthz must run _on_shutdown cleanly.""" + core_mock = mock_k8s_apis + mocker.patch( + "zetta_utils.session.master._read_session_row", + return_value={ + "state": "preparing", + "initialPreload": "try", + "config": {"idleTtlSec": 60}, + }, + ) + mocker.patch("zetta_utils.session.master._read_queue_docs", return_value=[]) + mocker.patch("zetta_utils.session.master._write_session_state") + + aiohttp_mock_session.get.side_effect = aiohttp.ClientConnectionError("never ready") + + from zetta_utils.session import master + + boot_task = asyncio.create_task(master._boot()) + await asyncio.sleep(0.05) + + await master._on_shutdown(reason="explicit_terminate") + + assert core_mock.delete_namespaced_pod.call_count == 1 + assert core_mock.delete_namespaced_service.call_count == 1 + + boot_task.cancel() + with pytest.raises((asyncio.CancelledError, aiohttp.ClientConnectionError, SystemExit)): + await boot_task + + +async def test_l_concurrent_dispatches_idle_timer_safe( + master_env, mock_k8s_apis, mocker, aiohttp_mock_session +): + """Two concurrent dispatches must not leave an orphan idle-timer task.""" + from zetta_utils.session import master + + master._idle_ttl_sec = 60 + master._worker_endpoint = "http://session-worker-test/" + + mocker.patch("zetta_utils.session.master._update_last_dispatch_at") + aiohttp_mock_session.set_post_response( + status=200, + json_payload={"result": "ok"}, + ) + + body = {"specUrl": "gs://x", "runId": "r1", "jobType": "j", "requiredPreload": "try"} + results = await asyncio.gather( + master._dispatch_logic(body, authorization="Bearer fake@zetta.ai"), + master._dispatch_logic(body, authorization="Bearer fake@zetta.ai"), + ) + assert all(r["result"] == "ok" for r in results) + + assert master._idle_timer_task is not None + assert not master._idle_timer_task.done() + + +async def test_m_worker_500_surfaces_as_502( + master_env, mock_k8s_apis, mocker, aiohttp_mock_session +): + """Worker /run_spec/ HTTP 500 -> master returns 502 after one bounded retry.""" + from zetta_utils.session import master + + master._worker_endpoint = "http://session-worker-test/" + + mocker.patch("asyncio.sleep", new_callable=mocker.AsyncMock) + + def _raise_500(*_, **__): + request_info = mocker.MagicMock() + history: tuple = () + raise aiohttp.ClientResponseError( + request_info=request_info, + history=history, + status=500, + message="worker_500", + headers=None, + ) + + aiohttp_mock_session.set_post_response(status=500) + aiohttp_mock_session.post.return_value.__aenter__.return_value.raise_for_status = ( + mocker.MagicMock(side_effect=_raise_500) + ) + + with pytest.raises(web.HTTPBadGateway) as exc_info: + await master._forward_dispatch_to_worker( + { + "specUrl": "gs://x", + "runId": "r1", + "jobType": "j", + "requiredPreload": "try", + } + ) + assert exc_info.value.status == 502 + assert exc_info.value.reason == "worker_run_spec_error" + assert aiohttp_mock_session.post.call_count == 2 + + +# ---- Firestore helpers -------------------------------------------------- + + +async def test_read_session_row_returns_dict(mocker): + """_read_session_row returns the snapshot dict via the sessions chain.""" + from zetta_utils.session import master + + db = mocker.MagicMock() + snapshot = mocker.MagicMock() + snapshot.to_dict.return_value = {"state": "ready"} + db.collection.return_value.document.return_value.get.return_value = snapshot + mocker.patch("zetta_utils.session.master._get_sessions_db", return_value=db) + + result = master._read_session_row("sid-1") + + assert result == {"state": "ready"} + db.collection.assert_called_once_with("sessions") + db.collection.return_value.document.assert_called_once_with("sid-1") + db.collection.return_value.document.return_value.get.assert_called_once_with() + + +async def test_read_session_row_none_to_dict_returns_empty(mocker): + """_read_session_row returns {} when to_dict() is None.""" + from zetta_utils.session import master + + db = mocker.MagicMock() + snapshot = mocker.MagicMock() + snapshot.to_dict.return_value = None + db.collection.return_value.document.return_value.get.return_value = snapshot + mocker.patch("zetta_utils.session.master._get_sessions_db", return_value=db) + + assert master._read_session_row("sid-1") == {} + + +async def test_write_session_state_plain(master_env, mocker): + """_write_session_state merges {'state': s} for a non-down state.""" + from zetta_utils.session import master + + db = mocker.MagicMock() + mocker.patch("zetta_utils.session.master._get_sessions_db", return_value=db) + + master._write_session_state("ready") + + doc = db.collection.return_value.document.return_value + doc.set.assert_called_once_with({"state": "ready"}, merge=True) + + +async def test_write_session_state_down_with_reason(master_env, mocker): + """state=down with a reason stamps terminatedAt and terminationReason.""" + from datetime import datetime + + from zetta_utils.session import master + + db = mocker.MagicMock() + mocker.patch("zetta_utils.session.master._get_sessions_db", return_value=db) + + master._write_session_state("down", reason="boom") + + doc = db.collection.return_value.document.return_value + payload = doc.set.call_args.args[0] + assert doc.set.call_args.kwargs == {"merge": True} + assert payload["state"] == "down" + assert isinstance(payload["terminatedAt"], datetime) + assert payload["terminationReason"] == "boom" + + +async def test_write_session_state_down_no_reason(master_env, mocker): + """state=down with reason=None omits the terminationReason key.""" + from zetta_utils.session import master + + db = mocker.MagicMock() + mocker.patch("zetta_utils.session.master._get_sessions_db", return_value=db) + + master._write_session_state("down") + + doc = db.collection.return_value.document.return_value + payload = doc.set.call_args.args[0] + assert payload["state"] == "down" + assert "terminatedAt" in payload + assert "terminationReason" not in payload + + +async def test_read_queue_docs_orders_and_stamps_dispatch_id(mocker): + """_read_queue_docs orders by enqueuedAt asc and stamps dispatchId.""" + from google.cloud import firestore + + from zetta_utils.session import master + + db = mocker.MagicMock() + snap_a = mocker.MagicMock() + snap_a.id = "d-a" + snap_a.to_dict.return_value = {"specUrl": "gs://a"} + snap_b = mocker.MagicMock() + snap_b.id = "d-b" + snap_b.to_dict.return_value = None + query = db.collection.return_value.document.return_value.collection.return_value.order_by + query.return_value.stream.return_value = [snap_a, snap_b] + mocker.patch("zetta_utils.session.master._get_sessions_db", return_value=db) + + docs = master._read_queue_docs("sid-1") + + query.assert_called_once_with("enqueuedAt", direction=firestore.Query.ASCENDING) + assert docs[0]["dispatchId"] == "d-a" + assert docs[0]["specUrl"] == "gs://a" + assert docs[1]["dispatchId"] == "d-b" + + +async def test_delete_queue_doc_chain(mocker): + """_delete_queue_doc deletes the queue document by dispatch id.""" + from zetta_utils.session import master + + db = mocker.MagicMock() + mocker.patch("zetta_utils.session.master._get_sessions_db", return_value=db) + + master._delete_queue_doc("sid-1", "d-1") + + sessions = db.collection.return_value + sessions.document.assert_called_once_with("sid-1") + queue = sessions.document.return_value.collection.return_value + queue.document.assert_called_once_with("d-1") + queue.document.return_value.delete.assert_called_once_with() + + +async def test_update_last_dispatch_at(master_env, mocker): + """_update_last_dispatch_at merges a server-timestamp lastDispatchAt.""" + from google.cloud import firestore + + from zetta_utils.session import master + + db = mocker.MagicMock() + mocker.patch("zetta_utils.session.master._get_sessions_db", return_value=db) + + master._update_last_dispatch_at() + + doc = db.collection.return_value.document.return_value + doc.set.assert_called_once_with({"lastDispatchAt": firestore.SERVER_TIMESTAMP}, merge=True) + + +# ---- main() ------------------------------------------------------------- + + +async def test_main_runs_full_lifecycle(master_env, mocker): + """main() installs the handler, boots, serves, then shuts down once.""" + from zetta_utils.session import master + + sigterm = mocker.patch("zetta_utils.session.master._install_sigterm_handler") + boot = mocker.patch("zetta_utils.session.master._boot", new_callable=mocker.AsyncMock) + serve = mocker.patch( + "zetta_utils.session.master._serve_forever", new_callable=mocker.AsyncMock + ) + on_shutdown = mocker.patch( + "zetta_utils.session.master._on_shutdown", new_callable=mocker.AsyncMock + ) + + await master.main() + + sigterm.assert_called_once_with() + boot.assert_awaited_once_with() + serve.assert_awaited_once_with() + on_shutdown.assert_awaited_once_with(reason="explicit_terminate") + + +async def test_main_shuts_down_when_boot_raises(master_env, mocker): + """main() runs _on_shutdown in finally even if _boot raises.""" + from zetta_utils.session import master + + mocker.patch("zetta_utils.session.master._install_sigterm_handler") + mocker.patch( + "zetta_utils.session.master._boot", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("boom"), + ) + serve = mocker.patch( + "zetta_utils.session.master._serve_forever", new_callable=mocker.AsyncMock + ) + on_shutdown = mocker.patch( + "zetta_utils.session.master._on_shutdown", new_callable=mocker.AsyncMock + ) + + with pytest.raises(RuntimeError): + await master.main() + + serve.assert_not_called() + on_shutdown.assert_awaited_once_with(reason="explicit_terminate") + + +# ---- Boot / classify ---------------------------------------------------- + + +async def test_boot_unexpected_state_exits(master_env, mocker): + """_boot exits with code 2 when the session is not 'preparing'.""" + from zetta_utils.session import master + + mocker.patch( + "zetta_utils.session.master._read_session_row", + return_value={"state": "ready"}, + ) + + with pytest.raises(SystemExit) as exc_info: + await master._boot() + assert exc_info.value.code == 2 + + +async def test_classify_worker_failure_404_permanent(master_env, mock_k8s_apis): + """A 404 on read_namespaced_pod_status classifies as permanent.""" + mock_k8s_apis.read_namespaced_pod_status.side_effect = ApiException(status=404) + + from zetta_utils.session import master + + assert master._classify_worker_failure() == "permanent" + + +async def test_classify_worker_failure_500_reraises(master_env, mock_k8s_apis): + """A non-404 ApiException re-raises out of _classify_worker_failure.""" + mock_k8s_apis.read_namespaced_pod_status.side_effect = ApiException(status=500) + + from zetta_utils.session import master + + with pytest.raises(ApiException): + master._classify_worker_failure() + + +async def test_classify_worker_failure_succeeded_permanent(master_env, mock_k8s_apis): + """phase=Succeeded classifies as permanent (worker is gone).""" + from kubernetes.client import V1Pod, V1PodStatus + + mock_k8s_apis.read_namespaced_pod_status.return_value = V1Pod( + status=V1PodStatus(phase="Succeeded") + ) + + from zetta_utils.session import master + + assert master._classify_worker_failure() == "permanent" + + +async def test_classify_worker_failure_unknown_permanent(master_env, mock_k8s_apis): + """phase=None classifies as permanent.""" + from kubernetes.client import V1Pod, V1PodStatus + + mock_k8s_apis.read_namespaced_pod_status.return_value = V1Pod(status=V1PodStatus(phase=None)) + + from zetta_utils.session import master + + assert master._classify_worker_failure() == "permanent" + + +async def test_wait_for_worker_healthz_timeout( + master_env, mock_k8s_apis, mocker, aiohttp_mock_session +): + """An expired deadline terminates with worker_healthz_timeout.""" + mocker.patch("zetta_utils.session.master.WORKER_HEALTHZ_TIMEOUT_S", -1) + terminate = mocker.patch( + "zetta_utils.session.master._terminate_session", + new_callable=mocker.AsyncMock, + side_effect=SystemExit, + ) + + from zetta_utils.session import master + + master._worker_endpoint = "http://session-worker-test/" + + with pytest.raises(SystemExit): + await master._wait_for_worker_healthz() + terminate.assert_awaited_with("worker_healthz_timeout") + + +async def test_wait_for_worker_healthz_permanent_failure( + master_env, mock_k8s_apis, mocker, aiohttp_mock_session +): + """Repeated refusals + permanent verdict terminates the boot self-check.""" + aiohttp_mock_session.get.side_effect = aiohttp.ClientConnectionError("dead") + mocker.patch( + "zetta_utils.session.master._classify_worker_failure", + return_value="permanent", + ) + mocker.patch("asyncio.sleep", new_callable=mocker.AsyncMock) + terminate = mocker.patch( + "zetta_utils.session.master._terminate_session", + new_callable=mocker.AsyncMock, + side_effect=SystemExit, + ) + + from zetta_utils.session import master + + master._worker_endpoint = "http://session-worker-test/" + + with pytest.raises(SystemExit): + await master._wait_for_worker_healthz() + terminate.assert_awaited_with("worker_permanent_failure") + + +# ---- Terminate / shutdown ---------------------------------------------- + + +async def test_terminate_session_nonidle_exit_code_1(master_env, mocker): + """_terminate_session with a non-idle reason exits 1 and shuts down.""" + from zetta_utils.session import master + + on_shutdown = mocker.patch( + "zetta_utils.session.master._on_shutdown", new_callable=mocker.AsyncMock + ) + + with pytest.raises(SystemExit) as exc_info: + await master._terminate_session("boom") + assert exc_info.value.code == 1 + on_shutdown.assert_awaited_once_with(reason="boom") + + +async def test_terminate_session_idle_exit_code_0(master_env, mocker): + """_terminate_session with reason=idle_timer exits 0.""" + from zetta_utils.session import master + + mocker.patch("zetta_utils.session.master._on_shutdown", new_callable=mocker.AsyncMock) + + with pytest.raises(SystemExit) as exc_info: + await master._terminate_session("idle_timer") + assert exc_info.value.code == 0 + + +async def test_on_shutdown_idempotent(master_env, mock_k8s_apis, mocker): + """_on_shutdown returns early when shutdown already started.""" + from zetta_utils.session import master + + master._shutdown_started = True + write_mock = mocker.patch("zetta_utils.session.master._write_session_state") + + await master._on_shutdown(reason="x") + + write_mock.assert_not_called() + mock_k8s_apis.delete_namespaced_pod.assert_not_called() + + +# ---- Handlers / build_app ---------------------------------------------- + + +async def test_dispatch_handler_assembles_payload(master_env, mocker): + """dispatch handler forwards the assembled payload and authorization.""" + from zetta_utils.session import master + + request = mocker.MagicMock() + request.json = mocker.AsyncMock( + return_value={ + "specUrl": "gs://x", + "runId": "r1", + "jobType": "j", + "requiredPreload": "x", + } + ) + request.headers = {"Authorization": "Bearer t"} + logic = mocker.patch( + "zetta_utils.session.master._dispatch_logic", + new_callable=mocker.AsyncMock, + return_value={"ok": 1}, + ) + + resp = await master.dispatch(request) + + logic.assert_awaited_once_with( + { + "specUrl": "gs://x", + "runId": "r1", + "jobType": "j", + "requiredPreload": "x", + }, + authorization="Bearer t", + ) + assert isinstance(resp.body, bytes) + assert json.loads(resp.body) == {"ok": 1} + + +async def test_status_handler_returns_logic_result(master_env, mocker): + """status handler returns the _status_logic result as JSON.""" + from zetta_utils.session import master + + mocker.patch( + "zetta_utils.session.master._status_logic", + new_callable=mocker.AsyncMock, + return_value={"state": "ready"}, + ) + + resp = await master.status(mocker.MagicMock()) + assert isinstance(resp.body, bytes) + assert json.loads(resp.body) == {"state": "ready"} + + +async def test_terminate_handler_returns_logic_result(master_env, mocker): + """terminate handler returns the _terminate_logic result as JSON.""" + from zetta_utils.session import master + + mocker.patch( + "zetta_utils.session.master._terminate_logic", + new_callable=mocker.AsyncMock, + return_value={"state": "down"}, + ) + + resp = await master.terminate(mocker.MagicMock()) + assert isinstance(resp.body, bytes) + assert json.loads(resp.body) == {"state": "down"} + + +async def test_build_app_registers_routes(master_env): + """_build_app registers POST /dispatch, GET /status, POST /terminate.""" + from zetta_utils.session import master + + app = master._build_app() + registered = { + (route.method, route.resource.canonical) + for route in app.router.routes() + if route.resource is not None + } + assert ("POST", "/dispatch") in registered + assert ("GET", "/status") in registered + assert ("POST", "/terminate") in registered + + +# ---- Endpoint logic ----------------------------------------------------- + + +async def test_status_logic_healthy(master_env, aiohttp_mock_session): + """_status_logic returns ready on a 200 healthz response.""" + aiohttp_mock_session.set_get_response(status=200) + + from zetta_utils.session import master + + master._worker_endpoint = "http://session-worker-test/" + assert await master._status_logic() == {"state": "ready"} + + +async def test_status_logic_unhealthy(master_env, aiohttp_mock_session): + """_status_logic returns down on a 503 healthz response.""" + aiohttp_mock_session.set_get_response(status=503) + + from zetta_utils.session import master + + master._worker_endpoint = "http://session-worker-test/" + assert await master._status_logic() == {"state": "down"} + + +async def test_terminate_logic_shuts_down_and_stops(master_env, mocker): + """_terminate_logic shuts down, requests serve-stop, returns down.""" + from zetta_utils.session import master + + on_shutdown = mocker.patch( + "zetta_utils.session.master._on_shutdown", new_callable=mocker.AsyncMock + ) + stop = mocker.patch("zetta_utils.session.master._request_serve_stop") + + result = await master._terminate_logic() + + assert result == {"state": "down"} + on_shutdown.assert_awaited_once_with(reason="explicit_terminate") + stop.assert_called_once_with() + + +# ---- Forwarding to worker ---------------------------------------------- + + +async def test_forward_dispatch_bare_token_gets_bearer_prefix( + master_env, mocker, aiohttp_mock_session +): + """A bare user token is prefixed with 'Bearer ' on the worker request.""" + from zetta_utils.session import master + + master._worker_endpoint = "http://session-worker-test/" + aiohttp_mock_session.set_post_response(status=200, json_payload={"result": 1}) + mocker.patch("zetta_utils.session.master._update_last_dispatch_at") + + await master._forward_dispatch_to_worker( + {"specUrl": "gs://x", "runId": "r1", "jobType": "j", "requiredPreload": "try"}, + user_token="rawtoken", + ) + + headers = aiohttp_mock_session.post.call_args.kwargs["headers"] + assert headers["Authorization"] == "Bearer rawtoken" + + +async def test_forward_dispatch_4xx_no_retry(master_env, mocker, aiohttp_mock_session): + """A worker 4xx surfaces as a 502 client-error gateway with no retry.""" + from zetta_utils.session import master + + master._worker_endpoint = "http://session-worker-test/" + + def _raise_404(*_, **__): + raise aiohttp.ClientResponseError( + request_info=mocker.MagicMock(), + history=(), + status=404, + message="worker_404", + headers=None, + ) + + aiohttp_mock_session.set_post_response(status=404) + aiohttp_mock_session.post.return_value.__aenter__.return_value.raise_for_status = ( + mocker.MagicMock(side_effect=_raise_404) + ) + + with pytest.raises(web.HTTPBadGateway) as exc_info: + await master._forward_dispatch_to_worker( + {"specUrl": "gs://x", "runId": "r1", "jobType": "j", "requiredPreload": "try"} + ) + assert exc_info.value.status == 502 + assert exc_info.value.reason == "worker_run_spec_client_error" + assert aiohttp_mock_session.post.call_count == 1 + + +async def test_forward_dispatch_conn_refused_permanent(master_env, mocker, aiohttp_mock_session): + """Connection refused + permanent verdict terminates then raises 502.""" + from zetta_utils.session import master + + master._worker_endpoint = "http://session-worker-test/" + aiohttp_mock_session.post.side_effect = aiohttp.ClientConnectionError("dead") + mocker.patch( + "zetta_utils.session.master._classify_worker_failure", + return_value="permanent", + ) + terminate = mocker.patch( + "zetta_utils.session.master._terminate_session", + new_callable=mocker.AsyncMock, + ) + + with pytest.raises(web.HTTPBadGateway) as exc_info: + await master._forward_dispatch_to_worker( + {"specUrl": "gs://x", "runId": "r1", "jobType": "j", "requiredPreload": "try"} + ) + assert exc_info.value.reason == "worker unreachable" + terminate.assert_awaited_with("worker_permanent_failure") + + +async def test_forward_dispatch_conn_refused_transient(master_env, mocker, aiohttp_mock_session): + """Connection refused + transient verdict raises 502 without terminating.""" + from zetta_utils.session import master + + master._worker_endpoint = "http://session-worker-test/" + aiohttp_mock_session.post.side_effect = aiohttp.ClientConnectionError("dead") + mocker.patch( + "zetta_utils.session.master._classify_worker_failure", + return_value="transient", + ) + terminate = mocker.patch( + "zetta_utils.session.master._terminate_session", + new_callable=mocker.AsyncMock, + ) + + with pytest.raises(web.HTTPBadGateway) as exc_info: + await master._forward_dispatch_to_worker( + {"specUrl": "gs://x", "runId": "r1", "jobType": "j", "requiredPreload": "try"} + ) + assert exc_info.value.reason == "worker unreachable" + terminate.assert_not_called() + assert aiohttp_mock_session.post.call_count == 1 + + +# ---- Idle timer --------------------------------------------------------- + + +async def test_start_idle_timer_is_idempotent(master_env, mocker): + """A second _start_idle_timer call does not replace the live task.""" + from zetta_utils.session import master + + master._idle_ttl_sec = 60 + master._start_idle_timer() + task1 = master._idle_timer_task + master._start_idle_timer() + try: + assert master._idle_timer_task is task1 + finally: + master._cancel_idle_timer() diff --git a/tests/unit/session/test_pod_create.py b/tests/unit/session/test_pod_create.py new file mode 100644 index 000000000..664680330 --- /dev/null +++ b/tests/unit/session/test_pod_create.py @@ -0,0 +1,55 @@ +# pylint: disable=protected-access,unused-argument,import-outside-toplevel +import pytest + + +def test_create_namespaced_pod_passes_through(mocker): + from zetta_utils.cloud_management.resource_allocation.k8s import pod + + core_mock = mocker.patch( + "zetta_utils.cloud_management.resource_allocation.k8s.pod.k8s_client.CoreV1Api" + ).return_value + body = {"metadata": {"name": "p", "namespace": "sessions"}} + pod.create_namespaced_pod(namespace="sessions", body=body) + core_mock.create_namespaced_pod.assert_called_once_with( + namespace="sessions", + body=body, + ) + + +def test_worker_template_substitution(master_env): + """The rendered Pod body must carry the locked invariants.""" + from zetta_utils.session import master + + body = master._render_worker_template(initial_preload="try") + assert body["metadata"]["name"] == "session-worker-test-uuid-001" + assert body["metadata"]["namespace"] == "sessions" + or_ref = body["metadata"]["ownerReferences"][0] + assert or_ref["uid"] == "pod-uid-xyz" + assert or_ref["name"] == "session-master-test-uuid-001-abcd" + assert or_ref["controller"] is True + + +def test_delete_swallows_404(mocker): + from kubernetes.client.exceptions import ApiException + + from zetta_utils.cloud_management.resource_allocation.k8s import pod + + core_mock = mocker.patch( + "zetta_utils.cloud_management.resource_allocation.k8s.pod.k8s_client.CoreV1Api" + ).return_value + core_mock.delete_namespaced_pod.side_effect = ApiException(status=404) + # Must NOT raise. + pod.delete_namespaced_pod(name="missing", namespace="sessions") + + +def test_delete_propagates_500(mocker): + from kubernetes.client.exceptions import ApiException + + from zetta_utils.cloud_management.resource_allocation.k8s import pod + + core_mock = mocker.patch( + "zetta_utils.cloud_management.resource_allocation.k8s.pod.k8s_client.CoreV1Api" + ).return_value + core_mock.delete_namespaced_pod.side_effect = ApiException(status=500) + with pytest.raises(ApiException): + pod.delete_namespaced_pod(name="p", namespace="sessions") diff --git a/tests/unit/session/test_reconcile.py b/tests/unit/session/test_reconcile.py new file mode 100644 index 000000000..029d68396 --- /dev/null +++ b/tests/unit/session/test_reconcile.py @@ -0,0 +1,275 @@ +# pylint: disable=redefined-outer-name,unused-argument,import-outside-toplevel,protected-access +from datetime import datetime, timedelta, timezone + +import pytest +from google.cloud.firestore_v1.base_query import FieldFilter +from kubernetes.client.exceptions import ApiException + + +@pytest.fixture +def mock_apis(mocker): + batch = mocker.patch("zetta_utils.session.reconcile.k8s_client.BatchV1Api").return_value + pod_mod = mocker.patch("zetta_utils.session.reconcile.pod", autospec=True) + svc_mod = mocker.patch("zetta_utils.session.reconcile.service", autospec=True) + return batch, pod_mod, svc_mod + + +def test_healthy_session_untouched(mock_apis, mocker): + batch, _, _ = mock_apis + fresh_ts = datetime.now(timezone.utc) - timedelta(minutes=5) + mocker.patch( + "zetta_utils.session.reconcile._query_non_down_sessions", + return_value=[ + { + "sessionId": "fresh-1", + "state": "ready", + "lastDispatchAt": fresh_ts, + "createdAt": fresh_ts, + } + ], + ) + batch.read_namespaced_job.return_value = mocker.Mock() + write_mock = mocker.patch("zetta_utils.session.reconcile._write_session_state") + + from zetta_utils.session import reconcile + + summary = reconcile.run_reconcile() + assert summary["reconciledCount"] == 0 + write_mock.assert_not_called() + + +def test_stale_but_alive_deletes_master_job(mock_apis, mocker): + """Stale-by-time with the Job still present -> delete the master Job + (cascade), NOT the worker resources directly.""" + batch, pod_mod, svc_mod = mock_apis + old_ts = datetime.now(timezone.utc) - timedelta(hours=30) + mocker.patch( + "zetta_utils.session.reconcile._query_non_down_sessions", + return_value=[ + {"sessionId": "old-1", "state": "ready", "lastDispatchAt": old_ts, "createdAt": old_ts} + ], + ) + batch.read_namespaced_job.return_value = mocker.Mock() # Job exists + write_mock = mocker.patch("zetta_utils.session.reconcile._write_session_state") + + from zetta_utils.session import reconcile + + summary = reconcile.run_reconcile() + assert summary["reconciledCount"] == 1 + assert summary["staleByTime"] == 1 + assert write_mock.call_args.args[0] == "old-1" + assert write_mock.call_args.kwargs["reason"] == "reconcile_stale_24h" + batch.delete_namespaced_job.assert_called_once() + assert batch.delete_namespaced_job.call_args.kwargs["name"] == "session-master-old-1" + assert batch.delete_namespaced_job.call_args.kwargs["propagation_policy"] == "Background" + # Cascade-GC handles the rest; reconcile does not touch worker resources. + pod_mod.delete_namespaced_pod.assert_not_called() + svc_mod.delete_namespaced_service.assert_not_called() + + +def test_master_missing_deletes_orphans(mock_apis, mocker): + """Master Job 404 -> manually delete the orphan worker Pod, worker Service, + and master Service; do not attempt a Job delete.""" + batch, pod_mod, svc_mod = mock_apis + fresh_ts = datetime.now(timezone.utc) - timedelta(minutes=5) + mocker.patch( + "zetta_utils.session.reconcile._query_non_down_sessions", + return_value=[ + { + "sessionId": "orphan-1", + "state": "ready", + "lastDispatchAt": fresh_ts, + "createdAt": fresh_ts, + } + ], + ) + batch.read_namespaced_job.side_effect = ApiException(status=404) + write_mock = mocker.patch("zetta_utils.session.reconcile._write_session_state") + + from zetta_utils.session import reconcile + + summary = reconcile.run_reconcile() + assert summary["reconciledCount"] == 1 + assert summary["staleByMissingMaster"] == 1 + write_mock.assert_called_once_with("orphan-1", "down", reason="reconcile_master_missing") + pod_mod.delete_namespaced_pod.assert_called_once() + assert svc_mod.delete_namespaced_service.call_count == 2 # worker + master + batch.delete_namespaced_job.assert_not_called() + + +def test_cleanup_failure_counts_but_does_not_crash(mock_apis, mocker): + """A non-404 error during cleanup increments cleanupErrors, never crashes.""" + batch, _, _ = mock_apis + old_ts = datetime.now(timezone.utc) - timedelta(hours=30) + mocker.patch( + "zetta_utils.session.reconcile._query_non_down_sessions", + return_value=[ + {"sessionId": "err-1", "state": "ready", "lastDispatchAt": old_ts, "createdAt": old_ts} + ], + ) + batch.read_namespaced_job.return_value = mocker.Mock() # stale-but-alive + mocker.patch("zetta_utils.session.reconcile._write_session_state") + batch.delete_namespaced_job.side_effect = ApiException(status=500) + + from zetta_utils.session import reconcile + + summary = reconcile.run_reconcile() + assert summary["reconciledCount"] == 1 + assert summary["cleanupErrors"] == 1 + + +def test_no_last_dispatch_falls_back_to_created_at(mock_apis, mocker): + batch, _, _ = mock_apis + old_ts = datetime.now(timezone.utc) - timedelta(hours=30) + mocker.patch( + "zetta_utils.session.reconcile._query_non_down_sessions", + return_value=[ + { + "sessionId": "never-dispatched", + "state": "preparing", + "lastDispatchAt": None, + "createdAt": old_ts, + } + ], + ) + batch.read_namespaced_job.return_value = mocker.Mock() + write_mock = mocker.patch("zetta_utils.session.reconcile._write_session_state") + + from zetta_utils.session import reconcile + + summary = reconcile.run_reconcile() + assert summary["staleByTime"] == 1 + write_mock.assert_called_once() + + +def test_loki_line_emitted(mock_apis, mocker, caplog): + mocker.patch("zetta_utils.session.reconcile._query_non_down_sessions", return_value=[]) + from zetta_utils.session import reconcile + + with caplog.at_level("INFO", logger="zetta_utils.session.reconcile"): + reconcile.run_reconcile() + assert any("sessions.reconcile.scan_complete" in r.message for r in caplog.records) + + +def test_is_stale_by_time_both_none_returns_false(mocker): + """When both lastDispatchAt and createdAt are None, _is_stale_by_time -> False.""" + from zetta_utils.session import reconcile + + result = reconcile._is_stale_by_time( + {"lastDispatchAt": None, "createdAt": None}, datetime.now(timezone.utc) + ) + assert result is False + + +def test_is_stale_by_time_old_created_at_returns_true(mocker): + """When lastDispatchAt is None but createdAt is old, _is_stale_by_time -> True.""" + from zetta_utils.session import reconcile + + old_ts = datetime.now(timezone.utc) - timedelta(hours=30) + result = reconcile._is_stale_by_time( + {"lastDispatchAt": None, "createdAt": old_ts}, datetime.now(timezone.utc) + ) + assert result is True + + +def test_is_master_missing_410_returns_true(mock_apis, mocker): + """ApiException(status=410) from read_namespaced_job -> _is_master_missing True.""" + batch, _, _ = mock_apis + batch.read_namespaced_job.side_effect = ApiException(status=410) + + from zetta_utils.session import reconcile + + assert reconcile._is_master_missing(batch, "sid") is True + + +def test_is_master_missing_500_reraises(mock_apis, mocker): + """ApiException(status=500) from read_namespaced_job is re-raised.""" + batch, _, _ = mock_apis + batch.read_namespaced_job.side_effect = ApiException(status=500) + + from zetta_utils.session import reconcile + + with pytest.raises(ApiException): + reconcile._is_master_missing(batch, "sid") + + +def test_query_non_down_sessions_yields_rows(mocker): + """_query_non_down_sessions yields rows with sessionId from each snapshot.""" + from zetta_utils.session import reconcile + + snap_a = mocker.MagicMock() + snap_a.id = "sess-aaa" + snap_a.to_dict.return_value = {"state": "ready"} + + snap_b = mocker.MagicMock() + snap_b.id = "sess-bbb" + snap_b.to_dict.return_value = None + + mock_query = mocker.MagicMock() + mock_query.stream.return_value = [snap_a, snap_b] + + mock_collection = mocker.MagicMock() + mock_where = mocker.MagicMock() + mock_where.return_value = mock_query + mock_collection.where = mock_where + + mock_db = mocker.MagicMock() + mock_db.collection.return_value = mock_collection + + mocker.patch("zetta_utils.session.reconcile._get_sessions_db", return_value=mock_db) + + rows = list(reconcile._query_non_down_sessions()) + + assert len(rows) == 2 + assert rows[0]["sessionId"] == "sess-aaa" + assert rows[0]["state"] == "ready" + assert rows[1]["sessionId"] == "sess-bbb" + + where_call = mock_collection.where.call_args + passed_filter = where_call.kwargs.get("filter") or where_call.args[0] + assert isinstance(passed_filter, FieldFilter) + assert passed_filter.field_path == "state" + assert passed_filter.value == "down" + + +def test_write_session_state_down_includes_timestamps(mocker): + """state='down' payload includes terminatedAt and terminationReason.""" + from google.cloud import firestore + + from zetta_utils.session import reconcile + + mock_doc = mocker.MagicMock() + mock_collection = mocker.MagicMock() + mock_collection.document.return_value = mock_doc + mock_db = mocker.MagicMock() + mock_db.collection.return_value = mock_collection + mocker.patch("zetta_utils.session.reconcile._get_sessions_db", return_value=mock_db) + + reconcile._write_session_state("sess-123", "down", reason="reconcile_stale_24h") + + mock_collection.document.assert_called_once_with("sess-123") + set_call = mock_doc.set.call_args + payload = set_call.args[0] + assert payload["state"] == "down" + assert payload["terminationReason"] == "reconcile_stale_24h" + assert payload["terminatedAt"] is firestore.SERVER_TIMESTAMP + assert set_call.kwargs.get("merge") is True + + +def test_write_session_state_non_down_minimal_payload(mocker): + """state != 'down' payload contains only the state key.""" + from zetta_utils.session import reconcile + + mock_doc = mocker.MagicMock() + mock_collection = mocker.MagicMock() + mock_collection.document.return_value = mock_doc + mock_db = mocker.MagicMock() + mock_db.collection.return_value = mock_collection + mocker.patch("zetta_utils.session.reconcile._get_sessions_db", return_value=mock_db) + + reconcile._write_session_state("sess-456", "ready", reason="unused") + + set_call = mock_doc.set.call_args + payload = set_call.args[0] + assert payload == {"state": "ready"} + assert set_call.kwargs.get("merge") is True diff --git a/tests/unit/session/test_worker_app.py b/tests/unit/session/test_worker_app.py new file mode 100644 index 000000000..d805ab1a1 --- /dev/null +++ b/tests/unit/session/test_worker_app.py @@ -0,0 +1,49 @@ +# pylint: disable=import-error,wrong-import-position,import-outside-toplevel +import pytest + +pytest.importorskip("fastapi") + +from starlette.routing import Mount + + +def test_worker_app_exposes_healthz(): + from web_api.app import worker + + paths = {getattr(route, "path", None) for route in worker.app.routes} + assert "/healthz" in paths + + +def test_worker_app_exposes_run_spec(): + from web_api.app import worker + + paths = {getattr(route, "path", None) for route in worker.app.routes} + assert "/run_spec/" in paths + + +def test_worker_app_has_no_portal_routers(): + from web_api.app import worker + + mount_paths = {route.path for route in worker.app.routes if isinstance(route, Mount)} + portal_mounts = { + "/alignment", + "/annotations", + "/collections", + "/layer_groups", + "/layers", + "/painting", + "/precomputed", + "/segmentation", + "/sessions", + "/tasks", + } + assert portal_mounts.isdisjoint(mount_paths) + + +def test_worker_app_has_auth_middleware(): + from web_api.app import worker + from web_api.app.auth import check_authorized_user + + middleware_funcs = [ + getattr(m, "kwargs", {}).get("dispatch") for m in worker.app.user_middleware + ] + assert check_authorized_user in middleware_funcs diff --git a/web_api/app/auth.py b/web_api/app/auth.py new file mode 100644 index 000000000..06b9434ac --- /dev/null +++ b/web_api/app/auth.py @@ -0,0 +1,38 @@ +# pylint: disable=all # type: ignore +import os + +from fastapi import Request, Response +from google.auth.transport import requests + +# from google.cloud import iap_v1 +# from google.iam.v1 import iam_policy_pb2 +from google.oauth2 import id_token + + +async def check_authorized_user(request: Request, call_next): + if request.method != "OPTIONS" and request.url.path != "/healthz": + try: + token = request.headers["authorization"].split()[-1] + except (KeyError, IndexError): + return Response(content="Missing auth token.", status_code=401) + + client_id = os.environ["OAUTH_CLIENT_ID"] + try: + idinfo = id_token.verify_oauth2_token(token, requests.Request(), client_id) + except Exception as exc: # pylint: disable=broad-exception-caught + return Response(content=str(exc), status_code=401) + + if not idinfo["email"].endswith("@zetta.ai"): + return Response(content="User not authorized.", status_code=401) + # user = f"user:{idinfo['email']}" + # client = iap_v1.IdentityAwareProxyAdminServiceClient() + + # iap_resource = os.environ["IAP_RESOURCE"] + # request = iam_policy_pb2.GetIamPolicyRequest(resource=iap_resource) + # policy = client.get_iam_policy(request=request) + # members = set(policy.bindings[0].members) + # if user not in members: + # return Response(content="User not authorized.", status_code=401) + + response = await call_next(request) + return response diff --git a/web_api/app/main.py b/web_api/app/main.py index 4f1c4bd16..032c0a84f 100644 --- a/web_api/app/main.py +++ b/web_api/app/main.py @@ -1,23 +1,19 @@ # pylint: disable=all # type: ignore -import os import sys -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from google.auth.transport import requests - -# from google.cloud import iap_v1 -# from google.iam.v1 import iam_policy_pb2 -from google.oauth2 import id_token from .alignment import api as alignment_api from .annotations import api as annotations_api +from .auth import check_authorized_user from .collections import api as collections_api from .layer_groups import api as layer_groups_api from .layers import api as layers_api from .painting import api as painting_api from .precomputed_annotations import api as precomputed_annotations_api from .segmentation import api as segmentation_api +from .session import api as session_api from .tasks import api as tasks_api app = FastAPI() @@ -38,37 +34,11 @@ app.mount("/painting", painting_api) app.mount("/precomputed", precomputed_annotations_api) app.mount("/segmentation", segmentation_api) +app.mount("/sessions", session_api) app.mount("/tasks", tasks_api) -@app.middleware("http") -async def check_authorized_user(request: Request, call_next): - if request.method != "OPTIONS" and request.url.path != "/healthz": - try: - token = request.headers["authorization"].split()[-1] - except (KeyError, IndexError): - return Response(content="Missing auth token.", status_code=401) - - client_id = os.environ["OAUTH_CLIENT_ID"] - try: - idinfo = id_token.verify_oauth2_token(token, requests.Request(), client_id) - except Exception as exc: # pylint: disable=broad-exception-caught - return Response(content=str(exc), status_code=401) - - if not idinfo["email"].endswith("@zetta.ai"): - return Response(content="User not authorized.", status_code=401) - # user = f"user:{idinfo['email']}" - # client = iap_v1.IdentityAwareProxyAdminServiceClient() - - # iap_resource = os.environ["IAP_RESOURCE"] - # request = iam_policy_pb2.GetIamPolicyRequest(resource=iap_resource) - # policy = client.get_iam_policy(request=request) - # members = set(policy.bindings[0].members) - # if user not in members: - # return Response(content="User not authorized.", status_code=401) - - response = await call_next(request) - return response +app.middleware("http")(check_authorized_user) @app.get("/") diff --git a/web_api/app/session.py b/web_api/app/session.py new file mode 100644 index 000000000..38f3eee12 --- /dev/null +++ b/web_api/app/session.py @@ -0,0 +1,378 @@ +# pylint: disable=all # type: ignore +""" +Session orchestration router, mounted under ``/sessions`` in the main API. + +Endpoints: + POST / + POST /{session_id}/dispatch + GET /{session_id}/status + DELETE /{session_id} + +Auth is enforced by the main API's OAuth middleware. State is 100% +Firestore-backed (main DB). HTTP proxying to the master uses aiohttp. +""" + +import asyncio +import logging +import os +import uuid +from pathlib import Path +from typing import Literal + +import aiohttp +import yaml +from fastapi import FastAPI, HTTPException +from google.cloud import firestore +from google.cloud.firestore_v1.base_query import FieldFilter +from pydantic import BaseModel + +from kubernetes import client as k8s_client +from zetta_utils.cloud_management.resource_allocation.k8s import service +from zetta_utils.session import _get_sessions_db + +log = logging.getLogger(__name__) + +WORKLOAD_NAMESPACE = os.environ.get("WORKLOAD_NAMESPACE", "sessions") + + +def _sessions_image_tag() -> str: + return os.environ["SESSIONS_IMAGE_TAG"] + + +def _master_template_path() -> str: + return os.environ["SESSION_MASTER_TEMPLATE_PATH"] + + +def _master_service_template_path() -> str: + return os.environ["SESSION_MASTER_SERVICE_TEMPLATE_PATH"] + + +class CreateSessionBody(BaseModel): + ownerType: str + ownerId: str + initialPreload: Literal["none", "try", "inference", "training", "all"] = "try" + jobType: str | None = None + config: dict | None = None + + +class CreateSessionResponse(BaseModel): + sessionId: str + controlEndpoint: str + workerEndpoint: str + state: Literal["preparing", "ready"] + + +class DispatchBody(BaseModel): + specUrl: str + runId: str + jobType: str + requiredPreload: Literal["none", "try", "inference", "training", "all"] = "try" + + +# ---- Firestore helpers -------------------------------------------------- + + +def _read_session_row(session_id: str) -> dict | None: + """Read the ``sessions/`` document; ``None`` if absent.""" + snap = _get_sessions_db().collection("sessions").document(session_id).get() + return snap.to_dict() if snap.exists else None # type: ignore[union-attr] + + +def _write_session_state(session_id: str, state: str, *, reason: str | None = None) -> None: + """Merge ``state`` onto ``sessions/``. + + When transitioning to ``down``, also stamps ``terminatedAt`` and + ``terminationReason``. + """ + payload: dict = {"state": state} + if state == "down": + payload["terminatedAt"] = firestore.SERVER_TIMESTAMP + if reason is not None: + payload["terminationReason"] = reason + _get_sessions_db().collection("sessions").document(session_id).set(payload, merge=True) + + +def _write_queue_doc(session_id: str, dispatch_id: str, fields: dict) -> None: + """Write ``sessions//queue/``.""" + ( + _get_sessions_db() + .collection("sessions") + .document(session_id) + .collection("queue") + .document(dispatch_id) + .set(fields) + ) + + +def _queue_depth(session_id: str) -> int: + """Count documents under ``sessions//queue``.""" + docs = ( + _get_sessions_db().collection("sessions").document(session_id).collection("queue").stream() + ) + return sum(1 for _ in docs) + + +def _reserve_or_get_existing( + owner_type: str, owner_id: str, session_id: str, new_row: dict +) -> dict | None: + """Atomically reuse-or-reserve a session for an owner. + + In one Firestore transaction: if an active session (state in + ``{preparing, ready}``) for the owner already exists, return it (carrying + its ``sessionId`` and ``state``); otherwise write ``new_row`` at + ``sessions/`` and return ``None``. + + Guards the lookup-then-create race so two concurrent ``POST /`` for the + same owner cannot both spawn a master. Firestore requires all reads before + any write inside a transaction, and writes go through the transaction + object. Needs the composite index + ``(ownerType, ownerId, state, createdAt DESC)``. + """ + db = _get_sessions_db() + sessions = db.collection("sessions") + transaction = db.transaction() + + @firestore.transactional + def _run(txn) -> dict | None: + query = ( + sessions.where(filter=FieldFilter("ownerType", "==", owner_type)) + .where(filter=FieldFilter("ownerId", "==", owner_id)) + .where(filter=FieldFilter("state", "in", ["preparing", "ready"])) + .order_by("createdAt", direction=firestore.Query.DESCENDING) + .limit(1) + ) + found = list(query.stream(transaction=txn)) + if found: + row = found[0].to_dict() or {} + row["sessionId"] = found[0].id + return row + txn.set(sessions.document(session_id), new_row) + return None + + return _run(transaction) + + +# ---- Proxy helpers ------------------------------------------------------ + + +async def _safe_detail(response: aiohttp.ClientResponse) -> str: + """Best-effort extract the master's error ``detail``. + + Reads a JSON ``detail`` field when present; falls back to the HTTP reason + phrase when the body is not JSON or the field is absent. + """ + try: + return (await response.json()).get("detail", response.reason) + except Exception: # pylint: disable=broad-exception-caught + return response.reason or "session master error" + + +# ---- K8s rendering ------------------------------------------------------ + + +def _build_endpoints(session_id: str) -> tuple[str, str]: + """Return the ``(controlEndpoint, workerEndpoint)`` cluster DNS URLs.""" + control = f"http://session-master-{session_id}.{WORKLOAD_NAMESPACE}.svc.cluster.local/" + worker = f"http://session-worker-{session_id}.{WORKLOAD_NAMESPACE}.svc.cluster.local/" + return control, worker + + +def _render_master_job(*, session_id: str) -> dict: + """Load the master Job YAML template and substitute placeholders.""" + raw = Path(_master_template_path()).read_text() + substituted = raw.replace("${SESSION_ID}", session_id).replace( + "${SESSIONS_IMAGE_TAG}", _sessions_image_tag() + ) + return yaml.safe_load(substituted) + + +def _render_master_service(*, session_id: str, job_uid: str) -> dict: + """Load the master Service YAML template and substitute placeholders.""" + raw = Path(_master_service_template_path()).read_text() + substituted = raw.replace("${SESSION_ID}", session_id).replace("${MASTER_JOB_UID}", job_uid) + return yaml.safe_load(substituted) + + +# ---- Endpoints ---------------------------------------------------------- + + +api = FastAPI() + + +@api.post("/", response_model=CreateSessionResponse, status_code=201) +async def create_session(body: CreateSessionBody) -> CreateSessionResponse: + session_id = str(uuid.uuid4()) + control, worker_url = _build_endpoints(session_id) + + config = dict(body.config or {}) + config.setdefault("idleTtlSec", 3600) + config.setdefault("maxDispatches", 50) + + new_row = { + "ownerType": body.ownerType, + "ownerId": body.ownerId, + "state": "preparing", + "controlEndpoint": control, + "workerEndpoint": worker_url, + "initialPreload": body.initialPreload, + "jobType": body.jobType, + "config": config, + "createdAt": firestore.SERVER_TIMESTAMP, + } + + # Atomic reuse-or-reserve closes the concurrent-create race: either we + # reserve session_id (returns None) or we get back a pre-existing/raced + # active session and skip Job creation entirely. + existing = _reserve_or_get_existing(body.ownerType, body.ownerId, session_id, new_row) + if existing is not None: + ctrl, wkr = _build_endpoints(existing["sessionId"]) + log.info( + "sessions.session.reused", + extra={ + "sessionId": existing["sessionId"], + "ownerType": body.ownerType, + "ownerId": body.ownerId, + }, + ) + return CreateSessionResponse( + sessionId=existing["sessionId"], + controlEndpoint=ctrl, + workerEndpoint=wkr, + state=existing["state"], + ) + + try: + job = k8s_client.BatchV1Api().create_namespaced_job( + namespace=WORKLOAD_NAMESPACE, + body=_render_master_job(session_id=session_id), + ) + # The master Service is what makes controlEndpoint resolve. + # ownerReferences point at the Job, so it is cascade-GC'd when the Job + # is deleted (terminate) or TTL-reaped after the master exits. + service.create_namespaced_service( + namespace=WORKLOAD_NAMESPACE, + body=_render_master_service(session_id=session_id, job_uid=job.metadata.uid), + ) + except Exception as e: # pylint: disable=broad-exception-caught + _write_session_state(session_id, "down", reason="manager_job_create_failed") + log.error( + "sessions.manager.session_create_failed", + extra={ + "ownerType": body.ownerType, + "ownerId": body.ownerId, + "error": str(e), + }, + ) + raise HTTPException(status_code=502, detail="manager could not create master Job/Service") + + log.info( + "sessions.session.created", + extra={ + "sessionId": session_id, + "ownerType": body.ownerType, + "jobType": body.jobType, + }, + ) + return CreateSessionResponse( + sessionId=session_id, + controlEndpoint=control, + workerEndpoint=worker_url, + state="preparing", + ) + + +@api.post("/{session_id}/dispatch") +async def dispatch(session_id: str, body: DispatchBody) -> dict: + row = _read_session_row(session_id) + if row is None: + raise HTTPException(status_code=404, detail="unknown session") + + if row["state"] == "preparing": + dispatch_id = str(uuid.uuid4()) + _write_queue_doc( + session_id, + dispatch_id, + { + "specUrl": body.specUrl, + "runId": body.runId, + "jobType": body.jobType, + "requiredPreload": body.requiredPreload, + "enqueuedAt": firestore.SERVER_TIMESTAMP, + }, + ) + return {"runId": body.runId, "state": "queued-pre-ready"} + + if row["state"] == "ready": + try: + timeout = aiohttp.ClientTimeout(total=None) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post( + f"{row['controlEndpoint']}dispatch", + json=body.model_dump(), + ) as response: + if response.status >= 400: + # Surface the master's status + detail (401/409/502...) + # instead of a generic 500. HTTPException propagates + # past the connection-error handler below untouched. + raise HTTPException( + status_code=response.status, + detail=await _safe_detail(response), + ) + return await response.json() + except (aiohttp.ClientConnectionError, asyncio.TimeoutError): + _write_session_state(session_id, "down", reason="proxy_unreachable") + log.warning( + "sessions.manager.proxy_unreachable", + extra={"sessionId": session_id, "endpoint": row["controlEndpoint"]}, + ) + raise HTTPException(status_code=502, detail="session master unreachable") + + raise HTTPException(status_code=409, detail=f"session state={row['state']!r}") + + +@api.get("/{session_id}/status") +async def status(session_id: str) -> dict: + row = _read_session_row(session_id) + if row is None: + raise HTTPException(status_code=404) + + if row["state"] == "preparing": + return {"state": "preparing", "queueDepth": _queue_depth(session_id)} + + try: + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(f"{row['controlEndpoint']}status") as response: + if response.status >= 400: + raise HTTPException( + status_code=response.status, + detail=await _safe_detail(response), + ) + return await response.json() + except (aiohttp.ClientConnectionError, asyncio.TimeoutError): + _write_session_state(session_id, "down", reason="proxy_unreachable") + log.warning( + "sessions.manager.proxy_unreachable", + extra={"sessionId": session_id, "endpoint": row["controlEndpoint"]}, + ) + return {"state": "down"} + + +@api.delete("/{session_id}") +async def terminate(session_id: str) -> dict: + try: + k8s_client.BatchV1Api().delete_namespaced_job( + name=f"session-master-{session_id}", + namespace=WORKLOAD_NAMESPACE, + propagation_policy="Background", + ) + except k8s_client.exceptions.ApiException as e: + if e.status not in (404, 410): + raise + # ownerReferences point the master Service at the Job for cascade GC, but + # delete it explicitly for promptness (swallows 404/410). + service.delete_namespaced_service( + name=f"session-master-{session_id}", namespace=WORKLOAD_NAMESPACE + ) + _write_session_state(session_id, "down", reason="explicit_terminate") + return {"state": "down"} diff --git a/web_api/app/worker.py b/web_api/app/worker.py new file mode 100644 index 000000000..02e617bc4 --- /dev/null +++ b/web_api/app/worker.py @@ -0,0 +1,202 @@ +# pylint: disable=all # type: ignore +import asyncio +import json +import logging +import os +import random +import tempfile +from datetime import datetime, timezone +from pathlib import Path +from typing import Literal + +import fsspec +import numpy as np +import torch +from fastapi import FastAPI, HTTPException, Response +from pydantic import BaseModel + +from zetta_utils import builder, parsing, run, setup_environment +from zetta_utils.common import ctx_managers +from zetta_utils.run import run_ctx_manager + +from .auth import check_authorized_user + +app = FastAPI(redirect_slashes=False) +app.middleware("http")(check_authorized_user) + +_run_spec_semaphore = asyncio.Semaphore(1) +_loaded_preload: str = os.environ.get("INITIAL_PRELOAD", "try") +_PRELOAD_RANK = {"none": 0, "try": 1, "inference": 2, "training": 3, "all": 4} +_RUN_SPEC_DOWNLOAD_TIMEOUT_SEC = int(os.environ.get("RUN_SPEC_DOWNLOAD_TIMEOUT_SEC", "30")) + +log = logging.getLogger(__name__) + +PreloadMode = Literal["none", "try", "inference", "training", "all"] + + +class RunSpecBody(BaseModel): + specUrl: str + runId: str + jobType: str + requiredPreload: PreloadMode = "try" + + +class RunSpecResponse(BaseModel): + result: object + + +@app.get("/healthz") +async def health_check(): + return {"status": "healthy"} + + +@app.post("/run_spec/") +async def run_spec(body: RunSpecBody) -> Response: + global _loaded_preload + + session_id = os.environ.get("SESSION_ID", "") + log.info( + "sessions.worker.dispatch_received", + extra={ + "sessionId": session_id, + "runId": body.runId, + "jobType": body.jobType, + "requiredPreload": body.requiredPreload, + }, + ) + log.info( + "sessions.dispatch.semaphore_wait_started", + extra={"sessionId": session_id, "runId": body.runId}, + ) + + queued_at = datetime.now(timezone.utc).timestamp() + with run_ctx_manager( + main_run_process=True, + run_id=body.runId, + spec={}, + queued_at=queued_at, + ) as ctx: + _log_dispatch_state( + "queued", session_id=session_id, run_id=body.runId, job_type=body.jobType + ) + + async with _run_spec_semaphore: + ctx.transition_to_running() + _log_dispatch_state( + "running", session_id=session_id, run_id=body.runId, job_type=body.jobType + ) + try: + _upgrade_preload_if_needed(body.requiredPreload, session_id) + spec = await _download_and_parse_spec(body.specUrl) + with ctx_managers.set_env_ctx_mngr( + ZETTA_RUN_SPEC_PATH=body.specUrl, + CURRENT_BUILD_SPEC=json.dumps(spec), + ): + run.record_run(spec) + result = builder.build(spec) + _log_dispatch_state( + "completed", session_id=session_id, run_id=body.runId, job_type=body.jobType + ) + log.info( + "sessions.worker.dispatch_completed", + extra={ + "sessionId": session_id, + "runId": body.runId, + "outcome": "completed", + }, + ) + _light_cleanup() + + return Response( + content=RunSpecResponse(result=result).model_dump_json(), + media_type="application/json", + ) + except Exception: + _log_dispatch_state( + "failed", session_id=session_id, run_id=body.runId, job_type=body.jobType + ) + log.exception( + "sessions.worker.dispatch_completed", + extra={ + "sessionId": session_id, + "runId": body.runId, + "outcome": "failed", + }, + ) + _light_cleanup() + raise HTTPException(status_code=500, detail="dispatch failed") + + +def _log_dispatch_state(state: str, *, session_id: str, run_id: str, job_type: str) -> None: + log.info( + "sessions.dispatch.total", + extra={ + "sessionId": session_id, + "runId": run_id, + "state": state, + "jobType": job_type, + }, + ) + + +def _upgrade_preload_if_needed(required: PreloadMode, session_id: str) -> None: + """Best-effort preload upgrade. On failure, os._exit(1) so K8s recycles. + + Rationale: ``setup_environment(load_mode=...)`` performs module imports + and model downloads. A partial failure mid-import leaves ``sys.modules`` + in an inconsistent state; ``_loaded_preload`` would NOT be updated, and + the next dispatch would retry the upgrade against a polluted process. + Crash-loop is the safe option — K8s recreates the pod with a clean + Python interpreter. + """ + global _loaded_preload + if _PRELOAD_RANK[required] <= _PRELOAD_RANK[_loaded_preload]: + return + log.info( + "sessions.worker.preload_upgraded", + extra={"sessionId": session_id, "from": _loaded_preload, "to": required}, + ) + try: + setup_environment(load_mode=required) + except Exception: + log.exception( + "sessions.worker.preload_upgrade_failed", + extra={"sessionId": session_id, "from": _loaded_preload, "to": required}, + ) + os._exit(1) # pragma: no cover + _loaded_preload = required + + +async def _download_and_parse_spec(spec_url: str) -> dict: + """Download CUE from GCS/HTTP and parse. Bounded by RUN_SPEC_DOWNLOAD_TIMEOUT_SEC. + + fsspec.open(...) is sync; wrap in ``asyncio.to_thread`` then bound with + ``asyncio.wait_for``. On timeout, raises ``asyncio.TimeoutError`` which + propagates as 500 via the outer handler. + """ + + def _blocking_download_and_parse() -> dict: + with tempfile.NamedTemporaryFile("w", suffix=".cue", delete=False) as tmp_f: + with fsspec.open(spec_url, "r", encoding="utf8") as f: + tmp_f.write(f.read()) + tmp_path = tmp_f.name + try: + return parsing.cue.load(tmp_path) + finally: + Path(tmp_path).unlink(missing_ok=True) + + return await asyncio.wait_for( + asyncio.to_thread(_blocking_download_and_parse), + timeout=_RUN_SPEC_DOWNLOAD_TIMEOUT_SEC, + ) + + +def _light_cleanup() -> None: + builder.building.BUILT_OBJECT_ID_REGISTRY.clear() + builder.PARALLEL_BUILD_ALLOWED = False + run.RUN_ID = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + random.seed() + np.random.seed() + torch.seed() diff --git a/zetta_utils/cli/main.py b/zetta_utils/cli/main.py index fdaf1726d..87c2a1894 100644 --- a/zetta_utils/cli/main.py +++ b/zetta_utils/cli/main.py @@ -181,6 +181,24 @@ def show_registry(): logger.critical(pprint.pformat(builder.REGISTRY, indent=4)) +@cli.command() +def session_master() -> None: # pragma: no cover # no logic, delegation + """Run the per-session master process. Driven by env vars.""" + import asyncio # pylint: disable=import-outside-toplevel + + from zetta_utils.session import master # pylint: disable=import-outside-toplevel + + asyncio.run(master.main()) + + +@cli.command() +def session_reconcile() -> None: # pragma: no cover # no logic, delegation + """One-shot reconcile scan. Driven by env vars.""" + from zetta_utils.session import reconcile # pylint: disable=import-outside-toplevel + + reconcile.run_reconcile() + + for cmd in run_info_cli.commands.values(): cli.add_command(cmd) diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/pod.py b/zetta_utils/cloud_management/resource_allocation/k8s/pod.py index 90cdb2501..034dc69b6 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/pod.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/pod.py @@ -398,6 +398,69 @@ def get_mazepa_pod_spec( ) +def create_namespaced_pod( + *, + namespace: str, + body: Mapping, + k8s_core_v1_api: k8s_client.CoreV1Api | None = None, +) -> k8s_client.V1Pod: + """Create a Pod via ``CoreV1Api``. + + Mirrors the call shape of ``BatchV1Api.create_namespaced_job`` — keyword-only + arguments and an optional explicit API client for testability. + + :param namespace: namespace to create the Pod in. + :param body: a dict matching ``V1Pod`` (must include ``metadata.name``, + ``metadata.namespace``, ``spec.containers``, etc.). The caller renders + this from a Pod template. + :param k8s_core_v1_api: optional override; tests inject a mock. + :return: the created ``V1Pod``. + """ + api = k8s_core_v1_api or k8s_client.CoreV1Api() + return api.create_namespaced_pod(namespace=namespace, body=body) + + +def read_namespaced_pod_status( + *, + name: str, + namespace: str, + k8s_core_v1_api: k8s_client.CoreV1Api | None = None, +) -> k8s_client.V1Pod: + """Read a Pod's status via ``CoreV1Api``. + + :param name: Pod name. + :param namespace: Pod namespace. + :param k8s_core_v1_api: optional override; tests inject a mock. + :return: the ``V1Pod`` carrying its status subresource. + """ + api = k8s_core_v1_api or k8s_client.CoreV1Api() + return api.read_namespaced_pod_status(name=name, namespace=namespace) + + +def delete_namespaced_pod( + *, + name: str, + namespace: str, + k8s_core_v1_api: k8s_client.CoreV1Api | None = None, +) -> None: + """Best-effort delete of a Pod via ``CoreV1Api``. + + Swallows ``ApiException`` with status 404 or 410 (the Pod is already gone), + re-raising any other status. + + :param name: Pod name. + :param namespace: Pod namespace. + :param k8s_core_v1_api: optional override; tests inject a mock. + """ + api = k8s_core_v1_api or k8s_client.CoreV1Api() + try: + api.delete_namespaced_pod(name=name, namespace=namespace) + except ApiException as e: + if e.status in (404, 410): + return + raise + + def _wait_for_pod_start( pod_name: str, namespace: str, diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/service.py b/zetta_utils/cloud_management/resource_allocation/k8s/service.py index 3adf240e3..a9428a339 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/service.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/service.py @@ -3,7 +3,9 @@ """ from contextlib import contextmanager -from typing import Dict, List, Optional +from typing import Dict, List, Mapping, Optional + +from kubernetes.client.exceptions import ApiException from kubernetes import client as k8s_client from zetta_utils import log @@ -47,6 +49,51 @@ def get_service( return k8s_client.V1Service(metadata=meta, spec=service_spec) +def create_namespaced_service( + *, + namespace: str, + body: Mapping, + k8s_core_v1_api: k8s_client.CoreV1Api | None = None, +) -> k8s_client.V1Service: + """Create a Service via ``CoreV1Api``. + + Mirrors :func:`create_namespaced_pod`'s call shape — keyword-only arguments + and an optional explicit API client for testability. + + :param namespace: namespace to create the Service in. + :param body: a dict matching ``V1Service``. The caller renders this from a + Service template. + :param k8s_core_v1_api: optional override; tests inject a mock. + :return: the created ``V1Service``. + """ + api = k8s_core_v1_api or k8s_client.CoreV1Api() + return api.create_namespaced_service(namespace=namespace, body=body) + + +def delete_namespaced_service( + *, + name: str, + namespace: str, + k8s_core_v1_api: k8s_client.CoreV1Api | None = None, +) -> None: + """Best-effort delete of a Service via ``CoreV1Api``. + + Swallows ``ApiException`` with status 404 or 410 (the Service is already + gone), re-raising any other status. + + :param name: Service name. + :param namespace: Service namespace. + :param k8s_core_v1_api: optional override; tests inject a mock. + """ + api = k8s_core_v1_api or k8s_client.CoreV1Api() + try: + api.delete_namespaced_service(name=name, namespace=namespace) + except ApiException as e: + if e.status in (404, 410): + return + raise + + @contextmanager def service_ctx_manager( run_id: str, diff --git a/zetta_utils/run/__init__.py b/zetta_utils/run/__init__.py index cbb784766..c4674330b 100644 --- a/zetta_utils/run/__init__.py +++ b/zetta_utils/run/__init__.py @@ -40,9 +40,11 @@ class RunInfo(Enum): WORKER_STATE = "worker_state" REGION_MISMATCH = "region_mismatch" SEMAPHORE_WIDTHS = "semaphore_widths" + QUEUED_AT = "queued_at" class RunState(Enum): + QUEUED = "queued" RUNNING = "running" TIMEDOUT = "timedout" COMPLETED = "completed" @@ -68,7 +70,7 @@ def update_run_results(results: dict) -> None: update_run_info(RUN_ID, info) -def _record_run(spec: dict | list | None = None) -> None: +def record_run(spec: dict | list | None = None) -> None: """ Records run info in a bucket for archiving. """ @@ -110,10 +112,30 @@ def update_run_info(run_id: str, info: DBRowDataT) -> None: RUN_DB[(run_id, col_keys)] = info -def _check_run_id_conflict(): - assert RUN_ID is not None - if RUN_ID in RUN_DB: - raise ValueError(f"RUN_ID {RUN_ID} already exists in database.") +def _check_run_id_conflict( + run_id: str, + *, + allowed_prior_state: str | None = None, +) -> None: + """ + Raise ValueError if a run-info row already exists for ``run_id``. + + :param run_id: the run id to check. + :param allowed_prior_state: if not ``None``, an existing row whose + ``state`` equals this value is permitted (no raise). Used by the + queued-state path to transition from QUEUED to RUNNING. + """ + if run_id not in RUN_DB: + return + if allowed_prior_state is None: + raise ValueError(f"RUN_ID {run_id} already exists in database.") + row = RUN_DB[(run_id, (RunInfo.STATE.value,))] + current = row.get(RunInfo.STATE.value) + if current != allowed_prior_state: + raise ValueError( + f"RUN_ID {run_id} already exists with state={current!r}; " + f"only state={allowed_prior_state!r} would be permitted" + ) def _send_heartbeat(run_id: str, bucket_egress_warned: set) -> None: @@ -166,14 +188,33 @@ def _cleanup_pod_stats(run_id: str) -> None: logger.warning(f"Failed to cleanup pod stats: {e}") +@attrs.define +class RunCtx: + run_id: str + _state: RunState = attrs.field(alias="_state") + + def transition_to_running(self) -> None: + if self._state == RunState.RUNNING: + return + if self._state != RunState.QUEUED: + raise RuntimeError( + f"transition_to_running called from state={self._state.value!r}; " + f"only QUEUED is permitted" + ) + _check_run_id_conflict(self.run_id, allowed_prior_state=RunState.QUEUED.value) + update_run_info(self.run_id, {RunInfo.STATE.value: RunState.RUNNING.value}) + self._state = RunState.RUNNING + + @contextmanager -def run_ctx_manager( +def run_ctx_manager( # pylint: disable=too-many-statements main_run_process: bool, run_id: str | None = None, spec: dict | list | None = None, heartbeat_interval: int = DEFAULT_HEARTBEAT_INTERVAL_SEC, update_costs_interval: int = DEFAULT_UPDATE_COSTS_INTERVAL_SEC, pod_stats_interval: int = DEFAULT_POD_STATS_INTERVAL_SEC, + queued_at: float | None = None, ): bucket_egress_warned: set[str] = set() @@ -189,22 +230,36 @@ def run_ctx_manager( heartbeat_sender = None update_costs_repeater = None pod_stats_repeater = None + run_ctx = None if main_run_process: - _check_run_id_conflict() - - # Register run only when heartbeat is enabled. - # Auxiliary processes should not modify the main process entry. - status = RunState.RUNNING.value - info: DBRowDataT = { - RunInfo.ZETTA_USER.value: os.environ["ZETTA_USER"], - RunInfo.TIMESTAMP.value: time.time(), - RunInfo.STATE.value: status, - RunInfo.PARAMS.value: " ".join(sys.argv[1:]), - } - _record_run(spec) - update_run_info(RUN_ID, info) - assert heartbeat_interval > 0 + _check_run_id_conflict(run_id) + + if queued_at is None: + # Register run only when heartbeat is enabled. + # Auxiliary processes should not modify the main process entry. + status = RunState.RUNNING.value + info: DBRowDataT = { + RunInfo.ZETTA_USER.value: os.environ["ZETTA_USER"], + RunInfo.TIMESTAMP.value: time.time(), + RunInfo.STATE.value: status, + RunInfo.PARAMS.value: " ".join(sys.argv[1:]), + } + record_run(spec) + update_run_info(RUN_ID, info) + else: + update_run_info( + run_id, + { + RunInfo.ZETTA_USER.value: os.environ["ZETTA_USER"], + RunInfo.TIMESTAMP.value: time.time(), + RunInfo.STATE.value: RunState.QUEUED.value, + RunInfo.QUEUED_AT.value: queued_at, + RunInfo.PARAMS.value: " ".join(sys.argv[1:]), + }, + ) + run_ctx = RunCtx(run_id=run_id, _state=RunState.QUEUED) + heartbeat_sender = RepeatTimer( heartbeat_interval, partial(_send_heartbeat, run_id, bucket_egress_warned) ) @@ -219,7 +274,7 @@ def run_ctx_manager( pod_stats_repeater.start() try: - yield + yield run_ctx except Exception as e: status = RunState.FAILED.value raise e from None diff --git a/zetta_utils/run/gc.py b/zetta_utils/run/gc.py index 23e78d07e..c115c1373 100644 --- a/zetta_utils/run/gc.py +++ b/zetta_utils/run/gc.py @@ -45,7 +45,7 @@ def _get_current_resources_and_stale_run_ids() -> ( run_resources[str(_resource["run_id"])][_resource_id] = _resource u_run_ids = set(run_resources.keys()) - runs = RUN_DB.query(column_filter={"state": ["running"]}) + runs = RUN_DB.query(column_filter={"state": ["running", "queued"]}) u_run_ids.update(runs.keys()) users = [] heartbeats = [] diff --git a/zetta_utils/session/__init__.py b/zetta_utils/session/__init__.py new file mode 100644 index 000000000..8b5830b74 --- /dev/null +++ b/zetta_utils/session/__init__.py @@ -0,0 +1,27 @@ +"""Per-session orchestration: manager service, master process, and reconcile backstop.""" + +import os + +from google.cloud import firestore + +from zetta_utils import constants + +_sessions_db: firestore.Client | None = None + + +def _get_sessions_db() -> firestore.Client: + """Lazily build and cache the Firestore client for the main (sessions) DB. + + Constructed once per process and shared by the manager, master, and + reconcile components so they all address the same ``sessions/*`` documents. + Construction is deferred to first use so importing a session module does not + require GCP credentials. The session documents live in the main database, + distinct from the run-info DB. + """ + global _sessions_db # pylint: disable=global-statement + if _sessions_db is None: + _sessions_db = firestore.Client( + project=os.environ.get("SESSIONS_FIRESTORE_PROJECT", constants.DEFAULT_PROJECT), + database=os.environ.get("SESSIONS_FIRESTORE_DATABASE"), + ) + return _sessions_db diff --git a/zetta_utils/session/master.py b/zetta_utils/session/master.py new file mode 100644 index 000000000..07186cf0f --- /dev/null +++ b/zetta_utils/session/master.py @@ -0,0 +1,562 @@ +""" +Per-session master process. + +Lifecycle: + Boot — read SESSION_ID, fetch Firestore row, create worker Pod+Service, + poll /healthz, drain pre-ready queue, transition state=ready. + Steady — aiohttp.web app exposing /dispatch, /status, /terminate. + Proxies to worker. Owns the cancellable idle TTL timer. + Terminate — on SIGTERM (Job deletion) or idle-fire: delete worker, write + state=down, exit cleanly. + +Run via: zetta session-master (registered in cli/main.py) +""" + +import asyncio +import logging +import os +import signal +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Literal + +import aiohttp +import yaml +from aiohttp import web +from google.cloud import firestore + +from kubernetes import client as k8s_client +from zetta_utils.cloud_management.resource_allocation.k8s import pod, service +from zetta_utils.session import _get_sessions_db + +log = logging.getLogger(__name__) + + +def _session_id() -> str: + return os.environ["SESSION_ID"] + + +def _pod_name() -> str: + return os.environ["POD_NAME"] + + +def _pod_uid() -> str: + return os.environ["POD_UID"] + + +def _workload_namespace() -> str: + return os.environ.get("WORKLOAD_NAMESPACE", "sessions") + + +def _sessions_image_tag() -> str: + return os.environ["SESSIONS_IMAGE_TAG"] + + +def _worker_template_path() -> str: + return os.environ["SESSION_WORKER_TEMPLATE_PATH"] + + +def _worker_service_template_path() -> str: + return os.environ["SESSION_WORKER_SERVICE_TEMPLATE_PATH"] + + +WORKER_HEALTHZ_TIMEOUT_S = 60 +WORKER_HEALTHZ_POLL_INTERVAL_S = 1 +WORKER_HEALTHZ_REFUSAL_THRESHOLD = 5 + +# Master-local state — disposable; rebuilt by manager creating a new master +# if this process dies. +# +# Concurrency invariant: no asyncio.Lock is required around these globals. +# The worker enforces single-flight via a Semaphore(1) around its /run_spec/ +# handler, so master can only have ONE response in flight at a time. +# _idle_timer_task is touched only from dispatch handlers (each serialised +# behind the worker semaphore anyway) and from the timer body itself; no race. +_idle_timer_task: asyncio.Task | None = None +_idle_ttl_sec: float = 3600 +_worker_endpoint: str = "" +_shutdown_started: bool = False + +_shutdown_event: asyncio.Event | None = None + + +def _get_shutdown_event() -> asyncio.Event: + """Return the module-level shutdown event, creating it on first use. + + Bound to the loop running ``main()`` (the sole loop ``asyncio.run`` creates + in the CLI entrypoint). ``_request_serve_stop`` and ``_serve_forever`` share + it so a stop requested before serving begins is not lost. + """ + global _shutdown_event # pylint: disable=global-statement + if _shutdown_event is None: + _shutdown_event = asyncio.Event() + return _shutdown_event + + +def _request_serve_stop() -> None: + """Signal ``_serve_forever`` to return so the process can exit cleanly. + + Sets the shutdown event ``_serve_forever`` awaits. Touches no event-loop + control directly, so callers (idle timer, terminate handler, SIGTERM + handler) leave the loop free for in-flight work to drain. + """ + _get_shutdown_event().set() + + +# ---- Firestore helpers -------------------------------------------------- + + +def _read_session_row(session_id: str) -> dict: + """Read the ``sessions/`` document. Returns ``{}`` if absent.""" + snapshot = _get_sessions_db().collection("sessions").document(session_id).get() + return snapshot.to_dict() or {} # type: ignore[union-attr] + + +def _write_session_state(state: str, *, reason: str | None = None) -> None: + """Merge ``state`` onto ``sessions/``. + + When transitioning to ``down``, also stamps ``terminatedAt`` and + ``terminationReason``. + """ + payload: dict = {"state": state} + if state == "down": + payload["terminatedAt"] = datetime.now(timezone.utc) + if reason is not None: + payload["terminationReason"] = reason + _get_sessions_db().collection("sessions").document(_session_id()).set(payload, merge=True) + + +def _read_queue_docs(session_id: str) -> list[dict]: + """List ``sessions//queue/*`` ordered by ``enqueuedAt`` asc. + + Each returned dict carries the document fields plus its ``dispatchId`` + (the document id). + """ + query = ( + _get_sessions_db() + .collection("sessions") + .document(session_id) + .collection("queue") + .order_by("enqueuedAt", direction=firestore.Query.ASCENDING) + ) + out: list[dict] = [] + for snapshot in query.stream(): + doc = snapshot.to_dict() or {} + doc["dispatchId"] = snapshot.id + out.append(doc) + return out + + +def _delete_queue_doc(session_id: str, dispatch_id: str) -> None: + """Delete ``sessions//queue/``.""" + ( + _get_sessions_db() + .collection("sessions") + .document(session_id) + .collection("queue") + .document(dispatch_id) + .delete() + ) + + +def _update_last_dispatch_at() -> None: + """Merge ``lastDispatchAt=`` onto ``sessions/``.""" + _get_sessions_db().collection("sessions").document(_session_id()).set( + {"lastDispatchAt": firestore.SERVER_TIMESTAMP}, merge=True + ) + + +# ---- Boot --------------------------------------------------------------- + + +async def main() -> None: + """CLI entrypoint. Boots master, serves the app, blocks until exit signal.""" + _install_sigterm_handler() + try: + await _boot() + await _serve_forever() + finally: + await _on_shutdown(reason="explicit_terminate") + + +async def _boot() -> None: + global _worker_endpoint, _idle_ttl_sec # pylint: disable=global-statement + + session_id = _session_id() + namespace = _workload_namespace() + log.info("sessions.master.boot_start", extra={"sessionId": session_id}) + row = _read_session_row(session_id) + if row.get("state") != "preparing": + log.error( + "sessions.master.unexpected_initial_state", + extra={"sessionId": session_id, "state": row.get("state")}, + ) + raise SystemExit(2) + + _idle_ttl_sec = int(row.get("config", {}).get("idleTtlSec", 3600)) + initial_preload = row.get("initialPreload", "try") + + # Render the worker template; create Pod + Service with ownerReferences + # pointing at THIS master Pod (downward-API env). + worker_body = _render_worker_template(initial_preload=initial_preload) + pod.create_namespaced_pod(namespace=namespace, body=worker_body) + + worker_svc_body = _render_worker_service() + service.create_namespaced_service(namespace=namespace, body=worker_svc_body) + + _worker_endpoint = f"http://session-worker-{session_id}.{namespace}.svc.cluster.local/" + + await _wait_for_worker_healthz() + await _drain_pre_ready_queue() + _write_session_state("ready") + log.info("sessions.master.boot_complete", extra={"sessionId": session_id}) + _start_idle_timer() + + +def _render_worker_template(*, initial_preload: str) -> dict: + """Load the worker Pod YAML template and substitute placeholders.""" + raw = Path(_worker_template_path()).read_text(encoding="utf-8") + substituted = ( + raw.replace("${SESSION_ID}", _session_id()) + .replace("${INITIAL_PRELOAD}", initial_preload) + .replace("${MASTER_POD_NAME}", _pod_name()) + .replace("${MASTER_POD_UID}", _pod_uid()) + .replace("${SESSIONS_IMAGE_TAG}", _sessions_image_tag()) + ) + return yaml.safe_load(substituted) + + +def _render_worker_service() -> dict: + """Load the worker Service YAML template and substitute placeholders.""" + raw = Path(_worker_service_template_path()).read_text(encoding="utf-8") + substituted = ( + raw.replace("${SESSION_ID}", _session_id()) + .replace("${MASTER_POD_NAME}", _pod_name()) + .replace("${MASTER_POD_UID}", _pod_uid()) + ) + return yaml.safe_load(substituted) + + +# ---- Worker probing ----------------------------------------------------- + + +async def _wait_for_worker_healthz() -> None: + """Poll the worker ``/healthz`` until ready or the boot budget expires. + + Applies the failure-mode disambiguation predicate on repeated + connection-refused; logs a terminal verdict and terminates the session on + a permanent worker failure or on timeout. + """ + refusal_count = 0 + deadline = asyncio.get_event_loop().time() + WORKER_HEALTHZ_TIMEOUT_S + timeout = aiohttp.ClientTimeout(total=2.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + while True: + now = asyncio.get_event_loop().time() + if now >= deadline: + await _terminate_session("worker_healthz_timeout") + try: + async with session.get(f"{_worker_endpoint}healthz") as response: + if response.status == 200: + return + except (aiohttp.ClientConnectionError, asyncio.TimeoutError): + refusal_count += 1 + if refusal_count >= WORKER_HEALTHZ_REFUSAL_THRESHOLD: + verdict = _classify_worker_failure() + if verdict == "permanent": + await _terminate_session("worker_permanent_failure") + # else: continue polling within the 60s budget + await asyncio.sleep(WORKER_HEALTHZ_POLL_INTERVAL_S) + + +def _classify_worker_failure() -> Literal["permanent", "transient"]: + """Classify a worker outage from its Pod status. + + ``phase=='Failed'`` or a non-zero container exit code is ``"permanent"``; + ``phase`` of ``Pending`` / ``Running`` is ``"transient"`` (keep polling + within the boot budget); unknown / ``Succeeded`` is ``"permanent"`` (the + worker is gone). + """ + session_id = _session_id() + try: + worker = pod.read_namespaced_pod_status( + name=f"session-worker-{session_id}", + namespace=_workload_namespace(), + ) + except k8s_client.exceptions.ApiException as e: + if e.status == 404: + log.warning( + "sessions.master.worker_404", + extra={"sessionId": session_id, "context": "classify"}, + ) + return "permanent" + raise + + phase = worker.status.phase + if phase == "Failed": + exit_code = None + for cs in worker.status.container_statuses or []: + term = getattr(cs.state, "terminated", None) + if term and term.exit_code is not None: + exit_code = term.exit_code + break + log.error( + "sessions.worker.permanent_failure", + extra={"sessionId": session_id, "exitCode": exit_code}, + ) + return "permanent" + + if phase in ("Pending", "Running"): + return "transient" + + # Unknown / Succeeded — treat as permanent (worker is gone). + return "permanent" + + +# ---- Queue drain -------------------------------------------------------- + + +async def _drain_pre_ready_queue() -> None: + """Drain the pre-ready queue, forwarding each dispatch to the worker. + + Polls until the queue is empty. Firestore reads are strongly consistent + within a region and the manager writes queue rows before transitioning the + session row to ``preparing``, so all enqueued rows are visible. Any row + enqueued after the drain completes is handled by the regular ``/dispatch`` + path in the ready state. + """ + session_id = _session_id() + drained = 0 + while True: + docs = _read_queue_docs(session_id) # ordered by enqueuedAt asc + if not docs: + break + for doc in docs: + await _forward_dispatch_to_worker(doc) + _delete_queue_doc(session_id, doc["dispatchId"]) + drained += 1 + log.info( + "sessions.master.queue_drained", + extra={"sessionId": session_id, "drainedCount": drained}, + ) + + +# ---- Endpoint logic ----------------------------------------------------- + + +async def _dispatch_logic(body: dict, *, authorization: str | None) -> dict: + """Forward a dispatch to the worker, pausing the idle timer around it.""" + _cancel_idle_timer() + try: + return await _forward_dispatch_to_worker(body, user_token=authorization) + finally: + _start_idle_timer() + + +async def _status_logic() -> dict: + """Probe the worker ``/healthz`` and report the session state. + + On an unreachable worker, mark the session ``down`` (reason + ``proxy_unreachable``) and return ``{"state": "down"}``. + """ + try: + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(f"{_worker_endpoint}healthz") as response: + return {"state": "ready" if response.status == 200 else "down"} + except (aiohttp.ClientConnectionError, asyncio.TimeoutError): + log.warning( + "sessions.master.worker_404", + extra={"sessionId": _session_id(), "context": "status"}, + ) + _write_session_state("down", reason="proxy_unreachable") + return {"state": "down"} + + +async def _terminate_logic() -> dict: + """Run a clean shutdown then request the serve loop to stop.""" + await _on_shutdown(reason="explicit_terminate") + _request_serve_stop() + return {"state": "down"} + + +# ---- aiohttp.web handlers ----------------------------------------------- + + +async def dispatch(request: web.Request) -> web.Response: + body = await request.json() + payload = { + "specUrl": body.get("specUrl"), + "runId": body.get("runId"), + "jobType": body.get("jobType"), + "requiredPreload": body.get("requiredPreload", "try"), + } + authorization = request.headers.get("Authorization") + result = await _dispatch_logic(payload, authorization=authorization) + return web.json_response(result) + + +async def status(request: web.Request) -> web.Response: # pylint: disable=unused-argument + return web.json_response(await _status_logic()) + + +async def terminate(request: web.Request) -> web.Response: # pylint: disable=unused-argument + return web.json_response(await _terminate_logic()) + + +def _build_app() -> web.Application: + app = web.Application() + app.router.add_post("/dispatch", dispatch) + app.router.add_get("/status", status) + app.router.add_post("/terminate", terminate) + return app + + +# ---- Forwarding to worker ----------------------------------------------- + + +async def _forward_dispatch_to_worker( + dispatch_doc: dict, + *, + user_token: str | None = None, +) -> dict: + headers = {} + token = user_token + if token: + if not token.startswith("Bearer "): + token = "Bearer " + token + headers["Authorization"] = token + + payload = { + k: dispatch_doc[k] + for k in ("specUrl", "runId", "jobType", "requiredPreload") + if k in dispatch_doc + } + + try: + # timeout=None lets the worker take as long as it needs (mazepa builds + # can run minutes). allow_redirects=False matches the worker's + # redirect_slashes=False config — a redirect would indicate drift. + timeout = aiohttp.ClientTimeout(total=None) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post( + f"{_worker_endpoint}run_spec/", + json=payload, + headers=headers, + allow_redirects=False, + ) as response: + response.raise_for_status() + body = await response.json() + _update_last_dispatch_at() + return body + except aiohttp.ClientResponseError as e: + # Worker process is intact (no connection refusal) but /run_spec/ + # returned >= 400. 4xx never retries or recycles. 5xx gets one bounded + # retry then surfaces 502. No recycle — the worker is healthy at the + # process level; the failure is application-level. + if e.status < 500: + raise web.HTTPBadGateway(reason="worker_run_spec_client_error") from e + if dispatch_doc.get("_master_retry_attempted"): + raise web.HTTPBadGateway(reason="worker_run_spec_error") from e + await asyncio.sleep(2.0) + dispatch_doc = {**dispatch_doc, "_master_retry_attempted": True} + return await _forward_dispatch_to_worker(dispatch_doc, user_token=user_token) + except (aiohttp.ClientConnectionError, asyncio.TimeoutError) as e: + # Worker is unreachable. Classify and either keep going or terminate. + verdict = _classify_worker_failure() + if verdict == "permanent": + await _terminate_session("worker_permanent_failure") + raise web.HTTPBadGateway(reason="worker unreachable") from e + + +# ---- Idle timer --------------------------------------------------------- + + +def _start_idle_timer() -> None: + global _idle_timer_task # pylint: disable=global-statement + if _idle_timer_task and not _idle_timer_task.done(): + return + _idle_timer_task = asyncio.create_task(_idle_timer_body()) + + +def _cancel_idle_timer() -> None: + global _idle_timer_task # pylint: disable=global-statement + if _idle_timer_task and not _idle_timer_task.done(): + _idle_timer_task.cancel() + _idle_timer_task = None + + +async def _idle_timer_body() -> None: + try: + await asyncio.sleep(_idle_ttl_sec) + except asyncio.CancelledError: + return + log.info( + "sessions.master.idle_timer_fired", + extra={"sessionId": _session_id(), "idleTtlSec": _idle_ttl_sec}, + ) + await _on_shutdown(reason="idle_timer") + _request_serve_stop() + + +# ---- Shutdown ----------------------------------------------------------- + + +def _install_sigterm_handler() -> None: # pragma: no cover + """Install a SIGTERM handler that runs a clean shutdown. + + Uses the running loop's ``add_signal_handler`` so the callback runs from + inside the loop (it integrates with the loop's selector). Falls back to + ``signal.signal`` only on platforms that do not support the loop API + (non-POSIX dev environments); production runs on Linux pods. + """ + + def _on_sigterm() -> None: + asyncio.create_task(_on_shutdown(reason="explicit_terminate")) + _request_serve_stop() + + try: + asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, _on_sigterm) + except NotImplementedError: + signal.signal(signal.SIGTERM, lambda *_: _on_sigterm()) + + +async def _on_shutdown(*, reason: str) -> None: + global _shutdown_started # pylint: disable=global-statement + if _shutdown_started: + return + _shutdown_started = True + session_id = _session_id() + namespace = _workload_namespace() + _write_session_state("down", reason=reason) + pod.delete_namespaced_pod( + name=f"session-worker-{session_id}", + namespace=namespace, + ) + service.delete_namespaced_service( + name=f"session-worker-{session_id}", + namespace=namespace, + ) + + +async def _terminate_session(reason: str) -> None: + await _on_shutdown(reason=reason) + sys.exit(1 if reason != "idle_timer" else 0) + + +# ---- Serve -------------------------------------------------------------- + + +async def _serve_forever() -> None: # pragma: no cover + """Run the aiohttp.web app, bound to 0.0.0.0:80. + + Returns when the shutdown event is set by the idle timer, the terminate + handler, or the SIGTERM handler, allowing the process to exit cleanly. + """ + runner = web.AppRunner(_build_app()) + await runner.setup() + site = web.TCPSite(runner, "0.0.0.0", 80) + await site.start() + await _get_shutdown_event().wait() + await runner.cleanup() diff --git a/zetta_utils/session/reconcile.py b/zetta_utils/session/reconcile.py new file mode 100644 index 000000000..a98f00e74 --- /dev/null +++ b/zetta_utils/session/reconcile.py @@ -0,0 +1,161 @@ +""" +Daily reconcile backstop. + +Predicates per session row: + (1) state != "down" AND lastDispatchAt < now() - 24h + OR + (2) state != "down" AND BatchV1Api.read_namespaced_job(name=session-master-) + raises ApiException(status in (404, 410)) + +On match: write state="down" + terminatedAt + terminationReason, then clean up +based on which predicate fired: + - master MISSING (Job already 404/410): the orphan case cascade-GC missed. + Best-effort delete the worker Pod, worker Service, and master Service. + - stale-but-ALIVE (Job still exists): delete the master Job (Background + propagation); ownerReferences cascade-GC the worker Pod/Service + master + Service. This actually terminates the still-running master. +Swallow 404/410 on every delete; count other errors as cleanupErrors. +""" + +import logging +import os +from datetime import datetime, timedelta, timezone + +from google.cloud import firestore +from google.cloud.firestore_v1.base_query import FieldFilter + +from kubernetes import client as k8s_client +from zetta_utils.cloud_management.resource_allocation.k8s import pod, service +from zetta_utils.session import _get_sessions_db + +log = logging.getLogger(__name__) + +WORKLOAD_NAMESPACE = os.environ.get("WORKLOAD_NAMESPACE", "sessions") +STALE_AFTER = timedelta(hours=24) + + +def run_reconcile() -> dict: + total = 0 + stale_by_time = 0 + stale_by_missing_master = 0 + reconciled = 0 + cleanup_errors = 0 + + batch_v1 = k8s_client.BatchV1Api() + now = datetime.now(timezone.utc) + + for row in _query_non_down_sessions(): + total += 1 + session_id = row["sessionId"] + is_stale_by_time = _is_stale_by_time(row, now) + is_missing_master = _is_master_missing(batch_v1, session_id) + + if not (is_stale_by_time or is_missing_master): + continue + + # Missing-master takes precedence: it dictates orphan cleanup and is the + # more actionable signal when a row is both stale and missing. + if is_missing_master: + stale_by_missing_master += 1 + reason = "reconcile_master_missing" + else: + stale_by_time += 1 + reason = "reconcile_stale_24h" + + _write_session_state(session_id, "down", reason=reason) + + try: + if is_missing_master: + pod.delete_namespaced_pod( + name=f"session-worker-{session_id}", namespace=WORKLOAD_NAMESPACE + ) + service.delete_namespaced_service( + name=f"session-worker-{session_id}", namespace=WORKLOAD_NAMESPACE + ) + service.delete_namespaced_service( + name=f"session-master-{session_id}", namespace=WORKLOAD_NAMESPACE + ) + else: + # Stale but alive: delete the master Job; ownerReferences + # cascade-GC the worker Pod/Service and master Service. + _delete_master_job(batch_v1, session_id) + except Exception as e: # pylint: disable=broad-exception-caught + cleanup_errors += 1 + log.warning( + "sessions.reconcile.cleanup_error", + extra={"sessionId": session_id, "error": str(e)}, + ) + + log.info( + "sessions.reconcile.reconciled_count", + extra={"sessionId": session_id, "reason": reason}, + ) + log.info("sessions.session.terminated", extra={"sessionId": session_id, "reason": reason}) + reconciled += 1 + + summary = { + "totalRowsScanned": total, + "staleByTime": stale_by_time, + "staleByMissingMaster": stale_by_missing_master, + "reconciledCount": reconciled, + "cleanupErrors": cleanup_errors, + } + log.info("sessions.reconcile.scan_complete", extra=summary) + return summary + + +def _is_stale_by_time(row: dict, now: datetime) -> bool: + last = row.get("lastDispatchAt") or row.get("createdAt") + if last is None: + return False + return (now - last) > STALE_AFTER + + +def _is_master_missing(batch_v1: k8s_client.BatchV1Api, session_id: str) -> bool: + try: + batch_v1.read_namespaced_job( + name=f"session-master-{session_id}", namespace=WORKLOAD_NAMESPACE + ) + return False + except k8s_client.exceptions.ApiException as e: + if e.status in (404, 410): + return True + raise + + +def _delete_master_job(batch_v1: k8s_client.BatchV1Api, session_id: str) -> None: + """Delete the master Job (Background propagation; cascade-GC reaps the + worker Pod/Service + master Service). Swallows 404/410 (already gone).""" + try: + batch_v1.delete_namespaced_job( + name=f"session-master-{session_id}", + namespace=WORKLOAD_NAMESPACE, + propagation_policy="Background", + ) + except k8s_client.exceptions.ApiException as e: + if e.status not in (404, 410): + raise + + +def _query_non_down_sessions(): + """Single-field query: yield rows where state != "down" (equivalently + state in ("preparing", "ready", "working", "idle")), each carrying its + document id as row["sessionId"]. Staleness is judged client-side from + lastDispatchAt, so this needs no composite index.""" + query = ( + _get_sessions_db().collection("sessions").where(filter=FieldFilter("state", "!=", "down")) + ) + for snap in query.stream(): + row = snap.to_dict() or {} + row["sessionId"] = snap.id + yield row + + +def _write_session_state(session_id: str, state: str, *, reason: str) -> None: + """Merge-write to sessions/; auto-stamps terminatedAt on 'down' + (same convention as manager/master).""" + payload: dict[str, object] = {"state": state} + if state == "down": + payload["terminatedAt"] = firestore.SERVER_TIMESTAMP + payload["terminationReason"] = reason + _get_sessions_db().collection("sessions").document(session_id).set(payload, merge=True)