From 482ba5b282e5d7aa1aac3ade4587815f677cd5c5 Mon Sep 17 00:00:00 2001
From: Joris Van den Bossche <jorisvandenbossche@gmail.com>
Date: Wed, 30 Oct 2024 08:55:32 +0100
Subject: [PATCH] BUG/TST (string dtype): fix and update tests for Stata IO

---
 pandas/io/stata.py            |  5 ++++
 pandas/tests/io/test_stata.py | 51 +++++++++++++++++------------------
 2 files changed, 30 insertions(+), 26 deletions(-)

diff --git a/pandas/io/stata.py b/pandas/io/stata.py
index 04bd1e32603f4..722e2c79c4e6a 100644
--- a/pandas/io/stata.py
+++ b/pandas/io/stata.py
@@ -569,7 +569,11 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
             if getattr(data[col].dtype, "numpy_dtype", None) is not None:
                 data[col] = data[col].astype(data[col].dtype.numpy_dtype)
             elif is_string_dtype(data[col].dtype):
+                # TODO could avoid converting string dtype to object here,
+                # but handle string dtype in _encode_strings
                 data[col] = data[col].astype("object")
+                # generate_table checks for None values
+                data.loc[data[col].isna(), col] = None
 
         dtype = data[col].dtype
         empty_df = data.shape[0] == 0
@@ -2725,6 +2729,7 @@ def _encode_strings(self) -> None:
                 continue
             column = self.data[col]
             dtype = column.dtype
+            # TODO could also handle string dtype here specifically
             if dtype.type is np.object_:
                 inferred_dtype = infer_dtype(column, skipna=True)
                 if not ((inferred_dtype == "string") or len(column) == 0):
diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py
index 9f5085ff2ad28..4b5369d61bed6 100644
--- a/pandas/tests/io/test_stata.py
+++ b/pandas/tests/io/test_stata.py
@@ -11,8 +11,6 @@
 import numpy as np
 import pytest
 
-from pandas._config import using_string_dtype
-
 import pandas.util._test_decorators as td
 
 import pandas as pd
@@ -435,9 +433,8 @@ def test_write_dta6(self, datapath, temp_file):
             check_index_type=False,
         )
 
-    @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
     @pytest.mark.parametrize("version", [114, 117, 118, 119, None])
-    def test_read_write_dta10(self, version, temp_file):
+    def test_read_write_dta10(self, version, temp_file, using_infer_string):
         original = DataFrame(
             data=[["string", "object", 1, 1.1, np.datetime64("2003-12-25")]],
             columns=["string", "object", "integer", "floating", "datetime"],
@@ -451,9 +448,11 @@ def test_read_write_dta10(self, version, temp_file):
         original.to_stata(path, convert_dates={"datetime": "tc"}, version=version)
         written_and_read_again = self.read_dta(path)
 
-        expected = original[:]
+        expected = original.copy()
         # "tc" convert_dates means we store in ms
         expected["datetime"] = expected["datetime"].astype("M8[ms]")
+        if using_infer_string:
+            expected["object"] = expected["object"].astype("str")
 
         tm.assert_frame_equal(
             written_and_read_again.set_index("index"),
@@ -1276,7 +1275,6 @@ def test_categorical_ordering(self, file, datapath):
             assert parsed[col].cat.ordered
             assert not parsed_unordered[col].cat.ordered
 
-    @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
     @pytest.mark.filterwarnings("ignore::UserWarning")
     @pytest.mark.parametrize(
         "file",
@@ -1340,6 +1338,10 @@ def _convert_categorical(from_frame: DataFrame) -> DataFrame:
                 if cat.categories.dtype == object:
                     categories = pd.Index._with_infer(cat.categories._values)
                     cat = cat.set_categories(categories)
+                elif cat.categories.dtype == "string" and len(cat.categories) == 0:
+                    # if the read categories are empty, it comes back as object dtype
+                    categories = cat.categories.astype(object)
+                    cat = cat.set_categories(categories)
                 from_frame[col] = cat
         return from_frame
 
@@ -1369,7 +1371,6 @@ def test_iterator(self, datapath):
             from_chunks = pd.concat(itr)
         tm.assert_frame_equal(parsed, from_chunks)
 
-    @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
     @pytest.mark.filterwarnings("ignore::UserWarning")
     @pytest.mark.parametrize(
         "file",
@@ -1674,12 +1675,11 @@ def test_inf(self, infval, temp_file):
             path = temp_file
             df.to_stata(path)
 
-    @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
     def test_path_pathlib(self):
         df = DataFrame(
             1.1 * np.arange(120).reshape((30, 4)),
-            columns=pd.Index(list("ABCD"), dtype=object),
-            index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
+            columns=pd.Index(list("ABCD")),
+            index=pd.Index([f"i-{i}" for i in range(30)]),
         )
         df.index.name = "index"
         reader = lambda x: read_stata(x).set_index("index")
@@ -1699,13 +1699,12 @@ def test_value_labels_iterator(self, write_index, temp_file):
             value_labels = dta_iter.value_labels()
         assert value_labels == {"A": {0: "A", 1: "B", 2: "C", 3: "E"}}
 
-    @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
     def test_set_index(self, temp_file):
         # GH 17328
         df = DataFrame(
             1.1 * np.arange(120).reshape((30, 4)),
-            columns=pd.Index(list("ABCD"), dtype=object),
-            index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
+            columns=pd.Index(list("ABCD")),
+            index=pd.Index([f"i-{i}" for i in range(30)]),
         )
         df.index.name = "index"
         path = temp_file
@@ -1733,9 +1732,9 @@ def test_date_parsing_ignores_format_details(self, column, datapath):
         formatted = df.loc[0, column + "_fmt"]
         assert unformatted == formatted
 
-    @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
+    # @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
     @pytest.mark.parametrize("byteorder", ["little", "big"])
-    def test_writer_117(self, byteorder, temp_file):
+    def test_writer_117(self, byteorder, temp_file, using_infer_string):
         original = DataFrame(
             data=[
                 [
@@ -1802,6 +1801,9 @@ def test_writer_117(self, byteorder, temp_file):
         expected = original[:]
         # "tc" for convert_dates means we store with "ms" resolution
         expected["datetime"] = expected["datetime"].astype("M8[ms]")
+        if using_infer_string:
+            # object dtype (with only strings/None) comes back as string dtype
+            expected["object"] = expected["object"].astype("str")
 
         tm.assert_frame_equal(
             written_and_read_again.set_index("index"),
@@ -1845,15 +1847,14 @@ def test_invalid_date_conversion(self, temp_file):
         with pytest.raises(ValueError, match=msg):
             original.to_stata(path, convert_dates={"wrong_name": "tc"})
 
-    @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
     @pytest.mark.parametrize("version", [114, 117, 118, 119, None])
     def test_nonfile_writing(self, version, temp_file):
         # GH 21041
         bio = io.BytesIO()
         df = DataFrame(
             1.1 * np.arange(120).reshape((30, 4)),
-            columns=pd.Index(list("ABCD"), dtype=object),
-            index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
+            columns=pd.Index(list("ABCD")),
+            index=pd.Index([f"i-{i}" for i in range(30)]),
         )
         df.index.name = "index"
         path = temp_file
@@ -1864,13 +1865,12 @@ def test_nonfile_writing(self, version, temp_file):
         reread = read_stata(path, index_col="index")
         tm.assert_frame_equal(df, reread)
 
-    @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
     def test_gzip_writing(self, temp_file):
         # writing version 117 requires seek and cannot be used with gzip
         df = DataFrame(
             1.1 * np.arange(120).reshape((30, 4)),
-            columns=pd.Index(list("ABCD"), dtype=object),
-            index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
+            columns=pd.Index(list("ABCD")),
+            index=pd.Index([f"i-{i}" for i in range(30)]),
         )
         df.index.name = "index"
         path = temp_file
@@ -1907,8 +1907,7 @@ def test_unicode_dta_118_119(self, file, datapath):
 
         tm.assert_frame_equal(unicode_df, expected)
 
-    @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
-    def test_mixed_string_strl(self, temp_file):
+    def test_mixed_string_strl(self, temp_file, using_infer_string):
         # GH 23633
         output = [{"mixed": "string" * 500, "number": 0}, {"mixed": None, "number": 1}]
         output = DataFrame(output)
@@ -1925,6 +1924,8 @@ def test_mixed_string_strl(self, temp_file):
         output.to_stata(path, write_index=False, convert_strl=["mixed"], version=117)
         reread = read_stata(path)
         expected = output.fillna("")
+        if using_infer_string:
+            expected["mixed"] = expected["mixed"].astype("str")
         tm.assert_frame_equal(reread, expected)
 
     @pytest.mark.parametrize("version", [114, 117, 118, 119, None])
@@ -2000,7 +2001,6 @@ def test_stata_119(self, datapath):
                 reader._ensure_open()
                 assert reader._nvar == 32999
 
-    @pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
     @pytest.mark.parametrize("version", [118, 119, None])
     @pytest.mark.parametrize("byteorder", ["little", "big"])
     def test_utf8_writer(self, version, byteorder, temp_file):
@@ -2348,13 +2348,12 @@ def test_iterator_errors(datapath, chunksize):
             pass
 
 
-@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
 def test_iterator_value_labels(temp_file):
     # GH 31544
     values = ["c_label", "b_label"] + ["a_label"] * 500
     df = DataFrame({f"col{k}": pd.Categorical(values, ordered=True) for k in range(2)})
     df.to_stata(temp_file, write_index=False)
-    expected = pd.Index(["a_label", "b_label", "c_label"], dtype="object")
+    expected = pd.Index(["a_label", "b_label", "c_label"])
     with read_stata(temp_file, chunksize=100) as reader:
         for j, chunk in enumerate(reader):
             for i in range(2):