From 56fb60ac3af09d8f880a4574e39b7f7a238efeb2 Mon Sep 17 00:00:00 2001 From: Nathan Peterson <2295306+MrPrezident@users.noreply.github.com> Date: Fri, 18 Apr 2025 09:37:43 -0400 Subject: [PATCH 1/2] feat: keep parser callback for queries --- tests/test_query.py | 75 ++++++++++++++++++++ tree_sitter/binding/node.c | 131 +++++++++++++++++++++++++++-------- tree_sitter/binding/parser.c | 2 - 3 files changed, 178 insertions(+), 30 deletions(-) diff --git a/tests/test_query.py b/tests/test_query.py index eac140ef..fa988e4f 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -219,6 +219,81 @@ def test_text_predicates(self): self.assertEqual(captures2[0][0], "function-name") self.assertEqual(captures2[0][1][0].text, b"fun2") + def test_text_predicates_with_callback(self): + parser = Parser(self.javascript) + source = b""" + keypair_object = { + key1: value1, + equal: equal + } + + function fun1(arg) { + return 1; + } + + function fun2(arg) { + return 2; + } + """ + + def read_callable_byte_offset(byte_offset, point): + return source[byte_offset: byte_offset + 1] + + def read_callable_point(byte_offset, point): + row, col = point + lines = source.split(b"\n") + if row >= len(lines): + return b"" + line = lines[row] + if col >= len(line): + return b"\n" + return line[col:col + 1] + + tree1 = parser.parse(read_callable_byte_offset) + root_node1 = tree1.root_node + tree2 = parser.parse(read_callable_point) + root_node2 = tree2.root_node + + # function with name equal to 'fun1' -> test for #eq? @capture string + query1 = Query( + self.javascript, + """ + ((function_declaration + name: (identifier) @function-name) + (#eq? @function-name fun1)) + """ + ) + cursor = QueryCursor(query1) + captures1 = list(cursor.captures(root_node1).items()) + self.assertEqual(1, len(captures1)) + self.assertEqual(captures1[0][0], "function-name") + self.assertEqual(captures1[0][1][0].text, b"fun1") + + captures2 = list(cursor.captures(root_node2).items()) + self.assertEqual(1, len(captures2)) + self.assertEqual(captures2[0][0], "function-name") + self.assertEqual(captures2[0][1][0].text, b"fun1") + + # functions with name not equal to 'fun1' -> test for #not-eq? @capture string + query2 = Query( + self.javascript, + """ + ((function_declaration + name: (identifier) @function-name) + (#not-eq? @function-name fun1)) + """ + ) + cursor = QueryCursor(query2) + captures3 = list(cursor.captures(root_node1).items()) + self.assertEqual(1, len(captures3)) + self.assertEqual(captures3[0][0], "function-name") + self.assertEqual(captures3[0][1][0].text, b"fun2") + + captures4 = list(cursor.captures(root_node2).items()) + self.assertEqual(1, len(captures4)) + self.assertEqual(captures4[0][0], "function-name") + self.assertEqual(captures4[0][1][0].text, b"fun2") + def test_text_predicates_errors(self): with self.assertRaises(QueryError): Query( diff --git a/tree_sitter/binding/node.c b/tree_sitter/binding/node.c index 1d55ed8a..86dfe0ed 100644 --- a/tree_sitter/binding/node.c +++ b/tree_sitter/binding/node.c @@ -557,36 +557,111 @@ PyObject *node_get_text(Node *self, void *Py_UNUSED(payload)) { Py_RETURN_NONE; } - PyObject *start_byte = PyLong_FromUnsignedLong(ts_node_start_byte(self->node)); - if (start_byte == NULL) { - PyErr_SetString(PyExc_RuntimeError, "Failed to determine start byte"); - return NULL; - } - PyObject *end_byte = PyLong_FromUnsignedLong(ts_node_end_byte(self->node)); - if (end_byte == NULL) { - Py_DECREF(start_byte); - PyErr_SetString(PyExc_RuntimeError, "Failed to determine end byte"); - return NULL; - } - PyObject *slice = PySlice_New(start_byte, end_byte, NULL); - Py_DECREF(start_byte); - Py_DECREF(end_byte); - if (slice == NULL) { - return NULL; - } - PyObject *node_mv = PyMemoryView_FromObject(tree->source); - if (node_mv == NULL) { + PyObject *result = NULL; + size_t start_offset = (size_t)ts_node_start_byte(self->node); + size_t end_offset = (size_t)ts_node_end_byte(self->node); + + // Case 1: source is a byte buffer + if (!PyCallable_Check(tree->source)) { + PyObject *start_byte = PyLong_FromSize_t(start_offset), + *end_byte = PyLong_FromSize_t(end_offset); + PyObject *slice = PySlice_New(start_byte, end_byte, NULL); + Py_XDECREF(start_byte); + Py_XDECREF(end_byte); + if (slice == NULL) { + return NULL; + } + + PyObject *node_mv = PyMemoryView_FromObject(tree->source); + if (node_mv == NULL) { + Py_DECREF(slice); + return NULL; + } + + PyObject *node_slice = PyObject_GetItem(node_mv, slice); Py_DECREF(slice); - return NULL; - } - PyObject *node_slice = PyObject_GetItem(node_mv, slice); - Py_DECREF(slice); - Py_DECREF(node_mv); - if (node_slice == NULL) { - return NULL; + Py_DECREF(node_mv); + if (node_slice == NULL) { + return NULL; + } + + result = PyBytes_FromObject(node_slice); + Py_DECREF(node_slice); + } else { + // Case 2: source is a callable + PyObject *collected_bytes = PyByteArray_FromStringAndSize(NULL, 0); + if (collected_bytes == NULL) { + return NULL; + } + TSPoint start_point = ts_node_start_point(self->node); + TSPoint current_point = start_point; + + for (size_t current_offset = start_offset; current_offset < end_offset;) { + PyObject *byte_offset_obj = PyLong_FromSize_t(current_offset); + PyObject *row = PyLong_FromSize_t((size_t)current_point.row); + PyObject *column = PyLong_FromSize_t((size_t)current_point.column); + PyObject *point_obj = PyTuple_Pack(2, row, column); + Py_XDECREF(row); + Py_XDECREF(column); + if (!point_obj) { + Py_XDECREF(collected_bytes); + return NULL; + } + + PyObject *args = PyTuple_Pack(2, byte_offset_obj, point_obj); + Py_XDECREF(byte_offset_obj); + Py_XDECREF(point_obj); + PyObject *rv = PyObject_Call(tree->source, args, NULL); + Py_XDECREF(args); + if (rv == NULL || rv == Py_None || !PyBytes_Check(rv)) { + Py_XDECREF(rv); + Py_XDECREF(collected_bytes); + return NULL; + } + + PyObject *rv_bytearray = PyByteArray_FromObject(rv); + if (rv_bytearray == NULL) { + Py_XDECREF(collected_bytes); + Py_XDECREF(rv); + return NULL; + } + + PyObject *new_collected_bytes = PyByteArray_Concat(collected_bytes, rv_bytearray); + Py_DECREF(rv_bytearray); + Py_DECREF(collected_bytes); + if (new_collected_bytes == NULL) { + Py_XDECREF(rv); + return NULL; + } + collected_bytes = new_collected_bytes; + + // Update current_point and current_offset + Py_ssize_t bytes_read = PyBytes_Size(rv); + const char *rv_str = PyBytes_AsString(rv); // Retrieve the string pointer once + for (Py_ssize_t i = 0; i < bytes_read; ++i) { + if (rv_str[i] == '\n') { + ++current_point.row; + current_point.column = 0; + } else { + ++current_point.column; + } + } + current_offset += bytes_read; + } + + PyObject *start_byte = PyLong_FromSize_t(0); + PyObject *end_byte = PyLong_FromSize_t(end_offset - start_offset); + PyObject *slice = PySlice_New(start_byte, end_byte, NULL); + Py_XDECREF(start_byte); + Py_XDECREF(end_byte); + if (slice == NULL) { + Py_DECREF(collected_bytes); + return NULL; + } + result = PyObject_GetItem(collected_bytes, slice); + Py_DECREF(slice); + Py_XDECREF(collected_bytes); } - PyObject *result = PyBytes_FromObject(node_slice); - Py_DECREF(node_slice); return result; } diff --git a/tree_sitter/binding/parser.c b/tree_sitter/binding/parser.c index 09384144..cd76e5df 100644 --- a/tree_sitter/binding/parser.c +++ b/tree_sitter/binding/parser.c @@ -180,8 +180,6 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) { } Py_XDECREF(payload.previous_retval); - source_or_callback = Py_None; - keep_text = false; } else { PyErr_Format(PyExc_TypeError, "source must be a bytestring or a callable, not %s", source_or_callback->ob_type->tp_name); From 296baa24365200eb9b55fee62d44c2d721a0b580 Mon Sep 17 00:00:00 2001 From: Nathan Peterson <2295306+MrPrezident@users.noreply.github.com> Date: Sun, 4 May 2025 18:35:14 -0500 Subject: [PATCH 2/2] fix args to read_callable in node_get_text --- tests/test_query.py | 42 ++++++++++++++++++++++++-------------- tree_sitter/binding/node.c | 35 ++++++++++++++++--------------- 2 files changed, 45 insertions(+), 32 deletions(-) diff --git a/tests/test_query.py b/tests/test_query.py index fa988e4f..807e2191 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -234,6 +234,10 @@ def test_text_predicates_with_callback(self): function fun2(arg) { return 2; } + + function fun3(arg) { + return 3; + } """ def read_callable_byte_offset(byte_offset, point): @@ -260,19 +264,23 @@ def read_callable_point(byte_offset, point): """ ((function_declaration name: (identifier) @function-name) - (#eq? @function-name fun1)) + (#match? @function-name "fun[12]")) """ ) - cursor = QueryCursor(query1) - captures1 = list(cursor.captures(root_node1).items()) + cursor1 = QueryCursor(query1) + captures1 = cursor1.captures(root_node1) self.assertEqual(1, len(captures1)) - self.assertEqual(captures1[0][0], "function-name") - self.assertEqual(captures1[0][1][0].text, b"fun1") + self.assertIn("function-name", captures1) + self.assertEqual(2, len(captures1["function-name"])) + self.assertEqual(captures1["function-name"][0].text, b"fun1") + self.assertEqual(captures1["function-name"][1].text, b"fun2") - captures2 = list(cursor.captures(root_node2).items()) + captures2 = cursor1.captures(root_node2) self.assertEqual(1, len(captures2)) - self.assertEqual(captures2[0][0], "function-name") - self.assertEqual(captures2[0][1][0].text, b"fun1") + self.assertIn("function-name", captures2) + self.assertEqual(2, len(captures2["function-name"])) + self.assertEqual(captures2["function-name"][0].text, b"fun1") + self.assertEqual(captures2["function-name"][1].text, b"fun2") # functions with name not equal to 'fun1' -> test for #not-eq? @capture string query2 = Query( @@ -283,16 +291,20 @@ def read_callable_point(byte_offset, point): (#not-eq? @function-name fun1)) """ ) - cursor = QueryCursor(query2) - captures3 = list(cursor.captures(root_node1).items()) + cursor2 = QueryCursor(query2) + captures3 = cursor2.captures(root_node1) self.assertEqual(1, len(captures3)) - self.assertEqual(captures3[0][0], "function-name") - self.assertEqual(captures3[0][1][0].text, b"fun2") + self.assertIn("function-name", captures3) + self.assertEqual(2, len(captures3["function-name"])) + self.assertEqual(captures3["function-name"][0].text, b"fun2") + self.assertEqual(captures3["function-name"][1].text, b"fun3") - captures4 = list(cursor.captures(root_node2).items()) + captures4 = cursor2.captures(root_node2) self.assertEqual(1, len(captures4)) - self.assertEqual(captures4[0][0], "function-name") - self.assertEqual(captures4[0][1][0].text, b"fun2") + self.assertIn("function-name", captures4) + self.assertEqual(2, len(captures4["function-name"])) + self.assertEqual(captures4["function-name"][0].text, b"fun2") + self.assertEqual(captures4["function-name"][1].text, b"fun3") def test_text_predicates_errors(self): with self.assertRaises(QueryError): diff --git a/tree_sitter/binding/node.c b/tree_sitter/binding/node.c index 86dfe0ed..83ace612 100644 --- a/tree_sitter/binding/node.c +++ b/tree_sitter/binding/node.c @@ -597,31 +597,30 @@ PyObject *node_get_text(Node *self, void *Py_UNUSED(payload)) { TSPoint current_point = start_point; for (size_t current_offset = start_offset; current_offset < end_offset;) { + // Form arguments to callable. PyObject *byte_offset_obj = PyLong_FromSize_t(current_offset); - PyObject *row = PyLong_FromSize_t((size_t)current_point.row); - PyObject *column = PyLong_FromSize_t((size_t)current_point.column); - PyObject *point_obj = PyTuple_Pack(2, row, column); - Py_XDECREF(row); - Py_XDECREF(column); - if (!point_obj) { - Py_XDECREF(collected_bytes); + if (!byte_offset_obj) { + Py_DECREF(collected_bytes); + return NULL; + } + PyObject *position_obj = POINT_NEW(GET_MODULE_STATE(self), current_point); + if (!position_obj) { + Py_DECREF(byte_offset_obj); + Py_DECREF(collected_bytes); return NULL; } - PyObject *args = PyTuple_Pack(2, byte_offset_obj, point_obj); + PyObject *args = PyTuple_Pack(2, byte_offset_obj, position_obj); Py_XDECREF(byte_offset_obj); - Py_XDECREF(point_obj); + Py_XDECREF(position_obj); + + // Call callable. PyObject *rv = PyObject_Call(tree->source, args, NULL); Py_XDECREF(args); - if (rv == NULL || rv == Py_None || !PyBytes_Check(rv)) { - Py_XDECREF(rv); - Py_XDECREF(collected_bytes); - return NULL; - } PyObject *rv_bytearray = PyByteArray_FromObject(rv); if (rv_bytearray == NULL) { - Py_XDECREF(collected_bytes); + Py_DECREF(collected_bytes); Py_XDECREF(rv); return NULL; } @@ -660,9 +659,11 @@ PyObject *node_get_text(Node *self, void *Py_UNUSED(payload)) { } result = PyObject_GetItem(collected_bytes, slice); Py_DECREF(slice); - Py_XDECREF(collected_bytes); + Py_DECREF(collected_bytes); } - return result; + PyObject *bytes_result = PyBytes_FromObject(result); + Py_DECREF(result); + return bytes_result; } Py_hash_t node_hash(Node *self) {