Skip to content

Commit 39648dc

Browse files
committed
WASM support
Add wasm support
1 parent e2519fd commit 39648dc

12 files changed

+416
-13
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ The package has no library dependencies and provides pre-compiled wheels for all
1515
1616
```sh
1717
pip install tree-sitter
18+
# For wasm support
19+
pip install tree-sitter[wasm]
1820
```
1921

2022
## Usage
@@ -39,6 +41,22 @@ from tree_sitter import Language, Parser
3941
PY_LANGUAGE = Language(tspython.language())
4042
```
4143

44+
#### Wasm support
45+
46+
If you enable the `wasm` extra, then tree-sitter will be able to use wasmtime to load languages compiled to wasm and parse with them. Example:
47+
48+
```python
49+
from pathlib import Path
50+
from wasmtime import Engine
51+
from tree_sitter import Language, Parser
52+
53+
engine = Engine()
54+
wasm_bytes = Path("my_language.wasm").read_bytes()
55+
MY_LANGUAGE = Language.from_wasm("my_language", engine, wasm_bytes)
56+
```
57+
58+
Languages loaded this way work identically to native-binary languages.
59+
4260
### Basic parsing
4361

4462
Create a `Parser` and configure it to use a language:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ tests = [
3838
"tree-sitter-python",
3939
"tree-sitter-rust",
4040
]
41+
wasm = ["wasmtime>=23"]
4142

4243
[tool.ruff]
4344
target-version = "py39"

setup.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,33 @@
2222
"tree_sitter/binding/range.c",
2323
"tree_sitter/binding/tree.c",
2424
"tree_sitter/binding/tree_cursor.c",
25+
"tree_sitter/binding/wasmtime.c",
2526
"tree_sitter/binding/module.c",
2627
],
2728
include_dirs=[
2829
"tree_sitter/binding",
2930
"tree_sitter/core/lib/include",
3031
"tree_sitter/core/lib/src",
32+
"tree_sitter/core/lib/src/wasm",
3133
],
3234
define_macros=[
3335
("PY_SSIZE_T_CLEAN", None),
3436
("TREE_SITTER_HIDE_SYMBOLS", None),
37+
("TREE_SITTER_FEATURE_WASM", None),
3538
],
36-
undef_macros=[
37-
"TREE_SITTER_FEATURE_WASM",
38-
],
39-
extra_compile_args=[
40-
"-std=c11",
41-
"-fvisibility=hidden",
42-
"-Wno-cast-function-type",
43-
"-Werror=implicit-function-declaration",
44-
] if system() != "Windows" else [
45-
"/std:c11",
46-
"/wd4244",
47-
],
39+
extra_compile_args=(
40+
[
41+
"-std=c11",
42+
"-fvisibility=hidden",
43+
"-Wno-cast-function-type",
44+
"-Werror=implicit-function-declaration",
45+
]
46+
if system() != "Windows"
47+
else [
48+
"/std:c11",
49+
"/wd4244",
50+
]
51+
),
4852
)
4953
],
5054
)

tests/test_wasm.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import importlib.resources
2+
from unittest import TestCase
3+
4+
from tree_sitter import Language, Parser, Tree
5+
6+
try:
7+
import wasmtime
8+
9+
class TestWasm(TestCase):
10+
@classmethod
11+
def setUpClass(cls):
12+
javascript_wasm = (
13+
importlib.resources.files("tests")
14+
.joinpath("wasm/tree-sitter-javascript.wasm")
15+
.read_bytes()
16+
)
17+
engine = wasmtime.Engine()
18+
cls.javascript = Language.from_wasm("javascript", engine, javascript_wasm)
19+
20+
def test_parser(self):
21+
parser = Parser(self.javascript)
22+
self.assertIsInstance(parser.parse(b"test"), Tree)
23+
24+
def test_language_is_wasm(self):
25+
self.assertEqual(self.javascript.is_wasm, True)
26+
27+
except ImportError:
28+
29+
class TestWasmDisabled(TestCase):
30+
def test_parser(self):
31+
def runtest():
32+
Language.from_wasm("javascript", None, b"")
33+
34+
self.assertRaisesRegex(
35+
RuntimeError, "wasmtime module is not loaded", runtest
36+
)
630 KB
Binary file not shown.

tree_sitter/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
MIN_COMPATIBLE_LANGUAGE_VERSION,
1515
)
1616

17-
Point.__doc__ = "A position in a multi-line text document, in terms of rows and columns."
17+
18+
Point.__doc__ = (
19+
"A position in a multi-line text document, in terms of rows and columns."
20+
)
1821
Point.row.__doc__ = "The zero-based row of the document."
1922
Point.column.__doc__ = "The zero-based column of the document."
2023

tree_sitter/binding/language.c

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "types.h"
22

3+
extern void wasm_engine_delete(TSWasmEngine *engine);
4+
extern TSWasmEngine *wasmtime_engine_clone(TSWasmEngine *engine);
5+
36
int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) {
47
PyObject *language;
58
if (!PyArg_ParseTuple(args, "O:__init__", &language)) {
@@ -25,10 +28,119 @@ int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) {
2528
}
2629

2730
void language_dealloc(Language *self) {
31+
if (self->wasm_engine != NULL) {
32+
wasm_engine_delete(self->wasm_engine);
33+
}
2834
ts_language_delete(self->language);
2935
Py_TYPE(self)->tp_free(self);
3036
}
3137

38+
// ctypes.cast(managed_pointer.ptr(), ctypes.c_void_p).value
39+
static void *get_managed_pointer(PyObject *cast, PyObject *c_void_p, PyObject *managed_pointer) {
40+
void *ptr = NULL;
41+
PyObject *ptr_method = NULL;
42+
PyObject *ptr_result = NULL;
43+
PyObject *cast_result = NULL;
44+
PyObject *value_attr = NULL;
45+
46+
// Call .ptr() method on the managed pointer
47+
ptr_method = PyObject_GetAttrString(managed_pointer, "ptr");
48+
if (ptr_method == NULL) {
49+
goto cleanup;
50+
}
51+
ptr_result = PyObject_CallObject(ptr_method, NULL);
52+
if (ptr_result == NULL) {
53+
goto cleanup;
54+
}
55+
56+
// Call cast function
57+
cast_result = PyObject_CallFunctionObjArgs(cast, ptr_result, c_void_p, NULL);
58+
if (cast_result == NULL) {
59+
goto cleanup;
60+
}
61+
62+
// Get the 'value' attribute from the cast result
63+
value_attr = PyObject_GetAttrString(cast_result, "value");
64+
if (value_attr == NULL) {
65+
goto cleanup;
66+
}
67+
68+
// Convert the value attribute to a C void pointer
69+
ptr = PyLong_AsVoidPtr(value_attr);
70+
71+
cleanup:
72+
Py_XDECREF(value_attr);
73+
Py_XDECREF(cast_result);
74+
Py_XDECREF(ptr_result);
75+
Py_XDECREF(ptr_method);
76+
77+
if (PyErr_Occurred()) {
78+
return NULL;
79+
}
80+
81+
return ptr;
82+
}
83+
84+
PyObject *language_from_wasm(PyTypeObject *cls, PyObject *args) {
85+
ModuleState *state = (ModuleState *)PyType_GetModuleState(cls);
86+
TSWasmError error;
87+
TSWasmStore *wasm_store = NULL;
88+
TSLanguage *language = NULL;
89+
Language *self = NULL;
90+
char *name;
91+
PyObject *py_engine = NULL;
92+
char *wasm;
93+
Py_ssize_t wasm_length;
94+
if (state->wasmtime_engine_type == NULL) {
95+
PyErr_SetString(PyExc_RuntimeError, "wasmtime module is not loaded");
96+
return NULL;
97+
}
98+
if (!PyArg_ParseTuple(args, "sO!y#:from_wasm", &name, state->wasmtime_engine_type, &py_engine, &wasm, &wasm_length)) {
99+
return NULL;
100+
}
101+
102+
TSWasmEngine *engine = (TSWasmEngine *)get_managed_pointer(state->ctypes_cast, state->c_void_p, py_engine);
103+
if (engine == NULL) {
104+
goto fail;
105+
}
106+
engine = wasmtime_engine_clone(engine);
107+
if (engine == NULL) {
108+
goto fail;
109+
}
110+
111+
wasm_store = ts_wasm_store_new(engine, &error);
112+
if (wasm_store == NULL) {
113+
PyErr_Format(PyExc_RuntimeError, "Failed to create TSWasmStore: %s", error.message);
114+
goto fail;
115+
}
116+
117+
language = (TSLanguage *)ts_wasm_store_load_language(wasm_store, name, wasm, wasm_length, &error);
118+
if (language == NULL) {
119+
PyErr_Format(PyExc_RuntimeError, "Failed to load language: %s", error.message);
120+
goto fail;
121+
}
122+
123+
self = (Language *)cls->tp_alloc(cls, 0);
124+
if (self == NULL) {
125+
goto fail;
126+
}
127+
128+
self->language = language;
129+
self->wasm_engine = engine;
130+
self->version = ts_language_version(self->language);
131+
#if HAS_LANGUAGE_NAMES
132+
self->name = ts_language_name(self->language);
133+
#endif
134+
return (PyObject *)self;
135+
136+
fail:
137+
if (engine != NULL) {
138+
wasm_engine_delete(engine);
139+
}
140+
ts_language_delete(language);
141+
return NULL;
142+
}
143+
32144
PyObject *language_repr(Language *self) {
33145
#if HAS_LANGUAGE_NAMES
34146
if (self->name == NULL) {
@@ -77,6 +189,10 @@ PyObject *language_get_field_count(Language *self, void *Py_UNUSED(payload)) {
77189
return PyLong_FromUnsignedLong(ts_language_field_count(self->language));
78190
}
79191

192+
PyObject *language_is_wasm(Language *self, void *Py_UNUSED(payload)) {
193+
return PyBool_FromLong(ts_language_is_wasm(self->language));
194+
}
195+
80196
PyObject *language_node_kind_for_id(Language *self, PyObject *args) {
81197
TSSymbol symbol;
82198
if (!PyArg_ParseTuple(args, "H:node_kind_for_id", &symbol)) {
@@ -185,6 +301,9 @@ PyObject *language_query(Language *self, PyObject *args) {
185301
return PyObject_CallFunction((PyObject *)state->query_type, "Os#", self, source, length);
186302
}
187303

304+
PyDoc_STRVAR(language_from_wasm_doc,
305+
"from_wasm(self, name, engine, wasm, /)\n--\n\n"
306+
"Load a language compiled as wasm.");
188307
PyDoc_STRVAR(language_node_kind_for_id_doc,
189308
"node_kind_for_id(self, id, /)\n--\n\n"
190309
"Get the name of the node kind for the given numerical id.");
@@ -215,6 +334,12 @@ PyDoc_STRVAR(
215334
"Create a new :class:`Query` from a string containing one or more S-expression patterns.");
216335

217336
static PyMethodDef language_methods[] = {
337+
{
338+
.ml_name = "from_wasm",
339+
.ml_meth = (PyCFunction)language_from_wasm,
340+
.ml_flags = METH_CLASS | METH_VARARGS,
341+
.ml_doc = language_from_wasm_doc,
342+
},
218343
{
219344
.ml_name = "node_kind_for_id",
220345
.ml_meth = (PyCFunction)language_node_kind_for_id,
@@ -286,6 +411,8 @@ static PyGetSetDef language_accessors[] = {
286411
PyDoc_STR("The number of valid states in this language."), NULL},
287412
{"field_count", (getter)language_get_field_count, NULL,
288413
PyDoc_STR("The number of distinct field names in this language."), NULL},
414+
{"is_wasm", (getter)language_is_wasm, NULL,
415+
PyDoc_STR("Check if the language came from a wasm module."), NULL},
289416
{NULL},
290417
};
291418

tree_sitter/binding/module.c

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <wasm.h>
12
#include "types.h"
23

34
extern PyType_Spec capture_eq_capture_type_spec;
@@ -15,6 +16,8 @@ extern PyType_Spec range_type_spec;
1516
extern PyType_Spec tree_cursor_type_spec;
1617
extern PyType_Spec tree_type_spec;
1718

19+
void tsp_load_wasmtime_symbols();
20+
1821
// TODO(0.24): drop Python 3.9 support
1922
#if PY_MINOR_VERSION > 9
2023
#define AddObjectRef PyModule_AddObjectRef
@@ -54,12 +57,17 @@ static void module_free(void *self) {
5457
Py_XDECREF(state->query_type);
5558
Py_XDECREF(state->range_type);
5659
Py_XDECREF(state->query_capture_type);
60+
Py_XDECREF(state->query_match_type);
5761
Py_XDECREF(state->capture_eq_capture_type);
5862
Py_XDECREF(state->capture_eq_string_type);
5963
Py_XDECREF(state->capture_match_string_type);
6064
Py_XDECREF(state->lookahead_iterator_type);
65+
Py_XDECREF(state->lookahead_names_iterator_type);
6166
Py_XDECREF(state->re_compile);
6267
Py_XDECREF(state->namedtuple);
68+
Py_XDECREF(state->wasmtime_engine_type);
69+
Py_XDECREF(state->ctypes_cast);
70+
Py_XDECREF(state->c_void_p);
6371
}
6472

6573
static struct PyModuleDef module_definition = {
@@ -137,6 +145,34 @@ PyMODINIT_FUNC PyInit__binding(void) {
137145
goto cleanup;
138146
}
139147

148+
PyObject *wasmtime_engine = import_attribute("wasmtime", "Engine");
149+
if (wasmtime_engine == NULL) {
150+
// No worries, disable functionality.
151+
PyErr_Clear();
152+
} else {
153+
// Ensure wasmtime_engine is a PyTypeObject
154+
if (!PyType_Check(wasmtime_engine)) {
155+
PyErr_SetString(PyExc_TypeError, "wasmtime.Engine is not a type");
156+
goto cleanup;
157+
}
158+
state->wasmtime_engine_type = (PyTypeObject *)wasmtime_engine;
159+
160+
tsp_load_wasmtime_symbols();
161+
if (PyErr_Occurred()) {
162+
goto cleanup;
163+
}
164+
165+
state->ctypes_cast = import_attribute("ctypes", "cast");
166+
if (state->ctypes_cast == NULL) {
167+
goto cleanup;
168+
}
169+
170+
state->c_void_p = import_attribute("ctypes", "c_void_p");
171+
if (state->c_void_p == NULL) {
172+
goto cleanup;
173+
}
174+
}
175+
140176
PyObject *point_args = Py_BuildValue("s[ss]", "Point", "row", "column");
141177
PyObject *point_kwargs = PyDict_New();
142178
PyDict_SetItemString(point_kwargs, "module", PyUnicode_FromString("tree_sitter"));

0 commit comments

Comments
 (0)