@@ -88,6 +88,20 @@ def remember_task(self, task):
8888 task for task in self .pending_tasks if task .done ()
8989 )
9090
91+ async def unsubscribe (self , op_id ):
92+ super ().unsubscribe (op_id )
93+
94+ async def unsubscribe_all (self ):
95+ awaitables = [self .unsubscribe (op_id ) for op_id in list (self .operations )]
96+ for task in self .pending_tasks :
97+ task .cancel ()
98+ awaitables .append (task )
99+ if awaitables :
100+ try :
101+ await asyncio .gather (* awaitables )
102+ except asyncio .CancelledError :
103+ pass
104+
91105
92106class BaseAsyncSubscriptionServer (base .BaseSubscriptionServer , ABC ):
93107 graphql_executor = AsyncioExecutor
@@ -107,9 +121,6 @@ def process_message(self, connection_context, parsed_message):
107121 connection_context .remember_task (task )
108122 return task
109123
110- async def send_message (self , * args , ** kwargs ):
111- await super ().send_message (* args , ** kwargs )
112-
113124 async def on_open (self , connection_context ):
114125 pass
115126
@@ -125,11 +136,13 @@ async def on_connection_init(self, connection_context, op_id, payload):
125136 await connection_context .close (1011 )
126137
127138 async def on_start (self , connection_context , op_id , params ):
128- execution_result = self .execute (params )
139+ # Attempt to unsubscribe first in case we already have a subscription
140+ # with this id.
141+ await connection_context .unsubscribe (op_id )
129142
130- if is_awaitable (execution_result ):
131- execution_result = await execution_result
143+ execution_result = self .execute (params )
132144
145+ connection_context .register_operation (op_id , execution_result )
133146 if hasattr (execution_result , "__aiter__" ):
134147 iterator = await execution_result .__aiter__ ()
135148 connection_context .register_operation (op_id , iterator )
@@ -142,30 +155,25 @@ async def on_start(self, connection_context, op_id, params):
142155 )
143156 except Exception as e :
144157 await self .send_error (connection_context , op_id , e )
145- connection_context .remove_operation (op_id )
146158 else :
147159 try :
160+ if is_awaitable (execution_result ):
161+ execution_result = await execution_result
148162 await self .send_execution_result (
149163 connection_context , op_id , execution_result
150164 )
151165 except Exception as e :
152166 await self .send_error (connection_context , op_id , e )
153167 await self .send_message (connection_context , op_id , GQL_COMPLETE )
168+ connection_context .remove_operation (op_id )
154169 await self .on_operation_complete (connection_context , op_id )
155170
156- async def on_close (self , connection_context ):
157- awaitables = tuple (
158- self .unsubscribe (connection_context , op_id )
159- for op_id in connection_context .operations
160- ) + tuple (task .cancel () for task in connection_context .pending_tasks )
161- if awaitables :
162- try :
163- await asyncio .gather (* awaitables , loop = self .loop )
164- except asyncio .CancelledError :
165- pass
166-
167- async def on_stop (self , connection_context , op_id ):
168- await self .unsubscribe (connection_context , op_id )
171+ async def send_message (
172+ self , connection_context , op_id = None , op_type = None , payload = None
173+ ):
174+ if op_id is None or connection_context .has_operation (op_id ):
175+ message = self .build_message (op_id , op_type , payload )
176+ return await connection_context .send (message )
169177
170178 async def on_operation_complete (self , connection_context , op_id ):
171179 pass
0 commit comments