Skip to content

[backport 2.3.x] String dtype: use ObjectEngine for indexing for now correctness over performance (#60329) #60453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pandas/_libs/index.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class MaskedUInt16Engine(MaskedIndexEngine): ...
class MaskedUInt8Engine(MaskedIndexEngine): ...
class MaskedBoolEngine(MaskedUInt8Engine): ...

class StringObjectEngine(ObjectEngine):
def __init__(self, values: object, na_value) -> None: ...

class BaseMultiIndexCodesEngine:
levels: list[np.ndarray]
offsets: np.ndarray # ndarray[uint64_t, ndim=1]
Expand Down
26 changes: 26 additions & 0 deletions pandas/_libs/index.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,32 @@ cdef class ObjectEngine(IndexEngine):
return loc


cdef class StringObjectEngine(ObjectEngine):

cdef:
object na_value
bint uses_na

def __init__(self, ndarray values, na_value):
super().__init__(values)
self.na_value = na_value
self.uses_na = na_value is C_NA

cdef bint _checknull(self, object val):
if self.uses_na:
return val is C_NA
else:
return util.is_nan(val)

cdef _check_type(self, object val):
if isinstance(val, str):
return val
elif self._checknull(val):
return self.na_value
else:
raise KeyError(val)


cdef class DatetimeEngine(Int64Engine):

cdef:
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,8 @@ def _engine(
# error: Item "ExtensionArray" of "Union[ExtensionArray,
# ndarray[Any, Any]]" has no attribute "_ndarray" [union-attr]
target_values = self._data._ndarray # type: ignore[union-attr]
elif is_string_dtype(self.dtype) and not is_object_dtype(self.dtype):
return libindex.StringObjectEngine(target_values, self.dtype.na_value) # type: ignore[union-attr]

# error: Argument 1 to "ExtensionEngine" has incompatible type
# "ndarray[Any, Any]"; expected "ExtensionArray"
Expand Down Expand Up @@ -6133,7 +6135,6 @@ def _should_fallback_to_positional(self) -> bool:
def get_indexer_non_unique(
self, target
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
target = ensure_index(target)
target = self._maybe_cast_listlike_indexer(target)

if not self._should_compare(target) and not self._should_partial_index(target):
Expand Down
104 changes: 93 additions & 11 deletions pandas/tests/indexes/string/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,51 @@
import pandas._testing as tm


def _isnan(val):
try:
return val is not pd.NA and np.isnan(val)
except TypeError:
return False


class TestGetLoc:
def test_get_loc(self, any_string_dtype):
index = Index(["a", "b", "c"], dtype=any_string_dtype)
assert index.get_loc("b") == 1

def test_get_loc_raises(self, any_string_dtype):
index = Index(["a", "b", "c"], dtype=any_string_dtype)
with pytest.raises(KeyError, match="d"):
index.get_loc("d")

def test_get_loc_invalid_value(self, any_string_dtype):
index = Index(["a", "b", "c"], dtype=any_string_dtype)
with pytest.raises(KeyError, match="1"):
index.get_loc(1)

def test_get_loc_non_unique(self, any_string_dtype):
index = Index(["a", "b", "a"], dtype=any_string_dtype)
result = index.get_loc("a")
expected = np.array([True, False, True])
tm.assert_numpy_array_equal(result, expected)

def test_get_loc_non_missing(self, any_string_dtype, nulls_fixture):
index = Index(["a", "b", "c"], dtype=any_string_dtype)
with pytest.raises(KeyError):
index.get_loc(nulls_fixture)

def test_get_loc_missing(self, any_string_dtype, nulls_fixture):
index = Index(["a", "b", nulls_fixture], dtype=any_string_dtype)
if any_string_dtype == "string" and (
(any_string_dtype.na_value is pd.NA and nulls_fixture is not pd.NA)
or (_isnan(any_string_dtype.na_value) and not _isnan(nulls_fixture))
):
with pytest.raises(KeyError):
index.get_loc(nulls_fixture)
else:
assert index.get_loc(nulls_fixture) == 2


class TestGetIndexer:
@pytest.mark.parametrize(
"method,expected",
Expand Down Expand Up @@ -41,23 +86,60 @@ def test_get_indexer_strings_raises(self, any_string_dtype):
["a", "b", "c", "d"], method="pad", tolerance=[2, 2, 2, 2]
)

@pytest.mark.parametrize("null", [None, np.nan, float("nan"), pd.NA])
def test_get_indexer_missing(self, any_string_dtype, null, using_infer_string):
# NaT and Decimal("NaN") from null_fixture are not supported for string dtype
index = Index(["a", "b", null], dtype=any_string_dtype)
result = index.get_indexer(["a", null, "c"])
if using_infer_string:
expected = np.array([0, 2, -1], dtype=np.intp)
elif any_string_dtype == "string" and (
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
):
expected = np.array([0, -1, -1], dtype=np.intp)
else:
expected = np.array([0, 2, -1], dtype=np.intp)

class TestGetIndexerNonUnique:
@pytest.mark.xfail(reason="TODO(infer_string)", strict=False)
def test_get_indexer_non_unique_nas(self, any_string_dtype, nulls_fixture):
index = Index(["a", "b", None], dtype=any_string_dtype)
indexer, missing = index.get_indexer_non_unique([nulls_fixture])
tm.assert_numpy_array_equal(result, expected)

expected_indexer = np.array([2], dtype=np.intp)
expected_missing = np.array([], dtype=np.intp)

class TestGetIndexerNonUnique:
@pytest.mark.parametrize("null", [None, np.nan, float("nan"), pd.NA])
def test_get_indexer_non_unique_nas(
self, any_string_dtype, null, using_infer_string
):
index = Index(["a", "b", null], dtype=any_string_dtype)
indexer, missing = index.get_indexer_non_unique(["a", null])

if using_infer_string:
expected_indexer = np.array([0, 2], dtype=np.intp)
expected_missing = np.array([], dtype=np.intp)
elif any_string_dtype == "string" and (
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
):
expected_indexer = np.array([0, -1], dtype=np.intp)
expected_missing = np.array([1], dtype=np.intp)
else:
expected_indexer = np.array([0, 2], dtype=np.intp)
expected_missing = np.array([], dtype=np.intp)
tm.assert_numpy_array_equal(indexer, expected_indexer)
tm.assert_numpy_array_equal(missing, expected_missing)

# actually non-unique
index = Index(["a", None, "b", None], dtype=any_string_dtype)
indexer, missing = index.get_indexer_non_unique([nulls_fixture])

expected_indexer = np.array([1, 3], dtype=np.intp)
index = Index(["a", null, "b", null], dtype=any_string_dtype)
indexer, missing = index.get_indexer_non_unique(["a", null])

if using_infer_string:
expected_indexer = np.array([0, 1, 3], dtype=np.intp)
elif any_string_dtype == "string" and (
(any_string_dtype.na_value is pd.NA and null is not pd.NA)
or (_isnan(any_string_dtype.na_value) and not _isnan(null))
):
pass
else:
expected_indexer = np.array([0, 1, 3], dtype=np.intp)
tm.assert_numpy_array_equal(indexer, expected_indexer)
tm.assert_numpy_array_equal(missing, expected_missing)

Expand Down
Loading