Skip to content

Commit 1cf303a

Browse files
rob-blackbournRob BlackbournRob Blackbourn
authored
Added support for subscription (#1107)
* Added support for subscription * Added pre-commit hooks for black and formatted changed files * Checked with flake8 * Integrated changes from master. Co-authored-by: Rob Blackbourn <[email protected]> Co-authored-by: Rob Blackbourn <[email protected]>
1 parent 88f79b2 commit 1cf303a

File tree

3 files changed

+113
-4
lines changed

3 files changed

+113
-4
lines changed

docs/execution/execute.rst

+37
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,43 @@ For executing a query against a schema, you can directly call the ``execute`` me
1717
``result`` represents the result of execution. ``result.data`` is the result of executing the query, ``result.errors`` is ``None`` if no errors occurred, and is a non-empty list if an error occurred.
1818

1919

20+
For executing a subscription, you can directly call the ``subscribe`` method on it.
21+
This method is async and must be awaited.
22+
23+
.. code:: python
24+
25+
import asyncio
26+
from datetime import datetime
27+
from graphene import ObjectType, String, Schema, Field
28+
29+
# All schema require a query.
30+
class Query(ObjectType):
31+
hello = String()
32+
33+
def resolve_hello(root, info):
34+
return 'Hello, world!'
35+
36+
class Subscription(ObjectType):
37+
time_of_day = Field(String)
38+
39+
async def subscribe_time_of_day(root, info):
40+
while True:
41+
yield { 'time_of_day': datetime.now().isoformat()}
42+
await asyncio.sleep(1)
43+
44+
SCHEMA = Schema(query=Query, subscription=Subscription)
45+
46+
async def main(schema):
47+
48+
subscription = 'subscription { timeOfDay }'
49+
result = await schema.subscribe(subscription)
50+
async for item in result:
51+
print(item.data['timeOfDay'])
52+
53+
asyncio.run(main(SCHEMA))
54+
55+
The ``result`` is an async iterator which yields items in the same manner as a query.
56+
2057
.. _SchemaExecuteContext:
2158

2259
Context

graphene/types/schema.py

+43-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
graphql,
88
graphql_sync,
99
introspection_types,
10+
parse,
1011
print_schema,
12+
subscribe,
1113
GraphQLArgument,
1214
GraphQLBoolean,
1315
GraphQLEnumValue,
@@ -309,20 +311,52 @@ def create_fields_for_type(self, graphene_type, is_input_type=False):
309311
if isinstance(arg.type, NonNull)
310312
else arg.default_value,
311313
)
312-
resolve = field.get_resolver(
313-
self.get_resolver(graphene_type, name, field.default_value)
314-
)
315314
_field = GraphQLField(
316315
field_type,
317316
args=args,
318-
resolve=resolve,
317+
resolve=field.get_resolver(
318+
self.get_resolver_for_type(
319+
graphene_type, "resolve_{}", name, field.default_value
320+
)
321+
),
322+
subscribe=field.get_resolver(
323+
self.get_resolver_for_type(
324+
graphene_type, "subscribe_{}", name, field.default_value
325+
)
326+
),
319327
deprecation_reason=field.deprecation_reason,
320328
description=field.description,
321329
)
322330
field_name = field.name or self.get_name(name)
323331
fields[field_name] = _field
324332
return fields
325333

334+
def get_resolver_for_type(self, graphene_type, pattern, name, default_value):
335+
if not issubclass(graphene_type, ObjectType):
336+
return
337+
func_name = pattern.format(name)
338+
resolver = getattr(graphene_type, func_name, None)
339+
if not resolver:
340+
# If we don't find the resolver in the ObjectType class, then try to
341+
# find it in each of the interfaces
342+
interface_resolver = None
343+
for interface in graphene_type._meta.interfaces:
344+
if name not in interface._meta.fields:
345+
continue
346+
interface_resolver = getattr(interface, func_name, None)
347+
if interface_resolver:
348+
break
349+
resolver = interface_resolver
350+
351+
# Only if is not decorated with classmethod
352+
if resolver:
353+
return get_unbound_function(resolver)
354+
355+
default_resolver = (
356+
graphene_type._meta.default_resolver or get_default_resolver()
357+
)
358+
return partial(default_resolver, name, default_value)
359+
326360
def resolve_type(self, resolve_type_func, type_name, root, info, _type):
327361
type_ = resolve_type_func(root, info)
328362

@@ -468,6 +502,11 @@ async def execute_async(self, *args, **kwargs):
468502
kwargs = normalize_execute_kwargs(kwargs)
469503
return await graphql(self.graphql_schema, *args, **kwargs)
470504

505+
async def subscribe(self, query, *args, **kwargs):
506+
document = parse(query)
507+
kwargs = normalize_execute_kwargs(kwargs)
508+
return await subscribe(self.graphql_schema, document, *args, **kwargs)
509+
471510
def introspect(self):
472511
introspection = self.execute(introspection_query)
473512
if introspection.errors:

tests_asyncio/test_subscribe.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pytest import mark
2+
3+
from graphene import ObjectType, Int, String, Schema, Field
4+
5+
6+
class Query(ObjectType):
7+
hello = String()
8+
9+
def resolve_hello(root, info):
10+
return "Hello, world!"
11+
12+
13+
class Subscription(ObjectType):
14+
count_to_ten = Field(Int)
15+
16+
async def subscribe_count_to_ten(root, info):
17+
count = 0
18+
while count < 10:
19+
count += 1
20+
yield {"count_to_ten": count}
21+
22+
23+
schema = Schema(query=Query, subscription=Subscription)
24+
25+
26+
@mark.asyncio
27+
async def test_subscription():
28+
subscription = "subscription { countToTen }"
29+
result = await schema.subscribe(subscription)
30+
count = 0
31+
async for item in result:
32+
count = item.data["countToTen"]
33+
assert count == 10

0 commit comments

Comments
 (0)