|
3 | 3 | import asyncio |
4 | 4 | import json |
5 | 5 | from typing import TYPE_CHECKING |
| 6 | +from unittest.mock import AsyncMock, patch |
6 | 7 |
|
7 | 8 | import pytest |
8 | 9 | from sqlalchemy import inspect, select |
| 10 | +from sqlalchemy.exc import SQLAlchemyError |
9 | 11 | from sqlalchemy.ext.asyncio import create_async_engine |
10 | 12 |
|
11 | 13 | from crawlee.configuration import Configuration |
| 14 | +from crawlee.errors import StorageWriteError |
12 | 15 | from crawlee.storage_clients import SqlStorageClient |
13 | 16 | from crawlee.storage_clients._sql._db_models import KeyValueStoreMetadataDb, KeyValueStoreRecordDb |
14 | 17 | from crawlee.storage_clients.models import KeyValueStoreMetadata |
@@ -281,3 +284,21 @@ async def test_data_persistence_across_reopens(configuration: Configuration) -> |
281 | 284 | assert record.value == test_value |
282 | 285 |
|
283 | 286 | await reopened_client.drop() |
| 287 | + |
| 288 | + |
| 289 | +async def test_error_handling_on_set_failure(kvs_client: SqlKeyValueStoreClient) -> None: |
| 290 | + """Test that StorageWriteError is raised when SQL writing fails.""" |
| 291 | + with patch( |
| 292 | + 'crawlee.storage_clients._sql._key_value_store_client.SqlKeyValueStoreClient.get_session', |
| 293 | + ) as mock_get_session: |
| 294 | + mock_session = AsyncMock() |
| 295 | + mock_session.__aenter__ = AsyncMock(return_value=mock_session) |
| 296 | + mock_session.__aexit__ = AsyncMock(return_value=None) |
| 297 | + mock_session.execute = AsyncMock(side_effect=SQLAlchemyError('db error')) |
| 298 | + mock_get_session.return_value = mock_session |
| 299 | + |
| 300 | + with pytest.raises(StorageWriteError) as exc_info: |
| 301 | + await kvs_client.set_value(key='test', value='test-value') |
| 302 | + |
| 303 | + assert isinstance(exc_info.value.cause, SQLAlchemyError) |
| 304 | + assert str(exc_info.value.cause) == 'db error' |
0 commit comments