Skip to content

Commit d09828c

Browse files
committed
move retry logic to storage clients
1 parent 5f20eb4 commit d09828c

20 files changed

Lines changed: 336 additions & 196 deletions

src/crawlee/_utils/retry.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from datetime import timedelta
5+
from functools import wraps
6+
from typing import TYPE_CHECKING, ParamSpec, TypeVar
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Awaitable, Callable
10+
11+
P = ParamSpec('P')
12+
T = TypeVar('T')
13+
14+
15+
def retry_on_error(
16+
*exception_types: type[Exception],
17+
max_attempts: int = 3,
18+
base_delay: timedelta = timedelta(milliseconds=500),
19+
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
20+
"""Retry an async function with exponential backoff on specified exceptions.
21+
22+
Args:
23+
*exception_types: Exception types to catch and retry on.
24+
max_attempts: Maximum number of attempts including the first one.
25+
base_delay: Base delay between retries; doubles on each subsequent attempt.
26+
"""
27+
28+
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
29+
30+
if max_attempts < 1:
31+
raise ValueError('max_attempts must be at least 1')
32+
33+
@wraps(func)
34+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
35+
base_delay_seconds = base_delay.total_seconds()
36+
for attempt in range(max_attempts):
37+
try:
38+
return await func(*args, **kwargs)
39+
except Exception as exc: # noqa: PERF203
40+
if not isinstance(exc, exception_types) or attempt >= max_attempts - 1:
41+
raise
42+
await asyncio.sleep(base_delay_seconds * (2**attempt))
43+
raise RuntimeError('Unreachable')
44+
45+
return wrapper
46+
47+
return decorator

src/crawlee/errors.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
'RequestHandlerError',
1919
'ServiceConflictError',
2020
'SessionError',
21-
'StorageWriteError',
2221
'UserDefinedErrorHandlerError',
2322
]
2423

@@ -117,12 +116,3 @@ class ContextPipelineInterruptedError(Exception):
117116
@docs_group('Errors')
118117
class RequestCollisionError(Exception):
119118
"""Raised when a request cannot be processed due to a conflict with required resources."""
120-
121-
122-
@docs_group('Errors')
123-
class StorageWriteError(Exception):
124-
"""Raised when a write operation to a storage fails."""
125-
126-
def __init__(self, cause: Exception) -> None:
127-
super().__init__(str(cause))
128-
self.cause = cause

src/crawlee/storage_clients/_file_system/_dataset_client.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from crawlee._utils.crypto import crypto_random_object_id
1616
from crawlee._utils.file import atomic_write, json_dumps
1717
from crawlee._utils.raise_if_too_many_kwargs import raise_if_too_many_kwargs
18-
from crawlee.errors import StorageWriteError
1918
from crawlee.storage_clients._base import DatasetClient
2019
from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata
2120

@@ -223,24 +222,21 @@ async def purge(self) -> None:
223222
@override
224223
async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None:
225224
async with self._lock:
226-
try:
227-
new_item_count = self._metadata.item_count
228-
if isinstance(data, list):
229-
for item in data:
230-
new_item_count += 1
231-
await self._push_item(item, new_item_count)
232-
else:
225+
new_item_count = self._metadata.item_count
226+
if isinstance(data, list):
227+
for item in data:
233228
new_item_count += 1
234-
await self._push_item(data, new_item_count)
229+
await self._push_item(item, new_item_count)
230+
else:
231+
new_item_count += 1
232+
await self._push_item(data, new_item_count)
235233

236-
# now update metadata under the same lock
237-
await self._update_metadata(
238-
update_accessed_at=True,
239-
update_modified_at=True,
240-
new_item_count=new_item_count,
241-
)
242-
except OSError as e:
243-
raise StorageWriteError(e) from e
234+
# now update metadata under the same lock
235+
await self._update_metadata(
236+
update_accessed_at=True,
237+
update_modified_at=True,
238+
new_item_count=new_item_count,
239+
)
244240

245241
@override
246242
async def get_data(

src/crawlee/storage_clients/_file_system/_key_value_store_client.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from crawlee._utils.crypto import crypto_random_object_id
1818
from crawlee._utils.file import atomic_write, infer_mime_type, json_dumps
1919
from crawlee._utils.raise_if_too_many_kwargs import raise_if_too_many_kwargs
20-
from crawlee.errors import StorageWriteError
2120
from crawlee.storage_clients._base import KeyValueStoreClient
2221
from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata
2322

@@ -329,20 +328,17 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No
329328
record_metadata_content = await json_dumps(record_metadata.model_dump())
330329

331330
async with self._lock:
332-
try:
333-
# Ensure the key-value store directory exists.
334-
await asyncio.to_thread(self.path_to_kvs.mkdir, parents=True, exist_ok=True)
331+
# Ensure the key-value store directory exists.
332+
await asyncio.to_thread(self.path_to_kvs.mkdir, parents=True, exist_ok=True)
335333

336-
# Write the value to the file.
337-
await atomic_write(record_path, value_bytes)
334+
# Write the value to the file.
335+
await atomic_write(record_path, value_bytes)
338336

339-
# Write the record metadata to the file.
340-
await atomic_write(record_metadata_filepath, record_metadata_content)
337+
# Write the record metadata to the file.
338+
await atomic_write(record_metadata_filepath, record_metadata_content)
341339

342-
# Update the KVS metadata to record the access and modification.
343-
await self._update_metadata(update_accessed_at=True, update_modified_at=True)
344-
except OSError as e:
345-
raise StorageWriteError(e) from e
340+
# Update the KVS metadata to record the access and modification.
341+
await self._update_metadata(update_accessed_at=True, update_modified_at=True)
346342

347343
@override
348344
async def delete_value(self, *, key: str) -> None:

src/crawlee/storage_clients/_redis/_dataset_client.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from redis.exceptions import RedisError
77
from typing_extensions import NotRequired, override
88

9-
from crawlee.errors import StorageWriteError
9+
from crawlee._utils.retry import retry_on_error
1010
from crawlee.storage_clients._base import DatasetClient
1111
from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata
1212

@@ -104,14 +104,17 @@ async def open(
104104
instance_kwargs={},
105105
)
106106

107+
@retry_on_error(RedisError)
107108
@override
108109
async def get_metadata(self) -> DatasetMetadata:
109110
return await self._get_metadata(DatasetMetadata)
110111

112+
@retry_on_error(RedisError)
111113
@override
112114
async def drop(self) -> None:
113115
await self._drop(extra_keys=[self._items_key])
114116

117+
@retry_on_error(RedisError)
115118
@override
116119
async def purge(self) -> None:
117120
await self._purge(
@@ -121,24 +124,22 @@ async def purge(self) -> None:
121124
),
122125
)
123126

127+
@retry_on_error(RedisError)
124128
@override
125129
async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None:
126130
if isinstance(data, dict):
127131
data = [data]
128132

129-
try:
130-
async with self._get_pipeline() as pipe:
131-
pipe.json().arrappend(self._items_key, '$', *data)
132-
await self._update_metadata(
133-
pipe,
134-
**_DatasetMetadataUpdateParams(
135-
update_accessed_at=True, update_modified_at=True, delta_item_count=len(data)
136-
),
137-
)
138-
139-
except RedisError as e:
140-
raise StorageWriteError(e) from e
133+
async with self._get_pipeline() as pipe:
134+
pipe.json().arrappend(self._items_key, '$', *data)
135+
await self._update_metadata(
136+
pipe,
137+
**_DatasetMetadataUpdateParams(
138+
update_accessed_at=True, update_modified_at=True, delta_item_count=len(data)
139+
),
140+
)
141141

142+
@retry_on_error(RedisError)
142143
@override
143144
async def get_data(
144145
self,

src/crawlee/storage_clients/_redis/_key_value_store_client.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing_extensions import override
99

1010
from crawlee._utils.file import infer_mime_type
11-
from crawlee.errors import StorageWriteError
11+
from crawlee._utils.retry import retry_on_error
1212
from crawlee.storage_clients._base import KeyValueStoreClient
1313
from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata
1414

@@ -102,21 +102,25 @@ async def open(
102102
instance_kwargs={},
103103
)
104104

105+
@retry_on_error(RedisError)
105106
@override
106107
async def get_metadata(self) -> KeyValueStoreMetadata:
107108
return await self._get_metadata(KeyValueStoreMetadata)
108109

110+
@retry_on_error(RedisError)
109111
@override
110112
async def drop(self) -> None:
111113
await self._drop(extra_keys=[self._items_key, self._metadata_items_key])
112114

115+
@retry_on_error(RedisError)
113116
@override
114117
async def purge(self) -> None:
115118
await self._purge(
116119
extra_keys=[self._items_key, self._metadata_items_key],
117120
metadata_kwargs=MetadataUpdateParams(update_accessed_at=True, update_modified_at=True),
118121
)
119122

123+
@retry_on_error(RedisError)
120124
@override
121125
async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None:
122126
# Special handling for None values
@@ -143,25 +147,20 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No
143147
content_type=content_type,
144148
size=size,
145149
)
146-
try:
147-
async with self._get_pipeline() as pipe:
148-
# redis-py typing issue
149-
await await_redis_response(pipe.hset(self._items_key, key, value_bytes)) # ty: ignore[invalid-argument-type]
150-
151-
await await_redis_response(
152-
pipe.hset(
153-
self._metadata_items_key,
154-
key,
155-
item_metadata.model_dump_json(),
156-
)
157-
)
158-
await self._update_metadata(
159-
pipe, **MetadataUpdateParams(update_accessed_at=True, update_modified_at=True)
150+
async with self._get_pipeline() as pipe:
151+
# redis-py typing issue
152+
await await_redis_response(pipe.hset(self._items_key, key, value_bytes)) # ty: ignore[invalid-argument-type]
153+
154+
await await_redis_response(
155+
pipe.hset(
156+
self._metadata_items_key,
157+
key,
158+
item_metadata.model_dump_json(),
160159
)
160+
)
161+
await self._update_metadata(pipe, **MetadataUpdateParams(update_accessed_at=True, update_modified_at=True))
161162

162-
except RedisError as e:
163-
raise StorageWriteError(e) from e
164-
163+
@retry_on_error(RedisError)
165164
@override
166165
async def get_value(self, *, key: str) -> KeyValueStoreRecord | None:
167166
serialized_metadata_item = await await_redis_response(self._redis.hget(self._metadata_items_key, key))
@@ -207,6 +206,7 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None:
207206

208207
return KeyValueStoreRecord(value=value, **metadata_item.model_dump())
209208

209+
@retry_on_error(RedisError)
210210
@override
211211
async def delete_value(self, *, key: str) -> None:
212212
async with self._get_pipeline() as pipe:
@@ -258,6 +258,7 @@ async def iterate_keys(
258258
async def get_public_url(self, *, key: str) -> str:
259259
raise NotImplementedError('Public URLs are not supported for memory key-value stores.')
260260

261+
@retry_on_error(RedisError)
261262
@override
262263
async def record_exists(self, *, key: str) -> bool:
263264
async with self._get_pipeline(with_execute=False) as pipe:

src/crawlee/storage_clients/_redis/_request_queue_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
from logging import getLogger
77
from typing import TYPE_CHECKING, Any, Literal
88

9+
from redis.exceptions import RedisError
910
from typing_extensions import NotRequired, override
1011

1112
from crawlee import Request
1213
from crawlee._utils.crypto import crypto_random_object_id
14+
from crawlee._utils.retry import retry_on_error
1315
from crawlee.storage_clients._base import RequestQueueClient
1416
from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, RequestQueueMetadata
1517

@@ -207,10 +209,12 @@ async def open(
207209
instance_kwargs={'dedup_strategy': dedup_strategy, 'bloom_error_rate': bloom_error_rate},
208210
)
209211

212+
@retry_on_error(RedisError)
210213
@override
211214
async def get_metadata(self) -> RequestQueueMetadata:
212215
return await self._get_metadata(RequestQueueMetadata)
213216

217+
@retry_on_error(RedisError)
214218
@override
215219
async def drop(self) -> None:
216220
if self._dedup_strategy == 'bloom':
@@ -222,6 +226,7 @@ async def drop(self) -> None:
222226
extra_keys.extend([self._queue_key, self._data_key, self._in_progress_key])
223227
await self._drop(extra_keys=extra_keys)
224228

229+
@retry_on_error(RedisError)
225230
@override
226231
async def purge(self) -> None:
227232
if self._dedup_strategy == 'bloom':
@@ -349,6 +354,7 @@ async def add_batch_of_requests(
349354
unprocessed_requests=[],
350355
)
351356

357+
@retry_on_error(RedisError)
352358
@override
353359
async def fetch_next_request(self) -> Request | None:
354360
if self._pending_fetch_cache:
@@ -377,6 +383,7 @@ async def fetch_next_request(self) -> Request | None:
377383

378384
return requests[0]
379385

386+
@retry_on_error(RedisError)
380387
@override
381388
async def get_request(self, unique_key: str) -> Request | None:
382389
request_data = await await_redis_response(self._redis.hget(self._data_key, unique_key))
@@ -386,6 +393,7 @@ async def get_request(self, unique_key: str) -> Request | None:
386393

387394
return None
388395

396+
@retry_on_error(RedisError)
389397
@override
390398
async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None:
391399
# Check if the request is in progress.
@@ -424,6 +432,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest |
424432
was_already_handled=True,
425433
)
426434

435+
@retry_on_error(RedisError)
427436
@override
428437
async def reclaim_request(
429438
self,
@@ -469,6 +478,7 @@ async def reclaim_request(
469478
was_already_handled=False,
470479
)
471480

481+
@retry_on_error(RedisError)
472482
@override
473483
async def is_empty(self) -> bool:
474484
"""Check if the queue is empty.

src/crawlee/storage_clients/_sql/_client_mixin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ async def get_session(self, *, with_simple_commit: bool = False) -> AsyncIterato
205205
except SQLAlchemyError as e:
206206
logger.warning(f'Error occurred during session transaction: {e}')
207207
await session.rollback()
208+
raise
208209
else:
209210
yield session
210211

0 commit comments

Comments
 (0)