Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/climatevision/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,14 @@ class PredictRequest(BaseModel):
kind: str = Field(default="demo")
analysis_type: AnalysisType = Field(default="deforestation")
bbox: Optional[list[float]] = None
start_date: Optional[str] = None
end_date: Optional[str] = None
start_date: Optional[str] = Field(
default=None,
description="Start date in YYYY-MM-DD format. Must be earlier than end_date.",
)
end_date: Optional[str] = Field(
default=None,
description="End date in YYYY-MM-DD format. Must be later than start_date.",
)

@field_validator("bbox")
@classmethod
Expand Down Expand Up @@ -563,9 +569,6 @@ async def predict_json(
org: dict[str, Any] = Depends(require_api_key),
) -> dict[str, Any]:
"""Run prediction using bounding box and date range."""
if body.start_date and body.end_date and body.start_date > body.end_date:
raise HTTPException(status_code=400, detail="start_date must be before end_date")

created_at = _utc_now_iso()
bbox_json = json.dumps(body.bbox) if body.bbox else None

Expand Down
62 changes: 62 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for ClimateVision API endpoints."""

from unittest.mock import patch

import pytest
from fastapi.testclient import TestClient

Expand Down Expand Up @@ -40,3 +42,63 @@ def test_predict_json_accepts_dev_key(client: TestClient) -> None:
)
# Should pass auth; inference may fail due to missing models/GEE
assert response.status_code in (200, 500)


def test_predict_valid_date_range_reaches_inference(client: TestClient) -> None:
"""POST /api/predict with valid date range should reach the inference layer."""
payload = {
"bbox": [-60.0, -15.0, -45.0, -5.0],
"start_date": "2023-01-01",
"end_date": "2023-06-30",
"analysis_type": "deforestation",
}
fake_result = {
"region": {"bbox": payload["bbox"]},
"inference": {"forest_percentage": 72.3},
"analysis_type": "deforestation",
}
with patch(
"climatevision.api.main.run_inference_from_gee", return_value=fake_result
) as mock_infer:
response = client.post(
"/api/predict",
json=payload,
headers={"X-API-Key": "cv_dev"},
)
assert response.status_code == 200
mock_infer.assert_called_once()


def test_predict_reversed_date_range_returns_422(client: TestClient) -> None:
"""POST /api/predict with start_date > end_date should return 422."""
payload = {
"bbox": [-60.0, -15.0, -45.0, -5.0],
"start_date": "2026-06-01",
"end_date": "2026-01-01",
"analysis_type": "deforestation",
}
response = client.post(
"/api/predict",
json=payload,
headers={"X-API-Key": "cv_dev"},
)
assert response.status_code == 422
body = response.json()
error_messages = [e["msg"] for e in body["detail"]]
assert any("start_date" in msg or "end_date" in msg for msg in error_messages)


def test_predict_equal_dates_returns_422(client: TestClient) -> None:
"""POST /api/predict with start_date == end_date should return 422."""
payload = {
"bbox": [-60.0, -15.0, -45.0, -5.0],
"start_date": "2023-06-01",
"end_date": "2023-06-01",
"analysis_type": "deforestation",
}
response = client.post(
"/api/predict",
json=payload,
headers={"X-API-Key": "cv_dev"},
)
assert response.status_code == 422
Loading