diff --git a/inject/__init__.py b/inject/__init__.py index 100784f..a952186 100644 --- a/inject/__init__.py +++ b/inject/__init__.py @@ -104,6 +104,14 @@ def my_config(binder): _INJECTOR_LOCK = threading.RLock() # Guards injector initialization. _BINDING_LOCK = threading.RLock() # Guards runtime bindings. +MARKER = object() # A marker object to explicitly set a parameter to be injected + + +def set_marker(new_marker: Any): + global MARKER + MARKER = new_marker + + Injectable = Union[object, Any] T = TypeVar('T', bound=Injectable) Binding = Union[Type[Injectable], Hashable] @@ -328,11 +336,17 @@ def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[..., if inspect.iscoroutinefunction(func): @wraps(func) async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T: - provided_params = frozenset( - arg_names[:len(args)]) | frozenset(kwargs.keys()) + arg_name_tuple = arg_names[:len(args)] + provided_params = frozenset(arg_name_tuple) | frozenset(kwargs.keys()) for param, cls in params_to_provide.items(): if param not in provided_params: kwargs[param] = instance(cls) + elif param in kwargs and kwargs[param] is MARKER: + kwargs[param] = instance(cls) + elif param in arg_name_tuple: + idx = arg_name_tuple.index(param) + if args[idx] is MARKER: + args = args[:idx] + (instance(cls),) + args[idx + 1:] async_func = cast(Callable[..., Awaitable[T]], func) try: return await async_func(*args, **kwargs) @@ -343,11 +357,17 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T: @wraps(func) def injection_wrapper(*args: Any, **kwargs: Any) -> T: - provided_params = frozenset( - arg_names[:len(args)]) | frozenset(kwargs.keys()) + arg_name_tuple = arg_names[:len(args)] + provided_params = frozenset(arg_name_tuple) | frozenset(kwargs.keys()) for param, cls in params_to_provide.items(): if param not in provided_params: kwargs[param] = instance(cls) + elif param in kwargs and kwargs[param] is MARKER: + kwargs[param] = instance(cls) + elif param in arg_name_tuple: + idx = arg_name_tuple.index(param) + if args[idx] is MARKER: + args = args[:idx] + (instance(cls),) + args[idx + 1:] sync_func = cast(Callable[..., T], func) try: return sync_func(*args, **kwargs) diff --git a/test/test_autoparams.py b/test/test_autoparams.py index dda87f4..9ea7a61 100644 --- a/test/test_autoparams.py +++ b/test/test_autoparams.py @@ -17,6 +17,8 @@ def test_func(val: int = None): inject.configure(lambda binder: binder.bind(int, 123)) assert test_func() == 123 + assert test_func(inject.MARKER) == 123 + assert test_func(val=inject.MARKER) == 123 assert test_func(val=321) == 321 def test_autoparams_multi(self): @@ -33,8 +35,11 @@ def config(binder): assert test_func() == (1, 2, 3) assert test_func(10) == (10, 2, 3) + assert test_func(10, inject.MARKER) == (10, 2, 3) + assert test_func(10, 20) == (10, 20, 3) assert test_func(10, 20) == (10, 20, 3) assert test_func(10, 20, c=30) == (10, 20, 30) + assert test_func(a=inject.MARKER) == (1, 2, 3) assert test_func(a='a') == ('a', 2, 3) assert test_func(b='b') == (1, 'b', 3) assert test_func(c='c') == (1, 2, 'c') @@ -56,8 +61,10 @@ def config(binder): assert test_func() == (1, 2, 3) assert test_func(10) == (10, 2, 3) + assert test_func(10, inject.MARKER) == (10, 2, 3) assert test_func(10, 20) == (10, 20, 3) assert test_func(10, 20, c=30) == (10, 20, 30) + assert test_func(a=inject.MARKER) == (1, 2, 3) assert test_func(a='a') == ('a', 2, 3) assert test_func(b='b') == (1, 'b', 3) assert test_func(c='c') == (1, 2, 'c') @@ -78,9 +85,11 @@ def config(binder): assert test_func() == (1, 2, 3) assert test_func(10) == (10, 2, 3) + assert test_func(10, inject.MARKER) == (10, 2, 3) assert test_func(10, 20) == (10, 20, 3) assert test_func(10, 20, c=30) == (10, 20, 30) assert test_func(a='a') == ('a', 2, 3) + assert test_func(b=inject.MARKER) == (1, 2, 3) assert test_func(b='b') == (1, 'b', 3) assert test_func(c='c') == (1, 2, 'c') assert test_func(a=10, c=30) == (10, 2, 30) @@ -102,9 +111,11 @@ def config(binder): assert test.func() == (test, 1, 2, 3) assert test.func(10) == (test, 10, 2, 3) + assert test.func(10, inject.MARKER) == (test, 10, 2, 3) assert test.func(10, 20) == (test, 10, 20, 3) assert test.func(10, 20, c=30) == (test, 10, 20, 30) assert test.func(a='a') == (test, 'a', 2, 3) + assert test.func(b=inject.MARKER) == (test, 1, 2, 3) assert test.func(b='b') == (test, 1, 'b', 3) assert test.func(c='c') == (test, 1, 2, 'c') assert test.func(a=10, c=30) == (test, 10, 2, 30) @@ -127,9 +138,11 @@ def config(binder): assert Test.func() == (Test, 1, 2, 3) assert Test.func(10) == (Test, 10, 2, 3) + assert Test.func(10, inject.MARKER) == (Test, 10, 2, 3) assert Test.func(10, 20) == (Test, 10, 20, 3) assert Test.func(10, 20, c=30) == (Test, 10, 20, 30) assert Test.func(a='a') == (Test, 'a', 2, 3) + assert Test.func(b=inject.MARKER) == (Test, 1, 2, 3) assert Test.func(b='b') == (Test, 1, 'b', 3) assert Test.func(c='c') == (Test, 1, 2, 'c') assert Test.func(a=10, c=30) == (Test, 10, 2, 30) @@ -153,9 +166,11 @@ def config(binder): assert test.func() == (Test, 1, 2, 3) assert test.func(10) == (Test, 10, 2, 3) + assert test.func(10, inject.MARKER) == (Test, 10, 2, 3) assert test.func(10, 20) == (Test, 10, 20, 3) assert test.func(10, 20, c=30) == (Test, 10, 20, 30) assert test.func(a='a') == (Test, 'a', 2, 3) + assert test.func(b=inject.MARKER) == (Test, 1, 2, 3) assert test.func(b='b') == (Test, 1, 'b', 3) assert test.func(c='c') == (Test, 1, 2, 'c') assert test.func(a=10, c=30) == (Test, 10, 2, 30) diff --git a/test/test_params.py b/test/test_params.py index fc2e8dc..d079ea5 100644 --- a/test/test_params.py +++ b/test/test_params.py @@ -12,7 +12,9 @@ def test_func(val): inject.configure(lambda binder: binder.bind(int, 123)) assert test_func() == 123 + assert test_func(inject.MARKER) == 123 assert test_func(321) == 321 + assert test_func(val=inject.MARKER) == 123 assert test_func(val=42) == 42 def test_params_multi(self): @@ -29,8 +31,10 @@ def config(binder): assert test_func() == (1, 2, 3) assert test_func(10) == (10, 2, 3) + assert test_func(10, inject.MARKER) == (10, 2, 3) assert test_func(10, 20) == (10, 20, 3) assert test_func(10, 20, 30) == (10, 20, 30) + assert test_func(a=inject.MARKER) == (1, 2, 3) assert test_func(a='a') == ('a', 2, 3) assert test_func(b='b') == (1, 'b', 3) assert test_func(c='c') == (1, 2, 'c') @@ -52,12 +56,14 @@ def config(binder): assert test_func() == (1, 2, 3) assert test_func(10) == (10, 2, 3) + assert test_func(10, inject.MARKER) == (10, 2, 3) assert test_func(10, 20) == (10, 20, 3) assert test_func(10, 20, 30) == (10, 20, 30) assert test_func(a='a') == ('a', 2, 3) assert test_func(b='b') == (1, 'b', 3) assert test_func(c='c') == (1, 2, 'c') assert test_func(a=10, c=30) == (10, 2, 30) + assert test_func(a=10, b=inject.MARKER, c=30) == (10, 2, 30) assert test_func(c=30, b=20, a=10) == (10, 20, 30) assert test_func(10, b=20) == (10, 20, 3) @@ -78,10 +84,12 @@ def config(binder): assert test.func(10) == (test, 10, 2, 3) assert test.func(10, 20) == (test, 10, 20, 3) assert test.func(10, 20, 30) == (test, 10, 20, 30) + assert test.func(10, inject.MARKER, inject.MARKER) == (test, 10, 2, 3) assert test.func(a='a') == (test, 'a', 2, 3) assert test.func(b='b') == (test, 1, 'b', 3) assert test.func(c='c') == (test, 1, 2, 'c') assert test.func(a=10, c=30) == (test, 10, 2, 30) + assert test.func(a=10, b=inject.MARKER, c=30) == (test, 10, 2, 30) assert test.func(c=30, b=20, a=10) == (test, 10, 20, 30) assert test.func(10, b=20) == (test, 10, 20, 3) @@ -103,10 +111,12 @@ def config(binder): assert Test.func(10) == (Test, 10, 2, 3) assert Test.func(10, 20) == (Test, 10, 20, 3) assert Test.func(10, 20, 30) == (Test, 10, 20, 30) + assert Test.func(10, inject.MARKER, inject.MARKER) == (Test, 10, 2, 3) assert Test.func(a='a') == (Test, 'a', 2, 3) assert Test.func(b='b') == (Test, 1, 'b', 3) assert Test.func(c='c') == (Test, 1, 2, 'c') assert Test.func(a=10, c=30) == (Test, 10, 2, 30) + assert Test.func(a=10, b=inject.MARKER, c=30) == (Test, 10, 2, 30) assert Test.func(c=30, b=20, a=10) == (Test, 10, 20, 30) assert Test.func(10, b=20) == (Test, 10, 20, 3) @@ -129,10 +139,12 @@ def config(binder): assert test.func(10) == (Test, 10, 2, 3) assert test.func(10, 20) == (Test, 10, 20, 3) assert test.func(10, 20, 30) == (Test, 10, 20, 30) + assert test.func(10, inject.MARKER, inject.MARKER) == (Test, 10, 2, 3) assert test.func(a='a') == (Test, 'a', 2, 3) assert test.func(b='b') == (Test, 1, 'b', 3) assert test.func(c='c') == (Test, 1, 2, 'c') assert test.func(a=10, c=30) == (Test, 10, 2, 30) + assert test.func(a=10, b=inject.MARKER, c=30) == (Test, 10, 2, 30) assert test.func(c=30, b=20, a=10) == (Test, 10, 20, 30) assert test.func(10, b=20) == (Test, 10, 20, 3) @@ -145,5 +157,7 @@ async def test_func(val): assert inspect.iscoroutinefunction(test_func) assert self.run_async(test_func()) == 123 + assert self.run_async(test_func(inject.MARKER)) == 123 assert self.run_async(test_func(321)) == 321 + assert self.run_async(test_func(val=inject.MARKER)) == 123 assert self.run_async(test_func(val=42)) == 42