|
1 | | -from asyncio import ensure_future |
| 1 | +import asyncio |
2 | 2 | from inspect import isawaitable |
3 | 3 | from graphene_django.settings import graphene_settings |
4 | 4 | from graphql.execution.executors.asyncio import AsyncioExecutor |
|
10 | 10 |
|
11 | 11 |
|
12 | 12 | class ChannelsConnectionContext(BaseConnectionContext): |
13 | | - |
14 | 13 | async def send(self, data): |
15 | 14 | await self.ws.send_json(data) |
16 | 15 |
|
@@ -73,25 +72,38 @@ async def on_start(self, connection_context, op_id, params): |
73 | 72 | return |
74 | 73 |
|
75 | 74 | iterator = await execution_result.__aiter__() |
76 | | - ensure_future(self.run_op(connection_context, op_id, iterator)) |
| 75 | + task = asyncio.ensure_future(self.run_op(connection_context, op_id, iterator)) |
| 76 | + connection_context.register_operation(op_id, task) |
77 | 77 |
|
78 | 78 | async def run_op(self, connection_context, op_id, iterator): |
79 | | - connection_context.register_operation(op_id, iterator) |
80 | 79 | async for single_result in iterator: |
81 | 80 | if not connection_context.has_operation(op_id): |
82 | 81 | break |
83 | | - await self.send_execution_result( |
84 | | - connection_context, op_id, single_result |
85 | | - ) |
| 82 | + await self.send_execution_result(connection_context, op_id, single_result) |
86 | 83 | await self.send_message(connection_context, op_id, GQL_COMPLETE) |
87 | 84 |
|
88 | 85 | async def on_close(self, connection_context): |
89 | 86 | remove_operations = list(connection_context.operations.keys()) |
| 87 | + cancelled_tasks = [] |
90 | 88 | for op_id in remove_operations: |
91 | | - self.unsubscribe(connection_context, op_id) |
| 89 | + task = await self.unsubscribe(connection_context, op_id) |
| 90 | + if task: |
| 91 | + cancelled_tasks.append(task) |
| 92 | + # Wait around for all the tasks to actually cancel. |
| 93 | + await asyncio.gather(*cancelled_tasks, return_exceptions=True) |
92 | 94 |
|
93 | 95 | async def on_stop(self, connection_context, op_id): |
94 | | - self.unsubscribe(connection_context, op_id) |
| 96 | + task = await self.unsubscribe(connection_context, op_id) |
| 97 | + await asyncio.gather(task, return_exceptions=True) |
| 98 | + |
| 99 | + async def unsubscribe(self, connection_context, op_id): |
| 100 | + op = None |
| 101 | + if connection_context.has_operation(op_id): |
| 102 | + op = connection_context.get_operation(op_id) |
| 103 | + op.cancel() |
| 104 | + connection_context.remove_operation(op_id) |
| 105 | + self.on_operation_complete(connection_context, op_id) |
| 106 | + return op |
95 | 107 |
|
96 | 108 |
|
97 | 109 | subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA) |
0 commit comments