Skip to content

Commit 09ac3ba

Browse files
authored
[mypyc] Optimize str.startswith and str.endswith with tuple argument (#18678)
1 parent 0808624 commit 09ac3ba

File tree

7 files changed

+157
-9
lines changed

7 files changed

+157
-9
lines changed

mypyc/doc/str_operations.rst

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Methods
3131
* ``s.encode(encoding: str)``
3232
* ``s.encode(encoding: str, errors: str)``
3333
* ``s1.endswith(s2: str)``
34+
* ``s1.endswith(t: tuple[str, ...])``
3435
* ``s.join(x: Iterable)``
3536
* ``s.removeprefix(prefix: str)``
3637
* ``s.removesuffix(suffix: str)``
@@ -45,6 +46,7 @@ Methods
4546
* ``s.splitlines()``
4647
* ``s.splitlines(keepends: bool)``
4748
* ``s1.startswith(s2: str)``
49+
* ``s1.startswith(t: tuple[str, ...])``
4850

4951
.. note::
5052

mypyc/lib-rt/CPy.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,8 @@ PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
725725
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
726726
PyObject *CPyStr_Append(PyObject *o1, PyObject *o2);
727727
PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);
728-
bool CPyStr_Startswith(PyObject *self, PyObject *subobj);
729-
bool CPyStr_Endswith(PyObject *self, PyObject *subobj);
728+
int CPyStr_Startswith(PyObject *self, PyObject *subobj);
729+
int CPyStr_Endswith(PyObject *self, PyObject *subobj);
730730
PyObject *CPyStr_Removeprefix(PyObject *self, PyObject *prefix);
731731
PyObject *CPyStr_Removesuffix(PyObject *self, PyObject *suffix);
732732
bool CPyStr_IsTrue(PyObject *obj);

mypyc/lib-rt/str_ops.c

+38-2
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,51 @@ PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr,
161161
return PyUnicode_Replace(str, old_substr, new_substr, temp_max_replace);
162162
}
163163

164-
bool CPyStr_Startswith(PyObject *self, PyObject *subobj) {
164+
int CPyStr_Startswith(PyObject *self, PyObject *subobj) {
165165
Py_ssize_t start = 0;
166166
Py_ssize_t end = PyUnicode_GET_LENGTH(self);
167+
if (PyTuple_Check(subobj)) {
168+
Py_ssize_t i;
169+
for (i = 0; i < PyTuple_GET_SIZE(subobj); i++) {
170+
PyObject *substring = PyTuple_GET_ITEM(subobj, i);
171+
if (!PyUnicode_Check(substring)) {
172+
PyErr_Format(PyExc_TypeError,
173+
"tuple for startswith must only contain str, "
174+
"not %.100s",
175+
Py_TYPE(substring)->tp_name);
176+
return -1;
177+
}
178+
int result = PyUnicode_Tailmatch(self, substring, start, end, -1);
179+
if (result) {
180+
return 1;
181+
}
182+
}
183+
return 0;
184+
}
167185
return PyUnicode_Tailmatch(self, subobj, start, end, -1);
168186
}
169187

170-
bool CPyStr_Endswith(PyObject *self, PyObject *subobj) {
188+
int CPyStr_Endswith(PyObject *self, PyObject *subobj) {
171189
Py_ssize_t start = 0;
172190
Py_ssize_t end = PyUnicode_GET_LENGTH(self);
191+
if (PyTuple_Check(subobj)) {
192+
Py_ssize_t i;
193+
for (i = 0; i < PyTuple_GET_SIZE(subobj); i++) {
194+
PyObject *substring = PyTuple_GET_ITEM(subobj, i);
195+
if (!PyUnicode_Check(substring)) {
196+
PyErr_Format(PyExc_TypeError,
197+
"tuple for endswith must only contain str, "
198+
"not %.100s",
199+
Py_TYPE(substring)->tp_name);
200+
return -1;
201+
}
202+
int result = PyUnicode_Tailmatch(self, substring, start, end, 1);
203+
if (result) {
204+
return 1;
205+
}
206+
}
207+
return 0;
208+
}
173209
return PyUnicode_Tailmatch(self, subobj, start, end, 1);
174210
}
175211

mypyc/primitives/str_ops.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
object_rprimitive,
1616
pointer_rprimitive,
1717
str_rprimitive,
18+
tuple_rprimitive,
1819
)
1920
from mypyc.primitives.registry import (
2021
ERR_NEG_INT,
@@ -104,20 +105,42 @@
104105
method_op(
105106
name="startswith",
106107
arg_types=[str_rprimitive, str_rprimitive],
107-
return_type=bool_rprimitive,
108+
return_type=c_int_rprimitive,
108109
c_function_name="CPyStr_Startswith",
110+
truncated_type=bool_rprimitive,
109111
error_kind=ERR_NEVER,
110112
)
111113

114+
# str.startswith(tuple) (return -1/0/1)
115+
method_op(
116+
name="startswith",
117+
arg_types=[str_rprimitive, tuple_rprimitive],
118+
return_type=c_int_rprimitive,
119+
c_function_name="CPyStr_Startswith",
120+
truncated_type=bool_rprimitive,
121+
error_kind=ERR_NEG_INT,
122+
)
123+
112124
# str.endswith(str)
113125
method_op(
114126
name="endswith",
115127
arg_types=[str_rprimitive, str_rprimitive],
116-
return_type=bool_rprimitive,
128+
return_type=c_int_rprimitive,
117129
c_function_name="CPyStr_Endswith",
130+
truncated_type=bool_rprimitive,
118131
error_kind=ERR_NEVER,
119132
)
120133

134+
# str.endswith(tuple) (return -1/0/1)
135+
method_op(
136+
name="endswith",
137+
arg_types=[str_rprimitive, tuple_rprimitive],
138+
return_type=c_int_rprimitive,
139+
c_function_name="CPyStr_Endswith",
140+
truncated_type=bool_rprimitive,
141+
error_kind=ERR_NEG_INT,
142+
)
143+
121144
# str.removeprefix(str)
122145
method_op(
123146
name="removeprefix",

mypyc/test-data/fixtures/ir.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def strip (self, item: str) -> str: pass
109109
def join(self, x: Iterable[str]) -> str: pass
110110
def format(self, *args: Any, **kwargs: Any) -> str: ...
111111
def upper(self) -> str: ...
112-
def startswith(self, x: str, start: int=..., end: int=...) -> bool: ...
113-
def endswith(self, x: str, start: int=..., end: int=...) -> bool: ...
112+
def startswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
113+
def endswith(self, x: Union[str, Tuple[str, ...]], start: int=..., end: int=...) -> bool: ...
114114
def replace(self, old: str, new: str, maxcount: int=...) -> str: ...
115115
def encode(self, encoding: str=..., errors: str=...) -> bytes: ...
116116
def removeprefix(self, prefix: str, /) -> str: ...

mypyc/test-data/irbuild-str.test

+67
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,73 @@ L4:
137137
L5:
138138
unreachable
139139

140+
[case testStrStartswithEndswithTuple]
141+
from typing import Tuple
142+
143+
def do_startswith(s1: str, s2: Tuple[str, ...]) -> bool:
144+
return s1.startswith(s2)
145+
146+
def do_endswith(s1: str, s2: Tuple[str, ...]) -> bool:
147+
return s1.endswith(s2)
148+
149+
def do_tuple_literal_args(s1: str) -> None:
150+
x = s1.startswith(("a", "b"))
151+
y = s1.endswith(("a", "b"))
152+
[out]
153+
def do_startswith(s1, s2):
154+
s1 :: str
155+
s2 :: tuple
156+
r0 :: i32
157+
r1 :: bit
158+
r2 :: bool
159+
L0:
160+
r0 = CPyStr_Startswith(s1, s2)
161+
r1 = r0 >= 0 :: signed
162+
r2 = truncate r0: i32 to builtins.bool
163+
return r2
164+
def do_endswith(s1, s2):
165+
s1 :: str
166+
s2 :: tuple
167+
r0 :: i32
168+
r1 :: bit
169+
r2 :: bool
170+
L0:
171+
r0 = CPyStr_Endswith(s1, s2)
172+
r1 = r0 >= 0 :: signed
173+
r2 = truncate r0: i32 to builtins.bool
174+
return r2
175+
def do_tuple_literal_args(s1):
176+
s1, r0, r1 :: str
177+
r2 :: tuple[str, str]
178+
r3 :: object
179+
r4 :: i32
180+
r5 :: bit
181+
r6, x :: bool
182+
r7, r8 :: str
183+
r9 :: tuple[str, str]
184+
r10 :: object
185+
r11 :: i32
186+
r12 :: bit
187+
r13, y :: bool
188+
L0:
189+
r0 = 'a'
190+
r1 = 'b'
191+
r2 = (r0, r1)
192+
r3 = box(tuple[str, str], r2)
193+
r4 = CPyStr_Startswith(s1, r3)
194+
r5 = r4 >= 0 :: signed
195+
r6 = truncate r4: i32 to builtins.bool
196+
x = r6
197+
r7 = 'a'
198+
r8 = 'b'
199+
r9 = (r7, r8)
200+
r10 = box(tuple[str, str], r9)
201+
r11 = CPyStr_Endswith(s1, r10)
202+
r12 = r11 >= 0 :: signed
203+
r13 = truncate r11: i32 to builtins.bool
204+
y = r13
205+
return 1
206+
140207
[case testStrToBool]
141208
def is_true(x: str) -> bool:
142209
if x:

mypyc/test-data/run-strings.test

+21-1
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,20 @@ def eq(x: str) -> int:
2020
return 2
2121
def match(x: str, y: str) -> Tuple[bool, bool]:
2222
return (x.startswith(y), x.endswith(y))
23+
def match_tuple(x: str, y: Tuple[str, ...]) -> Tuple[bool, bool]:
24+
return (x.startswith(y), x.endswith(y))
25+
def match_tuple_literal_args(x: str, y: str, z: str) -> Tuple[bool, bool]:
26+
return (x.startswith((y, z)), x.endswith((y, z)))
2327
def remove_prefix_suffix(x: str, y: str) -> Tuple[str, str]:
2428
return (x.removeprefix(y), x.removesuffix(y))
2529

2630
[file driver.py]
27-
from native import f, g, tostr, booltostr, concat, eq, match, remove_prefix_suffix
31+
from native import (
32+
f, g, tostr, booltostr, concat, eq, match, match_tuple,
33+
match_tuple_literal_args, remove_prefix_suffix
34+
)
2835
import sys
36+
from testutil import assertRaises
2937

3038
assert f() == 'some string'
3139
assert f() is sys.intern('some string')
@@ -45,6 +53,18 @@ assert match('abc', '') == (True, True)
4553
assert match('abc', 'a') == (True, False)
4654
assert match('abc', 'c') == (False, True)
4755
assert match('', 'abc') == (False, False)
56+
assert match_tuple('abc', ('d', 'e')) == (False, False)
57+
assert match_tuple('abc', ('a', 'c')) == (True, True)
58+
assert match_tuple('abc', ('a',)) == (True, False)
59+
assert match_tuple('abc', ('c',)) == (False, True)
60+
assert match_tuple('abc', ('x', 'y', 'z')) == (False, False)
61+
assert match_tuple('abc', ('x', 'y', 'z', 'a', 'c')) == (True, True)
62+
with assertRaises(TypeError, "tuple for startswith must only contain str"):
63+
assert match_tuple('abc', (None,))
64+
with assertRaises(TypeError, "tuple for endswith must only contain str"):
65+
assert match_tuple('abc', ('a', None))
66+
assert match_tuple_literal_args('abc', 'z', 'a') == (True, False)
67+
assert match_tuple_literal_args('abc', 'z', 'c') == (False, True)
4868

4969
assert remove_prefix_suffix('', '') == ('', '')
5070
assert remove_prefix_suffix('abc', 'a') == ('bc', 'abc')

0 commit comments

Comments
 (0)