Skip to content

Commit aa2bc69

Browse files
committed
import_tools.discover: Correctly handle both modules and functions defined in __init__.py
1 parent eefb073 commit aa2bc69

File tree

3 files changed

+53
-23
lines changed

3 files changed

+53
-23
lines changed

domdf_python_tools/import_tools.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,19 @@
5151
from typing import Any, Callable, Dict, List, Optional, Type, overload
5252

5353
# 3rd party
54-
from typing_extensions import Literal
54+
from typing_extensions import Literal, TypedDict
5555

5656
# this package
5757
from domdf_python_tools.compat import importlib_metadata
5858

5959
__all__ = ["discover", "discover_entry_points", "discover_entry_points_by_name"]
6060

6161

62+
class _DiscoverKwargsType(TypedDict):
63+
match_func: Optional[Callable[[Any], bool]]
64+
exclude_side_effects: bool
65+
66+
6267
@overload
6368
def discover(
6469
package: ModuleType,
@@ -81,45 +86,66 @@ def discover(
8186
exclude_side_effects: bool = True,
8287
) -> List[Any]:
8388
"""
84-
Returns a list of objects in the given module,
85-
optionally filtered by ``match_func``.
89+
Returns a list of objects in the given package, optionally filtered by ``match_func``.
8690
8791
:param package: A Python package
8892
:param match_func: Function taking an object and returning :py:obj:`True` if the object is to be included in the output.
8993
:default match_func: :py:obj:`None`, which includes all objects.
9094
:param exclude_side_effects: Don't include objects that are only there because of an import side effect.
9195
92-
:return: List of matching objects.
93-
9496
.. versionchanged:: 1.0.0
9597
9698
Added the ``exclude_side_effects`` parameter.
99+
"""
97100

98-
.. TODO:: raise better exception when passing a module rather than a package.
99-
Or just return the contents of the module?
100-
""" # noqa D400
101+
kwargs: _DiscoverKwargsType = dict(exclude_side_effects=exclude_side_effects, match_func=match_func)
101102

102-
matching_objects = []
103+
matching_objects = _discover_in_module(package, **kwargs)
103104

104-
for _, module_name, _ in pkgutil.walk_packages(
105+
if hasattr(package, "__path__"):
105106
# https://github.com/python/mypy/issues/1422
106107
# Stalled PRs: https://github.com/python/mypy/pull/3527
107108
# https://github.com/python/mypy/pull/5212
108-
package.__path__, # type: ignore
109-
prefix=package.__name__ + '.',
110-
):
111-
module = __import__(module_name, fromlist=["__trash"], level=0)
109+
package_path = package.__path__ # type: ignore
110+
111+
for _, module_name, _ in pkgutil.walk_packages(package_path, prefix=f'{package.__name__}.'):
112+
module = __import__(module_name, fromlist=["__trash"], level=0)
113+
114+
matching_objects.extend(_discover_in_module(module, **kwargs))
115+
116+
return matching_objects
117+
118+
119+
def _discover_in_module(
120+
module: ModuleType,
121+
match_func: Optional[Callable[[Any], bool]] = None,
122+
exclude_side_effects: bool = True,
123+
) -> List[Any]:
124+
"""
125+
Returns a list of objects in the given module, optionally filtered by ``match_func``.
126+
127+
:param module: A Python module.
128+
:param match_func: Function taking an object and returning :py:obj:`True` if the object is to be included in the output.
129+
:default match_func: :py:obj:`None`, which includes all objects.
130+
:param exclude_side_effects: Don't include objects that are only there because of an import side effect.
131+
132+
.. versionadded:: 2.6.0
133+
134+
.. TODO:: make this public in that version
135+
"""
136+
137+
matching_objects = []
112138

113-
# Check all the functions in that module
114-
for _, imported_objects in inspect.getmembers(module, match_func):
139+
# Check all the functions in that module
140+
for _, imported_object in inspect.getmembers(module, match_func):
115141

116-
if exclude_side_effects:
117-
if not hasattr(imported_objects, "__module__"):
118-
continue
119-
if imported_objects.__module__ != module.__name__:
120-
continue
142+
if exclude_side_effects:
143+
if not hasattr(imported_object, "__module__"):
144+
continue
145+
if imported_object.__module__ != module.__name__:
146+
continue
121147

122-
matching_objects.append(imported_objects)
148+
matching_objects.append(imported_object)
123149

124150
return matching_objects
125151

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def foo_in_init() -> str:
2+
pass

tests/test_import_tools.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
def test_discover():
2121
# Alphabetical order regardless of order in the module.
2222
assert discover(discover_demo_module) == [
23+
discover_demo_module.foo_in_init,
2324
discover_demo_module.submodule_a.bar,
2425
discover_demo_module.submodule_a.foo,
2526
discover_demo_module.submodule_b.Alice,
@@ -32,6 +33,7 @@ def test_discover_function_only():
3233
assert discover(
3334
discover_demo_module, match_func=inspect.isfunction
3435
) == [
36+
discover_demo_module.foo_in_init,
3537
discover_demo_module.submodule_a.bar,
3638
discover_demo_module.submodule_a.foo,
3739
]
@@ -74,7 +76,7 @@ def does_not_raise():
7476
def raises_attribute_error(obj, **kwargs):
7577
return pytest.param(
7678
obj,
77-
pytest.raises(AttributeError, match=f"^'{type(obj).__name__}' object has no attribute '__path__'$"),
79+
pytest.raises(AttributeError, match=f"^'{type(obj).__name__}' object has no attribute '__name__'$"),
7880
**kwargs,
7981
)
8082

0 commit comments

Comments
 (0)