diff --git a/libs/labelbox/src/labelbox/schema/model_run.py b/libs/labelbox/src/labelbox/schema/model_run.py index a7712858e..0791e15e5 100644 --- a/libs/labelbox/src/labelbox/schema/model_run.py +++ b/libs/labelbox/src/labelbox/schema/model_run.py @@ -16,6 +16,8 @@ Union, ) +from lbox.exceptions import InternalServerError, ResourceNotFoundError + from labelbox.orm.db_object import DbObject, experimental from labelbox.orm.model import Entity, Field, Relationship from labelbox.orm.query import results_query_part @@ -65,6 +67,53 @@ class Status(Enum): COMPLETE = "COMPLETE" FAILED = "FAILED" + def _get_cost_and_usage(self) -> Dict[str, Any]: + """Lazily fetches and caches cost and data row count for this Model Run. + + Returns an empty dict when no cost/usage information is available. + """ + if getattr(self, "_cost_and_usage", None) is None: + query_str = """ + query GetModelRunCostInfoPyApi($modelRunId: ID!) { + modelFoundryModelRunInfo(where: {modelRunId: $modelRunId}) { + cost + status + totalDataRows + } + } + """ + try: + res = self.client.execute(query_str, {"modelRunId": self.uid}) + except (ResourceNotFoundError, InternalServerError): + # No cost/usage info available; cache the empty result. + # Transient errors (network, timeout, rate limit) are not + # caught so they propagate and the next access can retry. + res = None + self._cost_and_usage = (res or {}).get( + "modelFoundryModelRunInfo" + ) or {} + return self._cost_and_usage + + @property + def total_cost(self) -> Optional[float]: + """Total cost (USD) of this Model Run. + + ``None`` if cost is not available for this run. + """ + return self._get_cost_and_usage().get("cost") + + @property + def total_data_rows(self) -> Optional[int]: + """Number of data rows processed by this Model Run. + + ``None`` if not available for this run. + """ + return self._get_cost_and_usage().get("totalDataRows") + + def refresh_cost_and_usage(self) -> None: + """Clears the cached cost/usage so the next access re-fetches live data.""" + self._cost_and_usage = None + def upsert_labels( self, label_ids: Optional[List[str]] = None, diff --git a/libs/labelbox/tests/unit/test_unit_model_run.py b/libs/labelbox/tests/unit/test_unit_model_run.py new file mode 100644 index 000000000..747f1aa00 --- /dev/null +++ b/libs/labelbox/tests/unit/test_unit_model_run.py @@ -0,0 +1,116 @@ +from unittest.mock import MagicMock + +import pytest +from lbox.exceptions import ( + InternalServerError, + NetworkError, + ResourceNotFoundError, +) + +from labelbox.schema.model_run import ModelRun + + +def _make_model_run(client): + return ModelRun( + client, + { + "id": "model-run-1", + "name": "test run", + "createdAt": "2021-06-01T00:00:00.000Z", + "updatedAt": "2021-06-01T00:00:00.000Z", + "createdBy": "user-1", + "modelId": "model-1", + "trainingMetadata": {}, + "modelAppId": "app-1", + }, + ) + + +def test_total_cost_and_data_rows_are_fetched_and_cached(): + client = MagicMock() + client.execute.return_value = { + "modelFoundryModelRunInfo": { + "cost": 3.5, + "status": "finished", + "totalDataRows": 12, + } + } + model_run = _make_model_run(client) + + assert model_run.total_cost == 3.5 + assert model_run.total_data_rows == 12 + + # Cost/usage is rehydrated once and cached across property reads. + assert client.execute.call_count == 1 + # The model run id is passed to the query. + assert client.execute.call_args[0][1] == {"modelRunId": "model-run-1"} + + +def test_refresh_cost_and_usage_refetches(): + client = MagicMock() + client.execute.return_value = { + "modelFoundryModelRunInfo": { + "cost": 1.0, + "status": "finished", + "totalDataRows": 1, + } + } + model_run = _make_model_run(client) + + assert model_run.total_cost == 1.0 + model_run.refresh_cost_and_usage() + assert model_run.total_cost == 1.0 + assert client.execute.call_count == 2 + + +@pytest.mark.parametrize( + "error", + [ + ResourceNotFoundError(message="model run not found"), + InternalServerError("no model job for run"), + ], +) +def test_cost_and_usage_none_for_non_foundry_run(error): + client = MagicMock() + client.execute.side_effect = error + model_run = _make_model_run(client) + + assert model_run.total_cost is None + assert model_run.total_data_rows is None + + +@pytest.mark.parametrize( + "execute_result", + [ + None, # execute() can return None instead of a payload + {"modelFoundryModelRunInfo": None}, + ], +) +def test_cost_and_usage_none_when_payload_missing(execute_result): + client = MagicMock() + client.execute.return_value = execute_result + model_run = _make_model_run(client) + + assert model_run.total_cost is None + assert model_run.total_data_rows is None + + +def test_transient_errors_propagate_and_are_not_cached(): + client = MagicMock() + client.execute.side_effect = NetworkError(Exception("boom")) + model_run = _make_model_run(client) + + with pytest.raises(NetworkError): + _ = model_run.total_cost + + # The failure is not cached, so a later successful access recovers. + client.execute.side_effect = None + client.execute.return_value = { + "modelFoundryModelRunInfo": { + "cost": 2.0, + "status": "finished", + "totalDataRows": 5, + } + } + assert model_run.total_cost == 2.0 + assert model_run.total_data_rows == 5