1- import asyncio
2- from inspect import isawaitable
31from graphene_django .settings import graphene_settings
4- from graphql .execution .executors .asyncio import AsyncioExecutor
5- from ..base import BaseConnectionContext , BaseSubscriptionServer
6- from ..constants import GQL_CONNECTION_ACK , GQL_CONNECTION_ERROR , GQL_COMPLETE
2+ from ..base_async import BaseAsyncConnectionContext , BaseAsyncSubscriptionServer
73from ..observable_aiter import setup_observable_extension
84
95setup_observable_extension ()
106
117
12- class ChannelsConnectionContext (BaseConnectionContext ):
8+ class ChannelsConnectionContext (BaseAsyncConnectionContext ):
139 def __init__ (self , * args , ** kwargs ):
1410 super (ChannelsConnectionContext , self ).__init__ (* args , ** kwargs )
1511 self .socket_closed = False
@@ -27,88 +23,11 @@ async def close(self, code):
2723 await self .ws .close (code = code )
2824
2925
30- class ChannelsSubscriptionServer (BaseSubscriptionServer ):
31- def get_graphql_params (self , connection_context , payload ):
32- payload ["context" ] = connection_context .request_context
33- params = super (ChannelsSubscriptionServer , self ).get_graphql_params (
34- connection_context , payload
35- )
36- return dict (params , return_promise = True , executor = AsyncioExecutor ())
37-
26+ class ChannelsSubscriptionServer (BaseAsyncSubscriptionServer ):
3827 async def handle (self , ws , request_context = None ):
3928 connection_context = ChannelsConnectionContext (ws , request_context )
4029 await self .on_open (connection_context )
4130 return connection_context
4231
43- async def send_message (
44- self , connection_context , op_id = None , op_type = None , payload = None
45- ):
46- message = {}
47- if op_id is not None :
48- message ["id" ] = op_id
49- if op_type is not None :
50- message ["type" ] = op_type
51- if payload is not None :
52- message ["payload" ] = payload
53-
54- assert message , "You need to send at least one thing"
55- return await connection_context .send (message )
56-
57- async def on_open (self , connection_context ):
58- pass
59-
60- async def on_connect (self , connection_context , payload ):
61- pass
62-
63- async def on_connection_init (self , connection_context , op_id , payload ):
64- try :
65- await self .on_connect (connection_context , payload )
66- await self .send_message (connection_context , op_type = GQL_CONNECTION_ACK )
67- except Exception as e :
68- await self .send_error (connection_context , op_id , e , GQL_CONNECTION_ERROR )
69- await connection_context .close (1011 )
70-
71- async def on_start (self , connection_context , op_id , params ):
72- execution_result = self .execute (connection_context .request_context , params )
73-
74- if isawaitable (execution_result ):
75- execution_result = await execution_result
76-
77- if hasattr (execution_result , "__aiter__" ):
78- iterator = await execution_result .__aiter__ ()
79- connection_context .register_operation (op_id , iterator )
80- async for single_result in iterator :
81- if not connection_context .has_operation (op_id ):
82- break
83- await self .send_execution_result (
84- connection_context , op_id , single_result
85- )
86- else :
87- await self .send_execution_result (
88- connection_context , op_id , execution_result
89- )
90- await self .on_operation_complete (connection_context , op_id )
91-
92- async def on_close (self , connection_context ):
93- unsubscribes = [
94- self .unsubscribe (connection_context , op_id )
95- for op_id in connection_context .operations
96- ]
97- if unsubscribes :
98- await asyncio .wait (unsubscribes )
99-
100- async def on_stop (self , connection_context , op_id ):
101- await self .unsubscribe (connection_context , op_id )
102-
103- async def unsubscribe (self , connection_context , op_id ):
104- if connection_context .has_operation (op_id ):
105- op = connection_context .get_operation (op_id )
106- op .dispose ()
107- connection_context .remove_operation (op_id )
108- await self .on_operation_complete (connection_context , op_id )
109-
110- async def on_operation_complete (self , connection_context , op_id ):
111- await self .send_message (connection_context , op_id , GQL_COMPLETE )
112-
11332
11433subscription_server = ChannelsSubscriptionServer (schema = graphene_settings .SCHEMA )
0 commit comments