diff --git a/src/climatevision/api/main.py b/src/climatevision/api/main.py index 729b213..d873599 100644 --- a/src/climatevision/api/main.py +++ b/src/climatevision/api/main.py @@ -109,8 +109,14 @@ class PredictRequest(BaseModel): kind: str = Field(default="demo") analysis_type: AnalysisType = Field(default="deforestation") bbox: Optional[list[float]] = None - start_date: Optional[str] = None - end_date: Optional[str] = None + start_date: Optional[str] = Field( + default=None, + description="Start date in YYYY-MM-DD format. Must be earlier than end_date.", + ) + end_date: Optional[str] = Field( + default=None, + description="End date in YYYY-MM-DD format. Must be later than start_date.", + ) @field_validator("bbox") @classmethod @@ -563,9 +569,6 @@ async def predict_json( org: dict[str, Any] = Depends(require_api_key), ) -> dict[str, Any]: """Run prediction using bounding box and date range.""" - if body.start_date and body.end_date and body.start_date > body.end_date: - raise HTTPException(status_code=400, detail="start_date must be before end_date") - created_at = _utc_now_iso() bbox_json = json.dumps(body.bbox) if body.bbox else None diff --git a/tests/test_api.py b/tests/test_api.py index 1593b40..da9c49c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,5 +1,7 @@ """Tests for ClimateVision API endpoints.""" +from unittest.mock import patch + import pytest from fastapi.testclient import TestClient @@ -40,3 +42,63 @@ def test_predict_json_accepts_dev_key(client: TestClient) -> None: ) # Should pass auth; inference may fail due to missing models/GEE assert response.status_code in (200, 500) + + +def test_predict_valid_date_range_reaches_inference(client: TestClient) -> None: + """POST /api/predict with valid date range should reach the inference layer.""" + payload = { + "bbox": [-60.0, -15.0, -45.0, -5.0], + "start_date": "2023-01-01", + "end_date": "2023-06-30", + "analysis_type": "deforestation", + } + fake_result = { + "region": {"bbox": payload["bbox"]}, + "inference": {"forest_percentage": 72.3}, + "analysis_type": "deforestation", + } + with patch( + "climatevision.api.main.run_inference_from_gee", return_value=fake_result + ) as mock_infer: + response = client.post( + "/api/predict", + json=payload, + headers={"X-API-Key": "cv_dev"}, + ) + assert response.status_code == 200 + mock_infer.assert_called_once() + + +def test_predict_reversed_date_range_returns_422(client: TestClient) -> None: + """POST /api/predict with start_date > end_date should return 422.""" + payload = { + "bbox": [-60.0, -15.0, -45.0, -5.0], + "start_date": "2026-06-01", + "end_date": "2026-01-01", + "analysis_type": "deforestation", + } + response = client.post( + "/api/predict", + json=payload, + headers={"X-API-Key": "cv_dev"}, + ) + assert response.status_code == 422 + body = response.json() + error_messages = [e["msg"] for e in body["detail"]] + assert any("start_date" in msg or "end_date" in msg for msg in error_messages) + + +def test_predict_equal_dates_returns_422(client: TestClient) -> None: + """POST /api/predict with start_date == end_date should return 422.""" + payload = { + "bbox": [-60.0, -15.0, -45.0, -5.0], + "start_date": "2023-06-01", + "end_date": "2023-06-01", + "analysis_type": "deforestation", + } + response = client.post( + "/api/predict", + json=payload, + headers={"X-API-Key": "cv_dev"}, + ) + assert response.status_code == 422