Skip to content

Commit e7b0b83

Browse files
committed
feat(streaming): generate sync client for streaming responses
1 parent c6eb0c4 commit e7b0b83

File tree

2 files changed

+117
-4
lines changed

2 files changed

+117
-4
lines changed

src/typesense/sync/api_call.py

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
by other components of the library.
3232
"""
3333

34+
import json
3435
import sys
3536
from types import MappingProxyType, TracebackType
3637

@@ -51,6 +52,14 @@
5152
)
5253
from typesense.node_manager import NodeManager
5354
from typesense.request_handler import RequestHandler
55+
from typesense.stream_handlers import (
56+
JSONDict,
57+
StreamChunk,
58+
combine_stream_chunks,
59+
is_message_chunk,
60+
parse_sse_line,
61+
)
62+
from typesense.types.document import StreamConfig
5463

5564
if sys.version_info >= (3, 11):
5665
import typing
@@ -168,9 +177,9 @@ def __enter__(self) -> "ApiCall":
168177

169178
def __exit__(
170179
self,
171-
exc_type: typing.Optional[typing.Type[BaseException]],
172-
exc_val: typing.Optional[BaseException],
173-
exc_tb: typing.Optional[TracebackType],
180+
exc_type: typing.Type[BaseException] | None,
181+
exc_val: BaseException | None,
182+
exc_tb: TracebackType | None,
174183
) -> None:
175184
"""Async context manager exit."""
176185
self._client.close()
@@ -186,6 +195,8 @@ def get(
186195
entity_type: typing.Type[TEntityDict],
187196
as_json: typing.Literal[False],
188197
params: typing.Union[TParams, None] = None,
198+
stream_config: StreamConfig[TEntityDict] | None = None,
199+
is_streaming_request: bool = False,
189200
) -> str:
190201
"""
191202
Execute an async GET request to the Typesense API.
@@ -207,6 +218,8 @@ def get(
207218
entity_type: typing.Type[TEntityDict],
208219
as_json: typing.Literal[True] = True,
209220
params: typing.Union[TParams, None] = None,
221+
stream_config: StreamConfig[TEntityDict] | None = None,
222+
is_streaming_request: bool = False,
210223
) -> TEntityDict:
211224
"""
212225
Execute an async GET request to the Typesense API.
@@ -227,6 +240,8 @@ def get(
227240
entity_type: typing.Type[TEntityDict],
228241
as_json: typing.Union[typing.Literal[True], typing.Literal[False]] = True,
229242
params: typing.Union[TParams, None] = None,
243+
stream_config: StreamConfig[TEntityDict] | None = None,
244+
is_streaming_request: bool = False,
230245
) -> typing.Union[TEntityDict, str]:
231246
"""
232247
Execute an async GET request to the Typesense API.
@@ -246,6 +261,8 @@ def get(
246261
entity_type,
247262
as_json,
248263
params=params,
264+
stream_config=stream_config,
265+
is_streaming_request=is_streaming_request,
249266
)
250267

251268
@typing.overload
@@ -414,6 +431,8 @@ def _execute_request(
414431
as_json: typing.Literal[True],
415432
last_exception: typing.Union[None, Exception] = None,
416433
num_retries: int = 0,
434+
stream_config: StreamConfig[TEntityDict] | None = None,
435+
is_streaming_request: bool = False,
417436
**kwargs: typing.Unpack[SessionFunctionKwargs[TParams, TBody]],
418437
) -> TEntityDict:
419438
"""Execute an async request with retry logic."""
@@ -427,6 +446,8 @@ def _execute_request(
427446
as_json: typing.Literal[False],
428447
last_exception: typing.Union[None, Exception] = None,
429448
num_retries: int = 0,
449+
stream_config: StreamConfig[TEntityDict] | None = None,
450+
is_streaming_request: bool = False,
430451
**kwargs: typing.Unpack[SessionFunctionKwargs[TParams, TBody]],
431452
) -> str:
432453
"""Execute an async request with retry logic."""
@@ -439,6 +460,8 @@ def _execute_request(
439460
as_json: typing.Union[typing.Literal[True], typing.Literal[False]] = True,
440461
last_exception: typing.Union[None, Exception] = None,
441462
num_retries: int = 0,
463+
stream_config: StreamConfig[TEntityDict] | None = None,
464+
is_streaming_request: bool = False,
442465
**kwargs: typing.Unpack[SessionFunctionKwargs[TParams, TBody]],
443466
) -> typing.Union[TEntityDict, str]:
444467
"""
@@ -470,6 +493,10 @@ def _execute_request(
470493
node, url, request_kwargs = self._prepare_request_params(endpoint, **kwargs)
471494

472495
try:
496+
if is_streaming_request and method == "GET":
497+
return self._handle_streaming_get(
498+
url, entity_type, stream_config, **request_kwargs
499+
)
473500
return self._make_request_and_process_response(
474501
method,
475502
url,
@@ -479,13 +506,22 @@ def _execute_request(
479506
)
480507
except _SERVER_ERRORS as server_error:
481508
self.node_manager.set_node_health(node, is_healthy=False)
509+
if is_streaming_request and stream_config:
510+
on_error = stream_config.get("on_error")
511+
if on_error:
512+
try:
513+
on_error(server_error)
514+
except Exception:
515+
pass
482516
return self._execute_request(
483517
method,
484518
endpoint,
485519
entity_type,
486520
as_json,
487521
last_exception=server_error,
488522
num_retries=num_retries + 1,
523+
stream_config=stream_config,
524+
is_streaming_request=is_streaming_request,
489525
**kwargs,
490526
)
491527

@@ -516,6 +552,73 @@ def _make_request_and_process_response(
516552
else typing.cast(str, request_response)
517553
)
518554

555+
def _handle_streaming_get(
556+
self,
557+
url: str,
558+
entity_type: typing.Type[TEntityDict],
559+
stream_config: StreamConfig[TEntityDict] | None,
560+
**kwargs: typing.Unpack[SessionFunctionKwargs[TParams, TBody]],
561+
) -> TEntityDict:
562+
"""Perform an async streaming GET, parse SSE lines, invoke callbacks, return combined result."""
563+
headers: typing.Dict[str, str] = {
564+
self.request_handler.api_key_header_name: self.config.api_key,
565+
"Accept": "text/event-stream",
566+
}
567+
headers.update(self.config.additional_headers)
568+
extra_headers = kwargs.get("headers")
569+
if extra_headers:
570+
headers.update(extra_headers)
571+
572+
params = kwargs.get("params")
573+
content: typing.Union[str, bytes, None] = None
574+
if body := kwargs.get("data"):
575+
if isinstance(body, (str, bytes)):
576+
content = body
577+
else:
578+
content = json.dumps(body)
579+
580+
all_chunks: typing.List[StreamChunk] = []
581+
with self._client.stream(
582+
"GET",
583+
url,
584+
params=params,
585+
content=content,
586+
headers=headers,
587+
timeout=self.config.connection_timeout_seconds,
588+
) as response:
589+
if response.status_code < 200 or response.status_code >= 300:
590+
response.read()
591+
error_message = self.request_handler._get_error_message(response)
592+
raise self.request_handler._get_exception(response.status_code)(
593+
response.status_code,
594+
error_message,
595+
)
596+
for line in response.iter_lines():
597+
chunk = parse_sse_line(line)
598+
if chunk is not None:
599+
all_chunks.append(chunk)
600+
if stream_config and is_message_chunk(chunk):
601+
on_chunk = stream_config.get("on_chunk")
602+
if on_chunk:
603+
try:
604+
on_chunk(chunk)
605+
except Exception:
606+
pass
607+
608+
self.node_manager.set_node_health(
609+
self.node_manager.get_node(),
610+
is_healthy=True,
611+
)
612+
final: JSONDict = combine_stream_chunks(all_chunks)
613+
if stream_config:
614+
on_complete = stream_config.get("on_complete")
615+
if on_complete:
616+
try:
617+
on_complete(typing.cast(TEntityDict, final))
618+
except Exception:
619+
pass
620+
return typing.cast(TEntityDict, final)
621+
519622
def _prepare_request_params(
520623
self,
521624
endpoint: str,

src/typesense/sync/documents.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
ImportResponseWithId,
4444
SearchParameters,
4545
SearchResponse,
46+
StreamConfigBuilder,
4647
UpdateByFilterParameters,
4748
UpdateByFilterResponse,
4849
)
@@ -333,16 +334,25 @@ def search(self, search_parameters: SearchParameters) -> SearchResponse[TDoc]:
333334
334335
Args:
335336
search_parameters (SearchParameters): The search parameters.
337+
Use conversation_stream=True and optionally stream_config (on_chunk,
338+
on_complete, on_error) for conversational search streaming.
336339
337340
Returns:
338341
SearchResponse[TDoc]: The search response containing matching documents.
339342
"""
340-
stringified_search_params = stringify_search_params(search_parameters)
343+
params_for_api = dict(search_parameters)
344+
stream_config = params_for_api.pop("stream_config", None)
345+
if isinstance(stream_config, StreamConfigBuilder):
346+
stream_config = stream_config.build()
347+
conversation_stream = params_for_api.get("conversation_stream") is True
348+
stringified_search_params = stringify_search_params(params_for_api)
341349
response: SearchResponse[TDoc] = self.api_call.get(
342350
self._endpoint_path("search"),
343351
params=stringified_search_params,
344352
entity_type=SearchResponse,
345353
as_json=True,
354+
stream_config=stream_config,
355+
is_streaming_request=conversation_stream,
346356
)
347357
return response
348358

0 commit comments

Comments
 (0)