Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions libs/labelbox/src/labelbox/schema/model_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
Union,
)

from lbox.exceptions import InternalServerError, ResourceNotFoundError

from labelbox.orm.db_object import DbObject, experimental
from labelbox.orm.model import Entity, Field, Relationship
from labelbox.orm.query import results_query_part
Expand Down Expand Up @@ -65,6 +67,53 @@ class Status(Enum):
COMPLETE = "COMPLETE"
FAILED = "FAILED"

def _get_cost_and_usage(self) -> Dict[str, Any]:
"""Lazily fetches and caches cost and data row count for this Model Run.

Returns an empty dict when no cost/usage information is available.
"""
if getattr(self, "_cost_and_usage", None) is None:
query_str = """
query GetModelRunCostInfoPyApi($modelRunId: ID!) {
modelFoundryModelRunInfo(where: {modelRunId: $modelRunId}) {
cost
status
totalDataRows
}
}
"""
try:
res = self.client.execute(query_str, {"modelRunId": self.uid})
except (ResourceNotFoundError, InternalServerError):
# No cost/usage info available; cache the empty result.
# Transient errors (network, timeout, rate limit) are not
# caught so they propagate and the next access can retry.
res = None
self._cost_and_usage = (res or {}).get(
"modelFoundryModelRunInfo"
) or {}
return self._cost_and_usage

@property
def total_cost(self) -> Optional[float]:
"""Total cost (USD) of this Model Run.

``None`` if cost is not available for this run.
"""
return self._get_cost_and_usage().get("cost")

@property
def total_data_rows(self) -> Optional[int]:
"""Number of data rows processed by this Model Run.

``None`` if not available for this run.
"""
return self._get_cost_and_usage().get("totalDataRows")

def refresh_cost_and_usage(self) -> None:
"""Clears the cached cost/usage so the next access re-fetches live data."""
self._cost_and_usage = None

def upsert_labels(
self,
label_ids: Optional[List[str]] = None,
Expand Down
116 changes: 116 additions & 0 deletions libs/labelbox/tests/unit/test_unit_model_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from unittest.mock import MagicMock

import pytest
from lbox.exceptions import (
InternalServerError,
NetworkError,
ResourceNotFoundError,
)

from labelbox.schema.model_run import ModelRun


def _make_model_run(client):
return ModelRun(
client,
{
"id": "model-run-1",
"name": "test run",
"createdAt": "2021-06-01T00:00:00.000Z",
"updatedAt": "2021-06-01T00:00:00.000Z",
"createdBy": "user-1",
"modelId": "model-1",
"trainingMetadata": {},
"modelAppId": "app-1",
},
)


def test_total_cost_and_data_rows_are_fetched_and_cached():
client = MagicMock()
client.execute.return_value = {
"modelFoundryModelRunInfo": {
"cost": 3.5,
"status": "finished",
"totalDataRows": 12,
}
}
model_run = _make_model_run(client)

assert model_run.total_cost == 3.5
assert model_run.total_data_rows == 12

# Cost/usage is rehydrated once and cached across property reads.
assert client.execute.call_count == 1
# The model run id is passed to the query.
assert client.execute.call_args[0][1] == {"modelRunId": "model-run-1"}


def test_refresh_cost_and_usage_refetches():
client = MagicMock()
client.execute.return_value = {
"modelFoundryModelRunInfo": {
"cost": 1.0,
"status": "finished",
"totalDataRows": 1,
}
}
model_run = _make_model_run(client)

assert model_run.total_cost == 1.0
model_run.refresh_cost_and_usage()
assert model_run.total_cost == 1.0
assert client.execute.call_count == 2


@pytest.mark.parametrize(
"error",
[
ResourceNotFoundError(message="model run not found"),
InternalServerError("no model job for run"),
],
)
def test_cost_and_usage_none_for_non_foundry_run(error):
client = MagicMock()
client.execute.side_effect = error
model_run = _make_model_run(client)

assert model_run.total_cost is None
assert model_run.total_data_rows is None


@pytest.mark.parametrize(
"execute_result",
[
None, # execute() can return None instead of a payload
{"modelFoundryModelRunInfo": None},
],
)
def test_cost_and_usage_none_when_payload_missing(execute_result):
client = MagicMock()
client.execute.return_value = execute_result
model_run = _make_model_run(client)

assert model_run.total_cost is None
assert model_run.total_data_rows is None


def test_transient_errors_propagate_and_are_not_cached():
client = MagicMock()
client.execute.side_effect = NetworkError(Exception("boom"))
model_run = _make_model_run(client)

with pytest.raises(NetworkError):
_ = model_run.total_cost

# The failure is not cached, so a later successful access recovers.
client.execute.side_effect = None
client.execute.return_value = {
"modelFoundryModelRunInfo": {
"cost": 2.0,
"status": "finished",
"totalDataRows": 5,
}
}
assert model_run.total_cost == 2.0
assert model_run.total_data_rows == 5
Loading