Skip to content

Commit 7cd6ab4

Browse files
authored
Async API (#598)
* Refactor generating methods/props that are async * Better logic * Codegen + update _classes.py * fix issue, plus add tests * add comment * Add docs * apply codegenn to _api.py * Tweak for prop * Backwards compat * fix codegen test * Fix tests * Replace method usage, and disbale backwards compat * forgot one * fix * format * Logic to disable sync method for portability testing * format and enable backwards compat again * codegen
1 parent bd8dab7 commit 7cd6ab4

39 files changed

+567
-215
lines changed

codegen/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ In some cases we may want to deviate from the WebGPU API, because well ... Pytho
6969
Other changes include:
7070

7171
* Where in JS the input args are provided via a dict, we use kwargs directly. Nevertheless, some input args have subdicts (and sub-sub-dicts)
72-
* For methods that are async in IDL, we also provide sync methods. The Async method names have an "_async" suffix.
72+
* For methods that are async in JavaScript (i.e return a `Promise`), we provide both an asynchronous and synchronous variant, indicated by an `_async` and `_sync` suffix.
7373

7474
### Codegen summary
7575

codegen/apipatcher.py

+134-34
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def patch_properties(self, classname, i1, i2):
179179
elif "@apidiff.hide" in pre_lines:
180180
pass # continue as normal
181181
old_line = self.lines[j1]
182-
new_line = f" def {propname}(self):"
182+
new_line = self.get_property_def(classname, propname)
183183
if old_line != new_line:
184184
fixme_line = " # FIXME: was " + old_line.split("def ", 1)[-1]
185185
lines = [fixme_line, new_line]
@@ -241,7 +241,7 @@ def get_missing_properties(self, classname, seen_props):
241241
if propname not in seen_props:
242242
lines.append(" # FIXME: new prop to implement")
243243
lines.append(" @property")
244-
lines.append(f" def {propname}(self):")
244+
lines.append(self.get_property_def(classname, propname))
245245
lines.append(" raise NotImplementedError()")
246246
lines.append("")
247247
return lines
@@ -265,16 +265,105 @@ class IdlPatcherMixin:
265265
def __init__(self):
266266
super().__init__()
267267
self.idl = get_idl_parser()
268+
self.detect_async_props_and_methods()
269+
270+
def detect_async_props_and_methods(self):
271+
272+
self.async_idl_names = async_idl_names = {} # (sync-name, async-name)
273+
274+
for classname, interface in self.idl.classes.items():
275+
for namedict in [interface.attributes, interface.functions]:
276+
for name_idl, idl_line in namedict.items():
277+
idl_result = idl_line.split(name_idl)[0]
278+
if "Promise" in idl_result:
279+
# We found an async property or method.
280+
name_idl_base = name_idl
281+
if name_idl.endswith("Async"):
282+
name_idl_base = name_idl[:-5]
283+
key = classname, name_idl_base
284+
# Now we determine the kind
285+
if name_idl_base != name_idl and name_idl_base in namedict:
286+
# Has both
287+
async_idl_names[key] = name_idl_base, name_idl
288+
else:
289+
# Only has async
290+
async_idl_names[key] = None, name_idl
291+
292+
def get_idl_name_variants(self, classname, base_name):
293+
"""Returns the names of an idl prop/method for its sync and async variant.
294+
Either can be None.
295+
"""
296+
# Must be a base name, without the suffix
297+
assert not base_name.lower().endswith(("sync", "async"))
298+
299+
key = classname, base_name
300+
default = base_name, None
301+
return self.async_idl_names.get(key, default)
302+
303+
def name2idl(self, classname, name_py):
304+
"""Map a python propname/methodname to the idl variant.
305+
Take async into account.
306+
"""
307+
if name_py == "__init__":
308+
return "constructor"
309+
310+
# Get idl base name
311+
if name_py.endswith(("_sync", "_async")):
312+
name_idl_base = to_camel_case(name_py.rsplit("_", 1)[0])
313+
else:
314+
name_idl_base = to_camel_case(name_py)
268315

269-
def name2idl(self, name):
270-
m = {"__init__": "constructor"}
271-
name = m.get(name, name)
272-
return to_camel_case(name)
316+
# Get idl variant names
317+
idl_sync, idl_async = self.get_idl_name_variants(classname, name_idl_base)
273318

274-
def name2py(self, name):
275-
m = {"constructor": "__init__"}
276-
name = m.get(name, name)
277-
return to_snake_case(name)
319+
# Triage
320+
if idl_sync and idl_async:
321+
if name_py.endswith("_async"):
322+
return idl_async
323+
elif name_py.endswith("_sync"):
324+
return name_idl_base + "InvalidVariant"
325+
else:
326+
return idl_sync
327+
elif idl_async:
328+
if name_py.endswith("_async"):
329+
return idl_async
330+
elif name_py.endswith("_sync"):
331+
return idl_async
332+
else:
333+
return name_idl_base + "InvalidVariant"
334+
else: # idl_sync only
335+
if name_py.endswith("_async"):
336+
return name_idl_base + "InvalidVariant"
337+
elif name_py.endswith("_sync"):
338+
return name_idl_base + "InvalidVariant"
339+
else:
340+
return idl_sync
341+
342+
def name2py_names(self, classname, name_idl):
343+
"""Map a idl propname/methodname to the python variants.
344+
Take async into account. Returns a list with one or two names;
345+
for async props/methods Python has the sync and the async variant.
346+
"""
347+
348+
if name_idl == "constructor":
349+
return ["__init__"]
350+
351+
# Get idl base name
352+
name_idl_base = name_idl
353+
if name_idl.endswith("Async"):
354+
name_idl_base = name_idl[:-5]
355+
name_py_base = to_snake_case(name_idl_base)
356+
357+
# Get idl variant names
358+
idl_sync, idl_async = self.get_idl_name_variants(classname, name_idl_base)
359+
360+
if idl_sync and idl_async:
361+
return [to_snake_case(idl_sync), name_py_base + "_async"]
362+
elif idl_async:
363+
return [name_py_base + "_sync", name_py_base + "_async"]
364+
else:
365+
assert idl_sync == name_idl_base
366+
return [name_py_base]
278367

279368
def class_is_known(self, classname):
280369
return classname in self.idl.classes
@@ -295,22 +384,28 @@ def get_class_def(self, classname):
295384
bases = "" if not bases else f"({', '.join(bases)})"
296385
return f"class {classname}{bases}:"
297386

387+
def get_property_def(self, classname, propname):
388+
attributes = self.idl.classes[classname].attributes
389+
name_idl = self.name2idl(classname, propname)
390+
assert name_idl in attributes
391+
392+
line = "def " + to_snake_case(propname) + "(self):"
393+
if propname.endswith("_async"):
394+
line = "async " + line
395+
return " " + line
396+
298397
def get_method_def(self, classname, methodname):
299-
# Get the corresponding IDL line
300398
functions = self.idl.classes[classname].functions
301-
name_idl = self.name2idl(methodname)
302-
if methodname.endswith("_async") and name_idl not in functions:
303-
name_idl = self.name2idl(methodname.replace("_async", ""))
304-
elif name_idl not in functions and name_idl + "Async" in functions:
305-
name_idl += "Async"
306-
idl_line = functions[name_idl]
399+
name_idl = self.name2idl(classname, methodname)
400+
assert name_idl in functions
307401

308402
# Construct preamble
309403
preamble = "def " + to_snake_case(methodname) + "("
310-
if "async" in methodname:
404+
if methodname.endswith("_async"):
311405
preamble = "async " + preamble
312406

313407
# Get arg names and types
408+
idl_line = functions[name_idl]
314409
args = idl_line.split("(", 1)[1].split(")", 1)[0].split(",")
315410
args = [arg.strip() for arg in args if arg.strip()]
316411
raw_defaults = [arg.partition("=")[2].strip() for arg in args]
@@ -361,28 +456,31 @@ def _arg_from_struct_field(self, field):
361456
return result
362457

363458
def prop_is_known(self, classname, propname):
364-
propname_idl = self.name2idl(propname)
365-
return propname_idl in self.idl.classes[classname].attributes
459+
attributes = self.idl.classes[classname].attributes
460+
propname_idl = self.name2idl(classname, propname)
461+
return propname_idl if propname_idl in attributes else None
366462

367463
def method_is_known(self, classname, methodname):
368464
functions = self.idl.classes[classname].functions
369-
name_idl = self.name2idl(methodname)
370-
if "_async" in methodname and name_idl not in functions:
371-
name_idl = self.name2idl(methodname.replace("_async", ""))
372-
elif name_idl not in functions and name_idl + "Async" in functions:
373-
name_idl += "Async"
374-
return name_idl if name_idl in functions else None
465+
methodname_idl = self.name2idl(classname, methodname)
466+
return methodname_idl if methodname_idl in functions else None
375467

376468
def get_class_names(self):
377469
return list(self.idl.classes.keys())
378470

379471
def get_required_prop_names(self, classname):
380-
propnames_idl = self.idl.classes[classname].attributes.keys()
381-
return [self.name2py(x) for x in propnames_idl]
472+
attributes = self.idl.classes[classname].attributes
473+
names = []
474+
for name_idl in attributes.keys():
475+
names.extend(self.name2py_names(classname, name_idl))
476+
return names
382477

383478
def get_required_method_names(self, classname):
384-
methodnames_idl = self.idl.classes[classname].functions.keys()
385-
return [self.name2py(x) for x in methodnames_idl]
479+
functions = self.idl.classes[classname].functions
480+
names = []
481+
for name_idl in functions.keys():
482+
names.extend(self.name2py_names(classname, name_idl))
483+
return names
386484

387485

388486
class BaseApiPatcher(IdlPatcherMixin, AbstractApiPatcher):
@@ -398,14 +496,16 @@ def get_class_comment(self, classname):
398496
return None
399497

400498
def get_prop_comment(self, classname, propname):
401-
if self.prop_is_known(classname, propname):
402-
propname_idl = self.name2idl(propname)
403-
return " # IDL: " + self.idl.classes[classname].attributes[propname_idl]
499+
attributes = self.idl.classes[classname].attributes
500+
name_idl = self.prop_is_known(classname, propname)
501+
if name_idl:
502+
return " # IDL: " + attributes[name_idl]
404503

405504
def get_method_comment(self, classname, methodname):
505+
functions = self.idl.classes[classname].functions
406506
name_idl = self.method_is_known(classname, methodname)
407507
if name_idl:
408-
return " # IDL: " + self.idl.classes[classname].functions[name_idl]
508+
return " # IDL: " + functions[name_idl]
409509

410510

411511
class BackendApiPatcher(AbstractApiPatcher):

codegen/idlparser.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ class IdlParser:
6767
* enums: a dict mapping the (Pythonic) enum name to a dict of field-value pairs.
6868
* structs: a dict mapping the (Pythonic) struct name to a dict of StructField
6969
objects.
70-
* functions: a dict mapping the (normalized) func name to the line defining the
71-
function.
70+
* classes: a dict mapping the (normalized) class name an Interface object.
7271
7372
"""
7473

codegen/tests/test_codegen_apipatcher.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33

44
from codegen.utils import blacken
5-
from codegen.apipatcher import CommentRemover, AbstractCommentInjector
5+
from codegen.apipatcher import CommentRemover, AbstractCommentInjector, IdlPatcherMixin
66

77

88
def dedent(code):
@@ -110,6 +110,60 @@ def eggs(self):
110110
assert code2 == code3
111111

112112

113+
def test_async_api_logic():
114+
115+
class Object(object):
116+
pass
117+
118+
class OtherIdlPatcherMixin(IdlPatcherMixin):
119+
def __init__(self):
120+
cls = Object()
121+
cls.attributes = {
122+
"prop1": "x prop1 bla",
123+
"prop2": "Promise<x> prop2 bla",
124+
}
125+
cls.functions = {
126+
"method1": "x method1 bla",
127+
"method2": "Promise<x> method2 bla",
128+
"method3Async": "Promise<x> method3 bla",
129+
"method3": "x method3 bla",
130+
}
131+
132+
self.idl = Object()
133+
self.idl.classes = {"Foo": cls}
134+
135+
patcher = OtherIdlPatcherMixin()
136+
patcher.detect_async_props_and_methods()
137+
138+
# Normal prop
139+
assert patcher.name2idl("Foo", "prop1") == "prop1"
140+
assert patcher.name2idl("Foo", "prop1_sync") == "prop1InvalidVariant"
141+
assert patcher.name2idl("Foo", "prop1_async") == "prop1InvalidVariant"
142+
143+
# Unknow prop, name still works
144+
assert patcher.name2idl("Foo", "prop_unknown") == "propUnknown"
145+
146+
# Async prop
147+
assert patcher.name2idl("Foo", "prop2_async") == "prop2"
148+
assert patcher.name2idl("Foo", "prop2_sync") == "prop2"
149+
assert patcher.name2idl("Foo", "prop2") == "prop2InvalidVariant"
150+
151+
# Normal method
152+
assert patcher.name2idl("Foo", "method1") == "method1"
153+
assert patcher.name2idl("Foo", "method1_sync") == "method1InvalidVariant"
154+
assert patcher.name2idl("Foo", "method1_async") == "method1InvalidVariant"
155+
156+
# Async method
157+
assert patcher.name2idl("Foo", "method2_async") == "method2"
158+
assert patcher.name2idl("Foo", "method2_sync") == "method2"
159+
assert patcher.name2idl("Foo", "method2") == "method2InvalidVariant"
160+
161+
# Async method that also has sync variant in JS
162+
assert patcher.name2idl("Foo", "method3_async") == "method3Async"
163+
assert patcher.name2idl("Foo", "method3") == "method3"
164+
assert patcher.name2idl("Foo", "method3_sync") == "method3InvalidVariant"
165+
166+
113167
if __name__ == "__main__":
114168
for func in list(globals().values()):
115169
if callable(func) and func.__name__.startswith("test_"):

codegen/tests/test_codegen_result.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
""" Test some aspects of the generated code.
2+
"""
3+
4+
from codegen.files import read_file
5+
6+
7+
def test_async_methods_and_props():
8+
# Test that only and all aync methods are suffixed with '_async'
9+
10+
for fname in ["_classes.py", "backends/wgpu_native/_api.py"]:
11+
code = read_file(fname)
12+
for line in code.splitlines():
13+
line = line.strip()
14+
if line.startswith("def "):
15+
assert not line.endswith("_async"), line
16+
elif line.startswith("async def "):
17+
name = line.split("def", 1)[1].split("(")[0].strip()
18+
assert name.endswith("_async"), line

docs/backends.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ It also works out of the box, because the wgpu-native DLL is shipped with wgpu-p
4444

4545
The wgpu_native backend provides a few extra functionalities:
4646

47-
.. py:function:: wgpu.backends.wgpu_native.request_device(adapter, trace_path, *, label="", required_features, required_limits, default_queue)
47+
.. py:function:: wgpu.backends.wgpu_native.request_device_sync(adapter, trace_path, *, label="", required_features, required_limits, default_queue)
4848
4949
An alternative to :func:`wgpu.GPUAdapter.request_adapter`, that streams a trace
5050
of all low level calls to disk, so the visualization can be replayed (also on other systems),
@@ -88,7 +88,7 @@ You must tell the adapter to create a device that supports push constants,
8888
and you must tell it the number of bytes of push constants that you are using.
8989
Overestimating is okay::
9090

91-
device = adapter.request_device(
91+
device = adapter.request_device_sync(
9292
required_features=["push-constants"],
9393
required_limits={"max-push-constant-size": 256},
9494
)

docs/guide.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ you can obtain a device.
4343

4444
.. code-block:: py
4545
46-
adapter = wgpu.gpu.request_adapter(power_preference="high-performance")
47-
device = adapter.request_device()
46+
adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance")
47+
device = adapter.request_device_sync()
4848
4949
The ``wgpu.gpu`` object is the API entrypoint (:class:`wgpu.GPU`). It contains just a handful of functions,
5050
including ``request_adapter()``. The device is used to create most other GPU objects.
@@ -232,7 +232,7 @@ You can run your application via RenderDoc, which is able to capture a
232232
frame, including all API calls, objects and the complete pipeline state,
233233
and display all of that information within a nice UI.
234234

235-
You can use ``adapter.request_device()`` to provide a directory path
235+
You can use ``adapter.request_device_sync()`` to provide a directory path
236236
where a trace of all API calls will be written. This trace can then be used
237237
to re-play your use-case elsewhere (it's cross-platform).
238238

docs/start.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ You can verify whether the `"DiscreteGPU"` adapters are found:
9999
import wgpu
100100
import pprint
101101
102-
for a in wgpu.gpu.enumerate_adapters():
102+
for a in wgpu.gpu.enumerate_adapters_sync():
103103
pprint.pprint(a.info)
104104
105105
If you are using a remote frame buffer via `jupyter-rfb <https://github.com/vispy/jupyter_rfb>`_ we also recommend installing the following for optimal performance:

0 commit comments

Comments
 (0)