Skip to content

Update Auto-Sync to Python 3.13 and tree-sitter-py 24.0 #2705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: next
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions suite/auto-sync/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ name = "autosync"
version = "0.1.0"
dependencies = [
"termcolor >= 2.3.0",
"tree_sitter == 0.22.3",
"tree-sitter-cpp == 0.22.3",
"tree_sitter == 0.24.0",
"tree-sitter-cpp == 0.23.4",
"black >= 24.3.0",
"usort >= 1.0.8",
"setuptools >= 69.2.0",
Expand Down
3 changes: 1 addition & 2 deletions suite/auto-sync/src/autosync/cpptranslator/Configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,4 @@ def ts_set_cpp_language(self) -> None:

def init_parser(self) -> None:
log.debug("Init parser")
self.parser = Parser()
self.parser.set_language(self.ts_cpp_lang)
self.parser = Parser(self.ts_cpp_lang)
18 changes: 10 additions & 8 deletions suite/auto-sync/src/autosync/cpptranslator/CppTranslator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
run_clang_format,
)
from autosync.cpptranslator.patches.isUInt import IsUInt
from autosync.cpptranslator.tree_sitter_compatibility import query_captures_22_3


class Translator:
Expand Down Expand Up @@ -361,7 +362,7 @@ def init_patches(self):
def parse(self, src_path: Path) -> None:
self.read_src_file(src_path)
log.debug("Parse source code")
self.tree = self.parser.parse(self.src, keep_text=True)
self.tree = self.parser.parse(self.src)

def patch_src(self, p_list: [(bytes, Node)]) -> None:
if len(p_list) == 0:
Expand Down Expand Up @@ -391,7 +392,7 @@ def patch_src(self, p_list: [(bytes, Node)]) -> None:
old_end_point=old_end_point,
new_end_point=(old_end_point[0], old_end_point[1] + d),
)
self.tree = self.parser.parse(new_src, self.tree, keep_text=True)
self.tree = self.parser.parse(new_src, self.tree)

def apply_patch(self, patch: Patch) -> bool:
"""Tests if the given patch should be applied for the current architecture or file."""
Expand Down Expand Up @@ -435,7 +436,7 @@ def translate(self) -> None:
# Here we bundle these captures together.
query: Query = self.ts_cpp_lang.query(pattern)
captures_bundle: [[(Node, str)]] = list()
for q in query.captures(self.tree.root_node):
for q in query_captures_22_3(query, self.tree.root_node):
if q[1] == patch.get_main_capture_name():
# The main capture the patch is looking for.
captures_bundle.append([q])
Expand All @@ -453,8 +454,6 @@ def translate(self) -> None:
cb: [(Node, str)]
for cb in captures_bundle:
patch_kwargs = self.get_patch_kwargs(patch)
patch_kwargs["tree"] = self.tree
patch_kwargs["ts_cpp_lang"] = self.ts_cpp_lang
bytes_patch: bytes = patch.get_patch(cb, self.src, **patch_kwargs)
p_list.append((bytes_patch, cb[0][0]))
self.patch_src(p_list)
Expand All @@ -480,9 +479,12 @@ def collect_template_instances(self):
self.template_collector.collect()

def get_patch_kwargs(self, patch):
if isinstance(patch, Includes):
return {"filename": self.current_src_path_in.name}
return dict()
default_kwargs = dict()
default_kwargs["tree"] = self.tree
default_kwargs["ts_cpp_lang"] = self.ts_cpp_lang
if isinstance(patch, Includes) and self.current_src_path_in:
default_kwargs["filename"] = self.current_src_path_in.name
return default_kwargs

def remark_manual_files(self) -> None:
manual_edited = self.conf["manually_edited_files"]
Expand Down
2 changes: 1 addition & 1 deletion suite/auto-sync/src/autosync/cpptranslator/Differ.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def parse_file(self, file: Path) -> dict[str:Node]:
with open(file) as f:
content = bytes(f.read(), "utf8")

tree: Tree = self.parser.parse(content, keep_text=True)
tree: Tree = self.parser.parse(content)

node_types_to_diff = [
n["node_type"] for n in self.conf_general["nodes_to_diff"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tree_sitter import Language, Node, Parser, Query

from autosync.cpptranslator.patches.Helper import get_text
from autosync.cpptranslator.tree_sitter_compatibility import query_captures_22_3


class TemplateRefInstance:
Expand Down Expand Up @@ -105,7 +106,7 @@ def collect(self):
src = x["content"]
log.debug(f"Search for template references in {path}")

tree = self.parser.parse(src, keep_text=True)
tree = self.parser.parse(src)
query: Query = self.lang_cpp.query(self.get_template_pattern())
capture_bundles = self.get_capture_bundles(query, tree)

Expand Down Expand Up @@ -278,8 +279,8 @@ def read_files(self):

@staticmethod
def get_capture_bundles(query, tree):
captures_bundle: [[(Node, str)]] = list()
for q in query.captures(tree.root_node):
captures_bundle: list[list[tuple[Node, str]]] = list()
for q in query_captures_22_3(query, tree.root_node):
if q[1] == "templ_ref":
captures_bundle.append([q])
else:
Expand Down
26 changes: 15 additions & 11 deletions suite/auto-sync/src/autosync/cpptranslator/Tests/test_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import unittest

from pathlib import Path
from tree_sitter import Node, Query

from autosync.cpptranslator import CppTranslator
Expand Down Expand Up @@ -81,6 +82,7 @@
from autosync.cpptranslator.patches.BadConditionCode import BadConditionCode
from autosync.Helper import get_path
from autosync.cpptranslator.patches.isUInt import IsUInt
from autosync.cpptranslator.tree_sitter_compatibility import query_captures_22_3


class TestPatches(unittest.TestCase):
Expand All @@ -94,14 +96,18 @@ def setUpClass(cls):
configurator.get_parser(), configurator.get_cpp_lang(), [], []
)

def check_patching_result(self, patch, syntax, expected, filename=""):
def check_patching_result(
self, patch, syntax, expected, filename: Path = None, tree: dict = None
):
kwargs = self.translator.get_patch_kwargs(patch)
if filename:
kwargs = {"filename": filename}
else:
kwargs = self.translator.get_patch_kwargs(patch)
kwargs["filename"] = filename

query: Query = self.ts_cpp_lang.query(patch.get_search_pattern())
tree = self.parser.parse(syntax)
kwargs["tree"] = tree
captures_bundle: [[(Node, str)]] = list()
for q in query.captures(self.parser.parse(syntax, keep_text=True).root_node):
for q in query_captures_22_3(query, tree.root_node):
if q[1] == patch.get_main_capture_name():
captures_bundle.append([q])
else:
Expand Down Expand Up @@ -369,7 +375,7 @@ def test_includes(self):
b"#include <stdlib.h>\n"
b"#include <capstone/platform.h>\n\n"
b"test_output",
"filename",
filename=Path("filename"),
)

def test_inlinetostaticinline(self):
Expand Down Expand Up @@ -542,11 +548,6 @@ def test_stifeaturebits(self):
b"ARCH_getFeatureBits(Inst->csh->mode, ARCH::FLAG)",
)

def test_stifeaturebits(self):
patch = SubtargetInfoParam(0)
syntax = b"void function(MCSubtargetInfo &STI);"
self.check_patching_result(patch, syntax, b"()")

def test_streamoperation(self):
patch = StreamOperations(0)
syntax = b"{ OS << 'a'; }"
Expand All @@ -568,6 +569,9 @@ def test_streamoperation(self):
b'SStream_concat0(OS, "cccc");',
)

syntax = b"{ int y = 1; int x = 1; OS << x; }"
self.check_patching_result(patch, syntax, b"printInt32(OS, x);")

def test_templatedeclaration(self):
patch = TemplateDeclaration(0, self.template_collector)
syntax = b"template<A, B> void tfunction();"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def get_main_capture_name(self) -> str:
return "array_bit_cast"

def get_patch(self, captures: [(Node, str)], src: bytes, **kwargs) -> bytes:
arr_name: bytes = captures[1][0].text
c1 = captures[1][0]
c4 = captures[4][0]
arr_name: bytes = get_text(src, c1.start_byte, c1.end_byte)
array_type: Node = captures[3][0]
cast_target: bytes = captures[4][0].text.strip(b"()")
array_templ_args: bytes = (
array_type.named_children[0]
.named_children[1]
.named_children[1]
.text.strip(b"<>")
)
cast_target: bytes = get_text(src, c4.start_byte, c4.end_byte).strip(b"()")
named_child = array_type.named_children[0].named_children[1].named_children[1]
array_templ_args: bytes = get_text(
src, named_child.start_byte, named_child.end_byte
).strip(b"<>")
arr_type = array_templ_args.split(b",")[0]
arr_len = array_templ_args.split(b",")[1]
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,23 @@ def get_patch(self, captures: [(Node, str)], src: bytes, **kwargs) -> bytes:
(declaration (
(primitive_type) @typ
(init_declarator
(identifier) @ident (#eq? @ident "{last_op_text.decode('utf8')}")
(identifier) @ident (#eq? @ident "{last_op_text.decode("utf8")}")
)
)) @decl
"""
query = kwargs["ts_cpp_lang"].query(queue_str)
query.end_byte_for_pattern(last_op.start_byte)
root_node = kwargs["tree"].root_node
query_result = list(
filter(
lambda x: "typ" in x[1],
query.matches(root_node, end_byte=last_op.start_byte),
query.matches(root_node),
)
)
if len(query_result) == 0:
res += b"SStream_concat0(" + s_name + b", " + last_op_text + b");"
else:
cap = query_result[-1]
typ = get_text_from_node(src, cap[1]["typ"])
typ = get_text_from_node(src, query_result[0][1]["typ"][-1])
match typ:
case b"int":
res += b"printInt32(" + s_name + b", " + last_op_text + b");"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from tree_sitter import Node, Query


# Queries for the given pattern and converts the query back to the tree-siter v22.3 format.
# Which is: A list of tuples where the first element is the
# Node of the capture and the second one is the name.
def query_captures_22_3(query: Query, node: Node) -> list[tuple[Node, str]]:
result = list()
captures = query.captures(node)
for name, nodes in captures.items():
print(f"{name}: {len(nodes)}")
# Captures are no longer sorted by start point.
captures_sorted = dict()
nodes: list[Node]
for name, nodes in captures.items():
captures_sorted[name] = sorted(nodes, key=lambda n: n.start_point)
while len(captures_sorted) != 0:
for name, nodes in captures_sorted.items():
node = nodes.pop(0)
result.append((node, name))
captures_sorted = {k: l for k, l in captures_sorted.items() if len(l) != 0}
return result
Loading