3131by other components of the library.
3232"""
3333
34+ import json
3435import sys
3536from types import MappingProxyType , TracebackType
3637
5152)
5253from typesense .node_manager import NodeManager
5354from typesense .request_handler import RequestHandler
55+ from typesense .stream_handlers import (
56+ JSONDict ,
57+ StreamChunk ,
58+ combine_stream_chunks ,
59+ is_message_chunk ,
60+ parse_sse_line ,
61+ )
62+ from typesense .types .document import StreamConfig
5463
5564if sys .version_info >= (3 , 11 ):
5665 import typing
@@ -168,9 +177,9 @@ def __enter__(self) -> "ApiCall":
168177
169178 def __exit__ (
170179 self ,
171- exc_type : typing .Optional [ typing . Type [BaseException ]] ,
172- exc_val : typing . Optional [ BaseException ] ,
173- exc_tb : typing . Optional [ TracebackType ] ,
180+ exc_type : typing .Type [BaseException ] | None ,
181+ exc_val : BaseException | None ,
182+ exc_tb : TracebackType | None ,
174183 ) -> None :
175184 """Async context manager exit."""
176185 self ._client .close ()
@@ -186,6 +195,8 @@ def get(
186195 entity_type : typing .Type [TEntityDict ],
187196 as_json : typing .Literal [False ],
188197 params : typing .Union [TParams , None ] = None ,
198+ stream_config : StreamConfig [TEntityDict ] | None = None ,
199+ is_streaming_request : bool = False ,
189200 ) -> str :
190201 """
191202 Execute an async GET request to the Typesense API.
@@ -207,6 +218,8 @@ def get(
207218 entity_type : typing .Type [TEntityDict ],
208219 as_json : typing .Literal [True ] = True ,
209220 params : typing .Union [TParams , None ] = None ,
221+ stream_config : StreamConfig [TEntityDict ] | None = None ,
222+ is_streaming_request : bool = False ,
210223 ) -> TEntityDict :
211224 """
212225 Execute an async GET request to the Typesense API.
@@ -227,6 +240,8 @@ def get(
227240 entity_type : typing .Type [TEntityDict ],
228241 as_json : typing .Union [typing .Literal [True ], typing .Literal [False ]] = True ,
229242 params : typing .Union [TParams , None ] = None ,
243+ stream_config : StreamConfig [TEntityDict ] | None = None ,
244+ is_streaming_request : bool = False ,
230245 ) -> typing .Union [TEntityDict , str ]:
231246 """
232247 Execute an async GET request to the Typesense API.
@@ -246,6 +261,8 @@ def get(
246261 entity_type ,
247262 as_json ,
248263 params = params ,
264+ stream_config = stream_config ,
265+ is_streaming_request = is_streaming_request ,
249266 )
250267
251268 @typing .overload
@@ -414,6 +431,8 @@ def _execute_request(
414431 as_json : typing .Literal [True ],
415432 last_exception : typing .Union [None , Exception ] = None ,
416433 num_retries : int = 0 ,
434+ stream_config : StreamConfig [TEntityDict ] | None = None ,
435+ is_streaming_request : bool = False ,
417436 ** kwargs : typing .Unpack [SessionFunctionKwargs [TParams , TBody ]],
418437 ) -> TEntityDict :
419438 """Execute an async request with retry logic."""
@@ -427,6 +446,8 @@ def _execute_request(
427446 as_json : typing .Literal [False ],
428447 last_exception : typing .Union [None , Exception ] = None ,
429448 num_retries : int = 0 ,
449+ stream_config : StreamConfig [TEntityDict ] | None = None ,
450+ is_streaming_request : bool = False ,
430451 ** kwargs : typing .Unpack [SessionFunctionKwargs [TParams , TBody ]],
431452 ) -> str :
432453 """Execute an async request with retry logic."""
@@ -439,6 +460,8 @@ def _execute_request(
439460 as_json : typing .Union [typing .Literal [True ], typing .Literal [False ]] = True ,
440461 last_exception : typing .Union [None , Exception ] = None ,
441462 num_retries : int = 0 ,
463+ stream_config : StreamConfig [TEntityDict ] | None = None ,
464+ is_streaming_request : bool = False ,
442465 ** kwargs : typing .Unpack [SessionFunctionKwargs [TParams , TBody ]],
443466 ) -> typing .Union [TEntityDict , str ]:
444467 """
@@ -470,6 +493,10 @@ def _execute_request(
470493 node , url , request_kwargs = self ._prepare_request_params (endpoint , ** kwargs )
471494
472495 try :
496+ if is_streaming_request and method == "GET" :
497+ return self ._handle_streaming_get (
498+ url , entity_type , stream_config , ** request_kwargs
499+ )
473500 return self ._make_request_and_process_response (
474501 method ,
475502 url ,
@@ -479,13 +506,22 @@ def _execute_request(
479506 )
480507 except _SERVER_ERRORS as server_error :
481508 self .node_manager .set_node_health (node , is_healthy = False )
509+ if is_streaming_request and stream_config :
510+ on_error = stream_config .get ("on_error" )
511+ if on_error :
512+ try :
513+ on_error (server_error )
514+ except Exception :
515+ pass
482516 return self ._execute_request (
483517 method ,
484518 endpoint ,
485519 entity_type ,
486520 as_json ,
487521 last_exception = server_error ,
488522 num_retries = num_retries + 1 ,
523+ stream_config = stream_config ,
524+ is_streaming_request = is_streaming_request ,
489525 ** kwargs ,
490526 )
491527
@@ -516,6 +552,73 @@ def _make_request_and_process_response(
516552 else typing .cast (str , request_response )
517553 )
518554
555+ def _handle_streaming_get (
556+ self ,
557+ url : str ,
558+ entity_type : typing .Type [TEntityDict ],
559+ stream_config : StreamConfig [TEntityDict ] | None ,
560+ ** kwargs : typing .Unpack [SessionFunctionKwargs [TParams , TBody ]],
561+ ) -> TEntityDict :
562+ """Perform an async streaming GET, parse SSE lines, invoke callbacks, return combined result."""
563+ headers : typing .Dict [str , str ] = {
564+ self .request_handler .api_key_header_name : self .config .api_key ,
565+ "Accept" : "text/event-stream" ,
566+ }
567+ headers .update (self .config .additional_headers )
568+ extra_headers = kwargs .get ("headers" )
569+ if extra_headers :
570+ headers .update (extra_headers )
571+
572+ params = kwargs .get ("params" )
573+ content : typing .Union [str , bytes , None ] = None
574+ if body := kwargs .get ("data" ):
575+ if isinstance (body , (str , bytes )):
576+ content = body
577+ else :
578+ content = json .dumps (body )
579+
580+ all_chunks : typing .List [StreamChunk ] = []
581+ with self ._client .stream (
582+ "GET" ,
583+ url ,
584+ params = params ,
585+ content = content ,
586+ headers = headers ,
587+ timeout = self .config .connection_timeout_seconds ,
588+ ) as response :
589+ if response .status_code < 200 or response .status_code >= 300 :
590+ response .read ()
591+ error_message = self .request_handler ._get_error_message (response )
592+ raise self .request_handler ._get_exception (response .status_code )(
593+ response .status_code ,
594+ error_message ,
595+ )
596+ for line in response .iter_lines ():
597+ chunk = parse_sse_line (line )
598+ if chunk is not None :
599+ all_chunks .append (chunk )
600+ if stream_config and is_message_chunk (chunk ):
601+ on_chunk = stream_config .get ("on_chunk" )
602+ if on_chunk :
603+ try :
604+ on_chunk (chunk )
605+ except Exception :
606+ pass
607+
608+ self .node_manager .set_node_health (
609+ self .node_manager .get_node (),
610+ is_healthy = True ,
611+ )
612+ final : JSONDict = combine_stream_chunks (all_chunks )
613+ if stream_config :
614+ on_complete = stream_config .get ("on_complete" )
615+ if on_complete :
616+ try :
617+ on_complete (typing .cast (TEntityDict , final ))
618+ except Exception :
619+ pass
620+ return typing .cast (TEntityDict , final )
621+
519622 def _prepare_request_params (
520623 self ,
521624 endpoint : str ,
0 commit comments