|
1 | | -from asgiref.sync import async_to_sync |
| 1 | +from inspect import isawaitable |
2 | 2 | from graphene_django.settings import graphene_settings |
3 | 3 | from graphql.execution.executors.asyncio import AsyncioExecutor |
4 | | -from rx import Observer, Observable |
| 4 | +from rx import Observer |
5 | 5 | from ..base import BaseConnectionContext, BaseSubscriptionServer |
6 | | -from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR |
| 6 | +from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE |
| 7 | +from ..observable_aiter import setup_observable_extension |
| 8 | + |
| 9 | +setup_observable_extension() |
7 | 10 |
|
8 | 11 |
|
9 | 12 | class SubscriptionObserver(Observer): |
@@ -76,24 +79,25 @@ async def on_connection_init(self, connection_context, op_id, payload): |
76 | 79 | await connection_context.close(1011) |
77 | 80 |
|
78 | 81 | async def on_start(self, connection_context, op_id, params): |
79 | | - try: |
80 | | - execution_result = await self.execute( |
81 | | - connection_context.request_context, params |
| 82 | + execution_result = self.execute(connection_context.request_context, params) |
| 83 | + |
| 84 | + if isawaitable(execution_result): |
| 85 | + execution_result = await execution_result |
| 86 | + |
| 87 | + if not hasattr(execution_result, "__aiter__"): |
| 88 | + await self.send_execution_result( |
| 89 | + connection_context, op_id, execution_result |
82 | 90 | ) |
83 | | - assert isinstance( |
84 | | - execution_result, Observable |
85 | | - ), "A subscription must return an observable" |
86 | | - execution_result.subscribe( |
87 | | - SubscriptionObserver( |
88 | | - connection_context, |
89 | | - op_id, |
90 | | - async_to_sync(self.send_execution_result), |
91 | | - async_to_sync(self.send_error), |
92 | | - async_to_sync(self.on_close), |
| 91 | + else: |
| 92 | + iterator = await execution_result.__aiter__() |
| 93 | + connection_context.register_operation(op_id, iterator) |
| 94 | + async for single_result in iterator: |
| 95 | + if not connection_context.has_operation(op_id): |
| 96 | + break |
| 97 | + await self.send_execution_result( |
| 98 | + connection_context, op_id, single_result |
93 | 99 | ) |
94 | | - ) |
95 | | - except Exception as e: |
96 | | - self.send_error(connection_context, op_id, str(e)) |
| 100 | + await self.send_message(connection_context, op_id, GQL_COMPLETE) |
97 | 101 |
|
98 | 102 | async def on_close(self, connection_context): |
99 | 103 | remove_operations = list(connection_context.operations.keys()) |
|
0 commit comments