|
1 | | -import asyncio |
2 | 1 | from inspect import isawaitable |
3 | 2 | from graphene_django.settings import graphene_settings |
4 | 3 | from graphql.execution.executors.asyncio import AsyncioExecutor |
@@ -64,49 +63,34 @@ async def on_start(self, connection_context, op_id, params): |
64 | 63 | if isawaitable(execution_result): |
65 | 64 | execution_result = await execution_result |
66 | 65 |
|
67 | | - if not hasattr(execution_result, "__aiter__"): |
| 66 | + if hasattr(execution_result, "__aiter__"): |
| 67 | + iterator = await execution_result.__aiter__() |
| 68 | + connection_context.register_operation(op_id, iterator) |
| 69 | + async for single_result in iterator: |
| 70 | + if not connection_context.has_operation(op_id): |
| 71 | + break |
| 72 | + await self.send_execution_result( |
| 73 | + connection_context, op_id, single_result |
| 74 | + ) |
| 75 | + else: |
68 | 76 | await self.send_execution_result( |
69 | 77 | connection_context, op_id, execution_result |
70 | 78 | ) |
71 | | - await self.on_operation_complete(connection_context, op_id) |
72 | | - return |
73 | | - |
74 | | - task = asyncio.ensure_future( |
75 | | - self.run_op(connection_context, op_id, execution_result) |
76 | | - ) |
77 | | - connection_context.register_operation(op_id, task) |
78 | | - |
79 | | - async def run_op(self, connection_context, op_id, aiterable): |
80 | | - async for single_result in aiterable: |
81 | | - if not connection_context.has_operation(op_id): |
82 | | - break |
83 | | - await self.send_execution_result(connection_context, op_id, single_result) |
84 | 79 | await self.on_operation_complete(connection_context, op_id) |
85 | 80 |
|
86 | 81 | async def on_close(self, connection_context): |
87 | | - # Unsubscribe from all the connection's current operations in parallel. |
88 | | - unsubscribes = [ |
| 82 | + for op_id in connection_context.operations: |
89 | 83 | self.unsubscribe(connection_context, op_id) |
90 | | - for op_id in connection_context.operations |
91 | | - ] |
92 | | - cancelled_tasks = [task for task in await asyncio.gather(*unsubscribes) if task] |
93 | | - # Wait around for all the tasks to actually cancel. |
94 | | - if cancelled_tasks: |
95 | | - await asyncio.wait(cancelled_tasks) |
96 | 84 |
|
97 | 85 | async def on_stop(self, connection_context, op_id): |
98 | | - task = await self.unsubscribe(connection_context, op_id) |
99 | | - if task: |
100 | | - await asyncio.wait([task]) |
| 86 | + await self.unsubscribe(connection_context, op_id) |
101 | 87 |
|
102 | 88 | async def unsubscribe(self, connection_context, op_id): |
103 | | - op = None |
104 | 89 | if connection_context.has_operation(op_id): |
105 | 90 | op = connection_context.get_operation(op_id) |
106 | | - op.cancel() |
| 91 | + op.dispose() |
107 | 92 | connection_context.remove_operation(op_id) |
108 | 93 | await self.on_operation_complete(connection_context, op_id) |
109 | | - return op |
110 | 94 |
|
111 | 95 | async def on_operation_complete(self, connection_context, op_id): |
112 | 96 | await self.send_message(connection_context, op_id, GQL_COMPLETE) |
|
0 commit comments