Skip to content

Commit d85df83

Browse files
committed
Use new abstracted base code for django channels 2
1 parent cb5126a commit d85df83

2 files changed

Lines changed: 6 additions & 98 deletions

File tree

graphql_ws/django/consumers.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ def default(self, o):
2222

2323

2424
class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer):
25-
def __init__(self, *args, **kwargs):
26-
super().__init__(*args, **kwargs)
27-
self.futures = []
2825

2926
async def connect(self):
3027
self.connection_context = None
@@ -37,22 +34,14 @@ async def connect(self):
3734
await self.close()
3835

3936
async def disconnect(self, code):
40-
for future in self.futures:
41-
# Ensure any running message tasks are cancelled.
42-
future.cancel()
4337
if self.connection_context:
4438
self.connection_context.socket_closed = True
45-
close_future = subscription_server.on_close(self.connection_context)
46-
await asyncio.gather(close_future, *self.futures)
39+
await subscription_server.on_close(self.connection_context)
4740

4841
async def receive_json(self, content):
49-
self.futures.append(
50-
asyncio.ensure_future(
51-
subscription_server.on_message(self.connection_context, content)
52-
)
42+
self.connection_context.remember_task(
43+
subscription_server.on_message(self.connection_context, content)
5344
)
54-
# Clean up any completed futures.
55-
self.futures = [future for future in self.futures if not future.done()]
5645

5746
@classmethod
5847
async def encode_json(cls, content):

graphql_ws/django/subscriptions.py

Lines changed: 3 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
1-
import asyncio
2-
from inspect import isawaitable
31
from graphene_django.settings import graphene_settings
4-
from graphql.execution.executors.asyncio import AsyncioExecutor
5-
from ..base import BaseConnectionContext, BaseSubscriptionServer
6-
from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE
2+
from ..base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer
73
from ..observable_aiter import setup_observable_extension
84

95
setup_observable_extension()
106

117

12-
class ChannelsConnectionContext(BaseConnectionContext):
8+
class ChannelsConnectionContext(BaseAsyncConnectionContext):
139
def __init__(self, *args, **kwargs):
1410
super(ChannelsConnectionContext, self).__init__(*args, **kwargs)
1511
self.socket_closed = False
@@ -27,88 +23,11 @@ async def close(self, code):
2723
await self.ws.close(code=code)
2824

2925

30-
class ChannelsSubscriptionServer(BaseSubscriptionServer):
31-
def get_graphql_params(self, connection_context, payload):
32-
payload["context"] = connection_context.request_context
33-
params = super(ChannelsSubscriptionServer, self).get_graphql_params(
34-
connection_context, payload
35-
)
36-
return dict(params, return_promise=True, executor=AsyncioExecutor())
37-
26+
class ChannelsSubscriptionServer(BaseAsyncSubscriptionServer):
3827
async def handle(self, ws, request_context=None):
3928
connection_context = ChannelsConnectionContext(ws, request_context)
4029
await self.on_open(connection_context)
4130
return connection_context
4231

43-
async def send_message(
44-
self, connection_context, op_id=None, op_type=None, payload=None
45-
):
46-
message = {}
47-
if op_id is not None:
48-
message["id"] = op_id
49-
if op_type is not None:
50-
message["type"] = op_type
51-
if payload is not None:
52-
message["payload"] = payload
53-
54-
assert message, "You need to send at least one thing"
55-
return await connection_context.send(message)
56-
57-
async def on_open(self, connection_context):
58-
pass
59-
60-
async def on_connect(self, connection_context, payload):
61-
pass
62-
63-
async def on_connection_init(self, connection_context, op_id, payload):
64-
try:
65-
await self.on_connect(connection_context, payload)
66-
await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK)
67-
except Exception as e:
68-
await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR)
69-
await connection_context.close(1011)
70-
71-
async def on_start(self, connection_context, op_id, params):
72-
execution_result = self.execute(connection_context.request_context, params)
73-
74-
if isawaitable(execution_result):
75-
execution_result = await execution_result
76-
77-
if hasattr(execution_result, "__aiter__"):
78-
iterator = await execution_result.__aiter__()
79-
connection_context.register_operation(op_id, iterator)
80-
async for single_result in iterator:
81-
if not connection_context.has_operation(op_id):
82-
break
83-
await self.send_execution_result(
84-
connection_context, op_id, single_result
85-
)
86-
else:
87-
await self.send_execution_result(
88-
connection_context, op_id, execution_result
89-
)
90-
await self.on_operation_complete(connection_context, op_id)
91-
92-
async def on_close(self, connection_context):
93-
unsubscribes = [
94-
self.unsubscribe(connection_context, op_id)
95-
for op_id in connection_context.operations
96-
]
97-
if unsubscribes:
98-
await asyncio.wait(unsubscribes)
99-
100-
async def on_stop(self, connection_context, op_id):
101-
await self.unsubscribe(connection_context, op_id)
102-
103-
async def unsubscribe(self, connection_context, op_id):
104-
if connection_context.has_operation(op_id):
105-
op = connection_context.get_operation(op_id)
106-
op.dispose()
107-
connection_context.remove_operation(op_id)
108-
await self.on_operation_complete(connection_context, op_id)
109-
110-
async def on_operation_complete(self, connection_context, op_id):
111-
await self.send_message(connection_context, op_id, GQL_COMPLETE)
112-
11332

11433
subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA)

0 commit comments

Comments
 (0)