Skip to content

[Draft] WASM support #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ The package has no library dependencies and provides pre-compiled wheels for all

```sh
pip install tree-sitter
# For wasm support
pip install tree-sitter[wasm]
```

## Usage
Expand All @@ -39,6 +41,22 @@ from tree_sitter import Language, Parser
PY_LANGUAGE = Language(tspython.language())
```

#### Wasm support

If you enable the `wasm` extra, then tree-sitter will be able to use wasmtime to load languages compiled to wasm and parse with them. Example:

```python
from pathlib import Path
from wasmtime import Engine
from tree_sitter import Language, Parser

engine = Engine()
wasm_bytes = Path("my_language.wasm").read_bytes()
MY_LANGUAGE = Language.from_wasm("my_language", engine, wasm_bytes)
```

Languages loaded this way work identically to native-binary languages.

### Basic parsing

Create a `Parser` and configure it to use a language:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ tests = [
"tree-sitter-python>=0.23.0",
"tree-sitter-rust>=0.23.0",
]
wasm = ["wasmtime>=25"]

[tool.ruff]
target-version = "py39"
Expand Down
28 changes: 16 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,33 @@
"tree_sitter/binding/range.c",
"tree_sitter/binding/tree.c",
"tree_sitter/binding/tree_cursor.c",
"tree_sitter/binding/wasmtime.c",
"tree_sitter/binding/module.c",
],
include_dirs=[
"tree_sitter/binding",
"tree_sitter/core/lib/include",
"tree_sitter/core/lib/src",
"tree_sitter/core/lib/src/wasm",
],
define_macros=[
("PY_SSIZE_T_CLEAN", None),
("TREE_SITTER_HIDE_SYMBOLS", None),
("TREE_SITTER_FEATURE_WASM", None),
],
undef_macros=[
"TREE_SITTER_FEATURE_WASM",
],
extra_compile_args=[
"-std=c11",
"-fvisibility=hidden",
"-Wno-cast-function-type",
"-Werror=implicit-function-declaration",
] if system() != "Windows" else [
"/std:c11",
"/wd4244",
],
extra_compile_args=(
[
"-std=c11",
"-fvisibility=hidden",
"-Wno-cast-function-type",
"-Werror=implicit-function-declaration",
]
if system() != "Windows"
else [
"/std:c11",
"/wd4244",
]
),
)
],
)
36 changes: 36 additions & 0 deletions tests/test_wasm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import importlib.resources
from unittest import TestCase

from tree_sitter import Language, Parser, Tree

try:
import wasmtime

class TestWasm(TestCase):
@classmethod
def setUpClass(cls):
javascript_wasm = (
importlib.resources.files("tests")
.joinpath("wasm/tree-sitter-javascript.wasm")
.read_bytes()
)
engine = wasmtime.Engine()
cls.javascript = Language.from_wasm("javascript", engine, javascript_wasm)

def test_parser(self):
parser = Parser(self.javascript)
self.assertIsInstance(parser.parse(b"test"), Tree)

def test_language_is_wasm(self):
self.assertEqual(self.javascript.is_wasm, True)

except ImportError:

class TestWasmDisabled(TestCase):
def test_parser(self):
def runtest():
Language.from_wasm("javascript", None, b"")

self.assertRaisesRegex(
RuntimeError, "wasmtime module is not loaded", runtest
)
Binary file added tests/wasm/tree-sitter-javascript.wasm
Binary file not shown.
5 changes: 4 additions & 1 deletion tree_sitter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
MIN_COMPATIBLE_LANGUAGE_VERSION,
)

Point.__doc__ = "A position in a multi-line text document, in terms of rows and columns."

Point.__doc__ = (
"A position in a multi-line text document, in terms of rows and columns."
)
Point.row.__doc__ = "The zero-based row of the document."
Point.column.__doc__ = "The zero-based column of the document."

Expand Down
127 changes: 127 additions & 0 deletions tree_sitter/binding/language.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "types.h"

extern void wasm_engine_delete(TSWasmEngine *engine);
extern TSWasmEngine *wasmtime_engine_clone(TSWasmEngine *engine);

int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) {
PyObject *language;
if (!PyArg_ParseTuple(args, "O:__init__", &language)) {
Expand Down Expand Up @@ -30,10 +33,119 @@ int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) {
}

void language_dealloc(Language *self) {
if (self->wasm_engine != NULL) {
wasm_engine_delete(self->wasm_engine);
}
ts_language_delete(self->language);
Py_TYPE(self)->tp_free(self);
}

// ctypes.cast(managed_pointer.ptr(), ctypes.c_void_p).value
static void *get_managed_pointer(PyObject *cast, PyObject *c_void_p, PyObject *managed_pointer) {
void *ptr = NULL;
PyObject *ptr_method = NULL;
PyObject *ptr_result = NULL;
PyObject *cast_result = NULL;
PyObject *value_attr = NULL;

// Call .ptr() method on the managed pointer
ptr_method = PyObject_GetAttrString(managed_pointer, "ptr");
if (ptr_method == NULL) {
goto cleanup;
}
ptr_result = PyObject_CallObject(ptr_method, NULL);
if (ptr_result == NULL) {
goto cleanup;
}

// Call cast function
cast_result = PyObject_CallFunctionObjArgs(cast, ptr_result, c_void_p, NULL);
if (cast_result == NULL) {
goto cleanup;
}

// Get the 'value' attribute from the cast result
value_attr = PyObject_GetAttrString(cast_result, "value");
if (value_attr == NULL) {
goto cleanup;
}

// Convert the value attribute to a C void pointer
ptr = PyLong_AsVoidPtr(value_attr);

cleanup:
Py_XDECREF(value_attr);
Py_XDECREF(cast_result);
Py_XDECREF(ptr_result);
Py_XDECREF(ptr_method);

if (PyErr_Occurred()) {
return NULL;
}

return ptr;
}

PyObject *language_from_wasm(PyTypeObject *cls, PyObject *args) {
ModuleState *state = (ModuleState *)PyType_GetModuleState(cls);
TSWasmError error;
TSWasmStore *wasm_store = NULL;
TSLanguage *language = NULL;
Language *self = NULL;
char *name;
PyObject *py_engine = NULL;
char *wasm;
Py_ssize_t wasm_length;
if (state->wasmtime_engine_type == NULL) {
PyErr_SetString(PyExc_RuntimeError, "wasmtime module is not loaded");
return NULL;
}
if (!PyArg_ParseTuple(args, "sO!y#:from_wasm", &name, state->wasmtime_engine_type, &py_engine, &wasm, &wasm_length)) {
return NULL;
}

TSWasmEngine *engine = (TSWasmEngine *)get_managed_pointer(state->ctypes_cast, state->c_void_p, py_engine);
if (engine == NULL) {
goto fail;
}
engine = wasmtime_engine_clone(engine);
if (engine == NULL) {
goto fail;
}

wasm_store = ts_wasm_store_new(engine, &error);
if (wasm_store == NULL) {
PyErr_Format(PyExc_RuntimeError, "Failed to create TSWasmStore: %s", error.message);
goto fail;
}

language = (TSLanguage *)ts_wasm_store_load_language(wasm_store, name, wasm, wasm_length, &error);
if (language == NULL) {
PyErr_Format(PyExc_RuntimeError, "Failed to load language: %s", error.message);
goto fail;
}

self = (Language *)cls->tp_alloc(cls, 0);
if (self == NULL) {
goto fail;
}

self->language = language;
self->wasm_engine = engine;
self->version = ts_language_version(self->language);
#if HAS_LANGUAGE_NAMES
self->name = ts_language_name(self->language);
#endif
return (PyObject *)self;

fail:
if (engine != NULL) {
wasm_engine_delete(engine);
}
ts_language_delete(language);
return NULL;
}

PyObject *language_repr(Language *self) {
#if HAS_LANGUAGE_NAMES
if (self->name == NULL) {
Expand Down Expand Up @@ -82,6 +194,10 @@ PyObject *language_get_field_count(Language *self, void *Py_UNUSED(payload)) {
return PyLong_FromUnsignedLong(ts_language_field_count(self->language));
}

PyObject *language_is_wasm(Language *self, void *Py_UNUSED(payload)) {
return PyBool_FromLong(ts_language_is_wasm(self->language));
}

PyObject *language_node_kind_for_id(Language *self, PyObject *args) {
TSSymbol symbol;
if (!PyArg_ParseTuple(args, "H:node_kind_for_id", &symbol)) {
Expand Down Expand Up @@ -190,6 +306,9 @@ PyObject *language_query(Language *self, PyObject *args) {
return PyObject_CallFunction((PyObject *)state->query_type, "Os#", self, source, length);
}

PyDoc_STRVAR(language_from_wasm_doc,
"from_wasm(self, name, engine, wasm, /)\n--\n\n"
"Load a language compiled as wasm.");
PyDoc_STRVAR(language_node_kind_for_id_doc,
"node_kind_for_id(self, id, /)\n--\n\n"
"Get the name of the node kind for the given numerical id.");
Expand Down Expand Up @@ -220,6 +339,12 @@ PyDoc_STRVAR(
"Create a new :class:`Query` from a string containing one or more S-expression patterns.");

static PyMethodDef language_methods[] = {
{
.ml_name = "from_wasm",
.ml_meth = (PyCFunction)language_from_wasm,
.ml_flags = METH_CLASS | METH_VARARGS,
.ml_doc = language_from_wasm_doc,
},
{
.ml_name = "node_kind_for_id",
.ml_meth = (PyCFunction)language_node_kind_for_id,
Expand Down Expand Up @@ -291,6 +416,8 @@ static PyGetSetDef language_accessors[] = {
PyDoc_STR("The number of valid states in this language."), NULL},
{"field_count", (getter)language_get_field_count, NULL,
PyDoc_STR("The number of distinct field names in this language."), NULL},
{"is_wasm", (getter)language_is_wasm, NULL,
PyDoc_STR("Check if the language came from a wasm module."), NULL},
{NULL},
};

Expand Down
35 changes: 35 additions & 0 deletions tree_sitter/binding/module.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <wasm.h>
#include "types.h"

extern PyType_Spec language_type_spec;
Expand All @@ -15,6 +16,8 @@ extern PyType_Spec range_type_spec;
extern PyType_Spec tree_cursor_type_spec;
extern PyType_Spec tree_type_spec;

void tsp_load_wasmtime_symbols();

// TODO(0.24): drop Python 3.9 support
#if PY_MINOR_VERSION > 9
#define AddObjectRef PyModule_AddObjectRef
Expand Down Expand Up @@ -62,6 +65,9 @@ static void module_free(void *self) {
Py_XDECREF(state->tree_type);
Py_XDECREF(state->query_error);
Py_XDECREF(state->re_compile);
Py_XDECREF(state->wasmtime_engine_type);
Py_XDECREF(state->ctypes_cast);
Py_XDECREF(state->c_void_p);
}

static struct PyModuleDef module_definition = {
Expand Down Expand Up @@ -147,6 +153,35 @@ PyMODINIT_FUNC PyInit__binding(void) {
if (namedtuple == NULL) {
goto cleanup;
}

PyObject *wasmtime_engine = import_attribute("wasmtime", "Engine");
if (wasmtime_engine == NULL) {
// No worries, disable functionality.
PyErr_Clear();
} else {
// Ensure wasmtime_engine is a PyTypeObject
if (!PyType_Check(wasmtime_engine)) {
PyErr_SetString(PyExc_TypeError, "wasmtime.Engine is not a type");
goto cleanup;
}
state->wasmtime_engine_type = (PyTypeObject *)wasmtime_engine;

tsp_load_wasmtime_symbols();
if (PyErr_Occurred()) {
goto cleanup;
}

state->ctypes_cast = import_attribute("ctypes", "cast");
if (state->ctypes_cast == NULL) {
goto cleanup;
}

state->c_void_p = import_attribute("ctypes", "c_void_p");
if (state->c_void_p == NULL) {
goto cleanup;
}
}

PyObject *point_args = Py_BuildValue("s[ss]", "Point", "row", "column");
PyObject *point_kwargs = PyDict_New();
PyDict_SetItemString(point_kwargs, "module", PyUnicode_FromString("tree_sitter"));
Expand Down
Loading