Skip to content

Commit 3407fe6

Browse files
authored
Support the importlib.resources files API in rewritten files (#9173)
1 parent e84ba80 commit 3407fe6

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

changelog/9169.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support for the ``files`` API from ``importlib.resources`` within rewritten files.

src/_pytest/assertion/rewrite.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self, config: Config) -> None:
6464
except ValueError:
6565
self.fnpats = ["test_*.py", "*_test.py"]
6666
self.session: Optional[Session] = None
67-
self._rewritten_names: Set[str] = set()
67+
self._rewritten_names: Dict[str, Path] = {}
6868
self._must_rewrite: Set[str] = set()
6969
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
7070
# which might result in infinite recursion (#3506)
@@ -134,7 +134,7 @@ def exec_module(self, module: types.ModuleType) -> None:
134134
fn = Path(module.__spec__.origin)
135135
state = self.config.stash[assertstate_key]
136136

137-
self._rewritten_names.add(module.__name__)
137+
self._rewritten_names[module.__name__] = fn
138138

139139
# The requested module looks like a test file, so rewrite it. This is
140140
# the most magical part of the process: load the source, rewrite the
@@ -276,6 +276,14 @@ def get_data(self, pathname: Union[str, bytes]) -> bytes:
276276
with open(pathname, "rb") as f:
277277
return f.read()
278278

279+
if sys.version_info >= (3, 9):
280+
281+
def get_resource_reader(self, name: str) -> importlib.abc.TraversableResources: # type: ignore
282+
from types import SimpleNamespace
283+
from importlib.readers import FileReader
284+
285+
return FileReader(SimpleNamespace(path=self._rewritten_names[name]))
286+
279287

280288
def _write_pyc_fp(
281289
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType

testing/test_assertrewrite.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,35 @@ def test_zipfile(self, pytester: Pytester) -> None:
795795
)
796796
assert pytester.runpytest().ret == ExitCode.NO_TESTS_COLLECTED
797797

798+
@pytest.mark.skipif(
799+
sys.version_info < (3, 9),
800+
reason="importlib.resources.files was introduced in 3.9",
801+
)
802+
def test_load_resource_via_files_with_rewrite(self, pytester: Pytester) -> None:
803+
example = pytester.path.joinpath("demo") / "example"
804+
init = pytester.path.joinpath("demo") / "__init__.py"
805+
pytester.makepyfile(
806+
**{
807+
"demo/__init__.py": """
808+
from importlib.resources import files
809+
810+
def load():
811+
return files(__name__)
812+
""",
813+
"test_load": f"""
814+
pytest_plugins = ["demo"]
815+
816+
def test_load():
817+
from demo import load
818+
found = {{str(i) for i in load().iterdir() if i.name != "__pycache__"}}
819+
assert found == {{{str(example)!r}, {str(init)!r}}}
820+
""",
821+
}
822+
)
823+
example.mkdir()
824+
825+
assert pytester.runpytest("-vv").ret == ExitCode.OK
826+
798827
def test_readonly(self, pytester: Pytester) -> None:
799828
sub = pytester.mkdir("testing")
800829
sub.joinpath("test_readonly.py").write_bytes(

0 commit comments

Comments
 (0)