diff --git a/simple_parsing/wrappers/field_wrapper.py b/simple_parsing/wrappers/field_wrapper.py index 3a4d1860..c9d71e21 100644 --- a/simple_parsing/wrappers/field_wrapper.py +++ b/simple_parsing/wrappers/field_wrapper.py @@ -330,6 +330,8 @@ def get_arg_options(self) -> dict[str, Any]: elif self.is_union: logger.debug("Parsing a Union type!") _arg_options["type"] = get_parsing_fn(self.type) + if any(utils.is_list(o) for o in utils.get_args(self.type)): + _arg_options["nargs"] = "*" elif self.is_enum: logger.debug(f"Adding an Enum attribute '{self.name}'") @@ -501,6 +503,13 @@ def postprocess(self, raw_parsed_value: Any) -> Any: else: return raw_parsed_value + elif self.is_union: + list_in = [utils.is_list(o) for o in utils.get_args(self.type)] + # if type is like Union[str, list[str]] and only a single value was passed, + if any(list_in) and (not all(list_in)) and (len(raw_parsed_value) == 1): + raw_parsed_value = raw_parsed_value[0] + return raw_parsed_value + elif self.is_subparser: return raw_parsed_value diff --git a/test/test_union.py b/test/test_union.py index 4c0027dd..e90a39ac 100644 --- a/test/test_union.py +++ b/test/test_union.py @@ -32,3 +32,25 @@ class Foo2(TestSetup): foo = Foo2.setup("--x 2") assert foo.x == 2 and type(foo.x) is int + + +def test_union_type_with_list(): + @dataclass + class Foo(TestSetup): + x: Union[str, list[str]] + + foo = Foo.setup("--x bob") + assert foo.x == "bob" + + foo = Foo.setup("--x bob alice") + assert foo.x == ["bob", "alice"] + + @dataclass + class Foo(TestSetup): + x: Union[list[int], list[str]] + + foo = Foo.setup("--x bob alice") + assert foo.x == ["bob", "alice"] + + foo = Foo.setup("--x 1 2") + assert foo.x == [1, 2]