Skip to content

Commit e7d54a5

Browse files
BUG/TST (string dtype): fix and update tests for Stata IO (#60130)
1 parent 0db1f53 commit e7d54a5

File tree

2 files changed

+30
-26
lines changed

2 files changed

+30
-26
lines changed

pandas/io/stata.py

+5
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,11 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
569569
if getattr(data[col].dtype, "numpy_dtype", None) is not None:
570570
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
571571
elif is_string_dtype(data[col].dtype):
572+
# TODO could avoid converting string dtype to object here,
573+
# but handle string dtype in _encode_strings
572574
data[col] = data[col].astype("object")
575+
# generate_table checks for None values
576+
data.loc[data[col].isna(), col] = None
573577

574578
dtype = data[col].dtype
575579
empty_df = data.shape[0] == 0
@@ -2725,6 +2729,7 @@ def _encode_strings(self) -> None:
27252729
continue
27262730
column = self.data[col]
27272731
dtype = column.dtype
2732+
# TODO could also handle string dtype here specifically
27282733
if dtype.type is np.object_:
27292734
inferred_dtype = infer_dtype(column, skipna=True)
27302735
if not ((inferred_dtype == "string") or len(column) == 0):

pandas/tests/io/test_stata.py

+25-26
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
import numpy as np
1212
import pytest
1313

14-
from pandas._config import using_string_dtype
15-
1614
import pandas.util._test_decorators as td
1715

1816
import pandas as pd
@@ -435,9 +433,8 @@ def test_write_dta6(self, datapath, temp_file):
435433
check_index_type=False,
436434
)
437435

438-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
439436
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
440-
def test_read_write_dta10(self, version, temp_file):
437+
def test_read_write_dta10(self, version, temp_file, using_infer_string):
441438
original = DataFrame(
442439
data=[["string", "object", 1, 1.1, np.datetime64("2003-12-25")]],
443440
columns=["string", "object", "integer", "floating", "datetime"],
@@ -451,9 +448,11 @@ def test_read_write_dta10(self, version, temp_file):
451448
original.to_stata(path, convert_dates={"datetime": "tc"}, version=version)
452449
written_and_read_again = self.read_dta(path)
453450

454-
expected = original[:]
451+
expected = original.copy()
455452
# "tc" convert_dates means we store in ms
456453
expected["datetime"] = expected["datetime"].astype("M8[ms]")
454+
if using_infer_string:
455+
expected["object"] = expected["object"].astype("str")
457456

458457
tm.assert_frame_equal(
459458
written_and_read_again.set_index("index"),
@@ -1276,7 +1275,6 @@ def test_categorical_ordering(self, file, datapath):
12761275
assert parsed[col].cat.ordered
12771276
assert not parsed_unordered[col].cat.ordered
12781277

1279-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
12801278
@pytest.mark.filterwarnings("ignore::UserWarning")
12811279
@pytest.mark.parametrize(
12821280
"file",
@@ -1340,6 +1338,10 @@ def _convert_categorical(from_frame: DataFrame) -> DataFrame:
13401338
if cat.categories.dtype == object:
13411339
categories = pd.Index._with_infer(cat.categories._values)
13421340
cat = cat.set_categories(categories)
1341+
elif cat.categories.dtype == "string" and len(cat.categories) == 0:
1342+
# if the read categories are empty, it comes back as object dtype
1343+
categories = cat.categories.astype(object)
1344+
cat = cat.set_categories(categories)
13431345
from_frame[col] = cat
13441346
return from_frame
13451347

@@ -1369,7 +1371,6 @@ def test_iterator(self, datapath):
13691371
from_chunks = pd.concat(itr)
13701372
tm.assert_frame_equal(parsed, from_chunks)
13711373

1372-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
13731374
@pytest.mark.filterwarnings("ignore::UserWarning")
13741375
@pytest.mark.parametrize(
13751376
"file",
@@ -1674,12 +1675,11 @@ def test_inf(self, infval, temp_file):
16741675
path = temp_file
16751676
df.to_stata(path)
16761677

1677-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
16781678
def test_path_pathlib(self):
16791679
df = DataFrame(
16801680
1.1 * np.arange(120).reshape((30, 4)),
1681-
columns=pd.Index(list("ABCD"), dtype=object),
1682-
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
1681+
columns=pd.Index(list("ABCD")),
1682+
index=pd.Index([f"i-{i}" for i in range(30)]),
16831683
)
16841684
df.index.name = "index"
16851685
reader = lambda x: read_stata(x).set_index("index")
@@ -1699,13 +1699,12 @@ def test_value_labels_iterator(self, write_index, temp_file):
16991699
value_labels = dta_iter.value_labels()
17001700
assert value_labels == {"A": {0: "A", 1: "B", 2: "C", 3: "E"}}
17011701

1702-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
17031702
def test_set_index(self, temp_file):
17041703
# GH 17328
17051704
df = DataFrame(
17061705
1.1 * np.arange(120).reshape((30, 4)),
1707-
columns=pd.Index(list("ABCD"), dtype=object),
1708-
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
1706+
columns=pd.Index(list("ABCD")),
1707+
index=pd.Index([f"i-{i}" for i in range(30)]),
17091708
)
17101709
df.index.name = "index"
17111710
path = temp_file
@@ -1733,9 +1732,9 @@ def test_date_parsing_ignores_format_details(self, column, datapath):
17331732
formatted = df.loc[0, column + "_fmt"]
17341733
assert unformatted == formatted
17351734

1736-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
1735+
# @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
17371736
@pytest.mark.parametrize("byteorder", ["little", "big"])
1738-
def test_writer_117(self, byteorder, temp_file):
1737+
def test_writer_117(self, byteorder, temp_file, using_infer_string):
17391738
original = DataFrame(
17401739
data=[
17411740
[
@@ -1802,6 +1801,9 @@ def test_writer_117(self, byteorder, temp_file):
18021801
expected = original[:]
18031802
# "tc" for convert_dates means we store with "ms" resolution
18041803
expected["datetime"] = expected["datetime"].astype("M8[ms]")
1804+
if using_infer_string:
1805+
# object dtype (with only strings/None) comes back as string dtype
1806+
expected["object"] = expected["object"].astype("str")
18051807

18061808
tm.assert_frame_equal(
18071809
written_and_read_again.set_index("index"),
@@ -1845,15 +1847,14 @@ def test_invalid_date_conversion(self, temp_file):
18451847
with pytest.raises(ValueError, match=msg):
18461848
original.to_stata(path, convert_dates={"wrong_name": "tc"})
18471849

1848-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
18491850
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
18501851
def test_nonfile_writing(self, version, temp_file):
18511852
# GH 21041
18521853
bio = io.BytesIO()
18531854
df = DataFrame(
18541855
1.1 * np.arange(120).reshape((30, 4)),
1855-
columns=pd.Index(list("ABCD"), dtype=object),
1856-
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
1856+
columns=pd.Index(list("ABCD")),
1857+
index=pd.Index([f"i-{i}" for i in range(30)]),
18571858
)
18581859
df.index.name = "index"
18591860
path = temp_file
@@ -1864,13 +1865,12 @@ def test_nonfile_writing(self, version, temp_file):
18641865
reread = read_stata(path, index_col="index")
18651866
tm.assert_frame_equal(df, reread)
18661867

1867-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
18681868
def test_gzip_writing(self, temp_file):
18691869
# writing version 117 requires seek and cannot be used with gzip
18701870
df = DataFrame(
18711871
1.1 * np.arange(120).reshape((30, 4)),
1872-
columns=pd.Index(list("ABCD"), dtype=object),
1873-
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
1872+
columns=pd.Index(list("ABCD")),
1873+
index=pd.Index([f"i-{i}" for i in range(30)]),
18741874
)
18751875
df.index.name = "index"
18761876
path = temp_file
@@ -1907,8 +1907,7 @@ def test_unicode_dta_118_119(self, file, datapath):
19071907

19081908
tm.assert_frame_equal(unicode_df, expected)
19091909

1910-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
1911-
def test_mixed_string_strl(self, temp_file):
1910+
def test_mixed_string_strl(self, temp_file, using_infer_string):
19121911
# GH 23633
19131912
output = [{"mixed": "string" * 500, "number": 0}, {"mixed": None, "number": 1}]
19141913
output = DataFrame(output)
@@ -1925,6 +1924,8 @@ def test_mixed_string_strl(self, temp_file):
19251924
output.to_stata(path, write_index=False, convert_strl=["mixed"], version=117)
19261925
reread = read_stata(path)
19271926
expected = output.fillna("")
1927+
if using_infer_string:
1928+
expected["mixed"] = expected["mixed"].astype("str")
19281929
tm.assert_frame_equal(reread, expected)
19291930

19301931
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
@@ -2000,7 +2001,6 @@ def test_stata_119(self, datapath):
20002001
reader._ensure_open()
20012002
assert reader._nvar == 32999
20022003

2003-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
20042004
@pytest.mark.parametrize("version", [118, 119, None])
20052005
@pytest.mark.parametrize("byteorder", ["little", "big"])
20062006
def test_utf8_writer(self, version, byteorder, temp_file):
@@ -2348,13 +2348,12 @@ def test_iterator_errors(datapath, chunksize):
23482348
pass
23492349

23502350

2351-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
23522351
def test_iterator_value_labels(temp_file):
23532352
# GH 31544
23542353
values = ["c_label", "b_label"] + ["a_label"] * 500
23552354
df = DataFrame({f"col{k}": pd.Categorical(values, ordered=True) for k in range(2)})
23562355
df.to_stata(temp_file, write_index=False)
2357-
expected = pd.Index(["a_label", "b_label", "c_label"], dtype="object")
2356+
expected = pd.Index(["a_label", "b_label", "c_label"])
23582357
with read_stata(temp_file, chunksize=100) as reader:
23592358
for j, chunk in enumerate(reader):
23602359
for i in range(2):

0 commit comments

Comments
 (0)