@@ -22,6 +22,10 @@ def default(self, o):
2222
2323
2424class GraphQLSubscriptionConsumer (AsyncJsonWebsocketConsumer ):
25+ def __init__ (self , * args , ** kwargs ):
26+ super ().__init__ (* args , ** kwargs )
27+ self .futures = []
28+
2529 async def connect (self ):
2630 self .connection_context = None
2731 if WS_PROTOCOL in self .scope ["subprotocols" ]:
@@ -33,14 +37,22 @@ async def connect(self):
3337 await self .close ()
3438
3539 async def disconnect (self , code ):
40+ for future in self .futures :
41+ # Ensure any running message tasks are cancelled.
42+ future .cancel ()
3643 if self .connection_context :
3744 self .connection_context .socket_closed = True
38- await subscription_server .on_close (self .connection_context )
45+ close_future = subscription_server .on_close (self .connection_context )
46+ await asyncio .gather (close_future , * self .futures )
3947
4048 async def receive_json (self , content ):
41- asyncio .ensure_future (
42- subscription_server .on_message (self .connection_context , content )
49+ self .futures .append (
50+ asyncio .ensure_future (
51+ subscription_server .on_message (self .connection_context , content )
52+ )
4353 )
54+ # Clean up any completed futures.
55+ self .futures = [future for future in self .futures if not future .done ()]
4456
4557 @classmethod
4658 async def encode_json (cls , content ):
0 commit comments