diff --git a/docs/examples/fields/test_example_10.py b/docs/examples/fields/test_example_10.py new file mode 100644 index 00000000..fe8f7fbd --- /dev/null +++ b/docs/examples/fields/test_example_10.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass + +from polyfactory.decorators import post_generated +from polyfactory.factories import DataclassFactory +from polyfactory.fields import Param + + +@dataclass +class Person: + name: str + age_next_year: int + + +class PersonFactoryWithParamValueSpecifiedInFactory(DataclassFactory[Person]): + """In this factory, the next_years_age_from_calculator must be passed at build time.""" + + next_years_age_from_calculator = Param[int](lambda age: age + 1, is_callable=True, age=20) + + @post_generated + @classmethod + def age_next_year(cls, next_years_age_from_calculator: int) -> int: + return next_years_age_from_calculator + + +def test_factory__in_factory() -> None: + person = PersonFactoryWithParamValueSpecifiedInFactory.build() + + assert isinstance(person, Person) + assert not hasattr(person, "next_years_age_from_calculator") + assert person.age_next_year == 21 + + +class PersonFactoryWithParamValueSetAtBuild(DataclassFactory[Person]): + """In this factory, the next_years_age_from_calculator must be passed at build time.""" + + next_years_age_from_calculator = Param[int](is_callable=True, age=20) + + @post_generated + @classmethod + def age_next_year(cls, next_years_age_from_calculator: int) -> int: + return next_years_age_from_calculator + + +def test_factory__build_time() -> None: + person = PersonFactoryWithParamValueSpecifiedInFactory.build(next_years_age_from_calculator=lambda age: age + 1) + + assert isinstance(person, Person) + assert not hasattr(person, "next_years_age_from_calculator") + assert person.age_next_year == 21 diff --git a/docs/examples/fields/test_example_9.py b/docs/examples/fields/test_example_9.py new file mode 100644 index 00000000..fd7d7480 --- /dev/null +++ b/docs/examples/fields/test_example_9.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import List + +from polyfactory.decorators import post_generated +from polyfactory.factories import DataclassFactory +from polyfactory.fields import Param + + +@dataclass +class Pet: + name: str + sound: str + + +class PetFactoryWithParamValueSetAtBuild(DataclassFactory[Pet]): + """In this factory, the name_choices must be passed at build time.""" + + name_choices = Param[List[str]]() + + @post_generated + @classmethod + def name(cls, name_choices: List[str]) -> str: + return cls.__random__.choice(name_choices) + + +def test_factory__build_time() -> None: + names = ["Ralph", "Roxy"] + pet = PetFactoryWithParamValueSetAtBuild.build(name_choices=names) + + assert isinstance(pet, Pet) + assert not hasattr(pet, "name_choices") + assert pet.name in names + + +class PetFactoryWithParamSpecififiedInFactory(DataclassFactory[Pet]): + """In this factory, the name_choices are specified in the + factory and do not need to be passed at build time.""" + + name_choices = Param[List[str]](["Ralph", "Roxy"]) + + @post_generated + @classmethod + def name(cls, name_choices: List[str]) -> str: + return cls.__random__.choice(name_choices) + + +def test_factory__in_factory() -> None: + pet = PetFactoryWithParamSpecififiedInFactory.build() + + assert isinstance(pet, Pet) + assert not hasattr(pet, "name_choices") + assert pet.name in ["Ralph", "Roxy"] diff --git a/docs/usage/fields.rst b/docs/usage/fields.rst index d6162671..6abefe48 100644 --- a/docs/usage/fields.rst +++ b/docs/usage/fields.rst @@ -78,6 +78,28 @@ The signature for use is: ``cb: Callable, *args, **defaults`` it can receive an callable should be: ``name: str, values: dict[str, Any], *args, **defaults``. The already generated values are mapped by name in the values dictionary. + +The ``Param`` Field +------------------- + +The :class:`Param ` class denotes a parameter that can be referenced by other fields at build but whose value is not set on the final object. This is useful for passing values needed by other factory fields but that are not part of object being built. + +A Param type can be either a constant or a callable. If a callable is used, it will be executed at the beginning of build and its return value will be used as the value for the field. Optional keyword arguments may be passed to the callable as part of the field definition on the factory. Any additional keyword arguments passed to the Param constructor will also not be mapped into the final object. + +The Param type allows for flexibility in that it can either accept a value at the definition of the factory, or its value can be set at build time. If a value is provided at build time, it will take precedence over the value provided at the definition of the factory (if any). + +If neither a value is provided at the definition of the factory nor at build time, an exception will be raised. Likewise, a Param cannot have the same name as any other model field. + +.. literalinclude:: /examples/fields/test_example_9.py + :caption: Using the ``Param`` field with a constant + :language: python + +.. literalinclude:: /examples/fields/test_example_10.py + :caption: Using the ``Param`` field with a callable + :language: python + + + Factories as Fields ------------------- diff --git a/polyfactory/exceptions.py b/polyfactory/exceptions.py index 53f1271a..a4c22b75 100644 --- a/polyfactory/exceptions.py +++ b/polyfactory/exceptions.py @@ -16,3 +16,7 @@ class MissingBuildKwargException(FactoryException): class MissingDependencyException(FactoryException, ImportError): """Missing dependency exception - used when a dependency is not installed""" + + +class MissingParamException(FactoryException): + """Missing parameter exception - used when a required Param is not provided""" diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 09058037..52d983ca 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -8,7 +8,7 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal from enum import EnumMeta -from functools import partial +from functools import lru_cache, partial from importlib import import_module from ipaddress import ( IPv4Address, @@ -52,9 +52,14 @@ MIN_COLLECTION_LENGTH, RANDOMIZE_COLLECTION_LENGTH, ) -from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException +from polyfactory.exceptions import ( + ConfigurationException, + MissingBuildKwargException, + MissingParamException, + ParameterException, +) from polyfactory.field_meta import Null -from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use +from polyfactory.fields import Fixture, Ignore, Param, PostGenerated, Require, Use from polyfactory.utils.helpers import ( flatten_annotation, get_collection_type, @@ -230,6 +235,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901, PL raise ConfigurationException( msg, ) + cls._check_overlapping_param_names() if cls.__check_model__: cls._check_declared_fields_exist_in_model() else: @@ -1024,6 +1030,44 @@ def _check_declared_fields_exist_in_model(cls) -> None: if isinstance(field_value, (Use, PostGenerated, Ignore, Require)): raise ConfigurationException(error_message) + @classmethod + def _handle_factory_params(cls, params: dict[str, Param[Any]], **kwargs: Any) -> dict[str, Any]: + """Get the factory parameters. + + :param params: A dict of field name to Param instances. + :param kwargs: Any build kwargs. + + :returns: A dict of fieldname mapped to realized Param values. + """ + + try: + return {name: param.to_value(kwargs.get(name, Null)) for name, param in params.items()} + except MissingParamException as e: + msg = "Missing required kwargs" + raise MissingBuildKwargException(msg) from e + + @classmethod + @lru_cache(maxsize=None) + def get_factory_params(cls) -> dict[str, Param[Any]]: + """Get the factory parameters. + + :returns: A dict of field name to Param instances. + """ + return {name: item for name, item in cls.__dict__.items() if isinstance(item, Param)} + + @classmethod + def _check_overlapping_param_names(cls) -> None: + """Checks if there are overlapping param names with model fields. + + + :raises: ConfigurationException + """ + model_fields_names = {field_meta.name for field_meta in cls.get_model_fields()} + overlapping_params = set(cls.get_factory_params().keys()) & model_fields_names + if overlapping_params: + msg = f"Factory Params {', '.join(overlapping_params)} overlap with model fields" + raise ConfigurationException(msg) + @classmethod def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: """Process the given kwargs and generate values for the factory's model. @@ -1035,6 +1079,9 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: """ result, generate_post, _build_context = cls._get_initial_variables(kwargs) + params = cls.get_factory_params() + result.update(cls._handle_factory_params(params, **kwargs)) + for field_meta in cls.get_model_fields(): field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta): @@ -1071,7 +1118,7 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: for field_name, post_generator in generate_post.items(): result[field_name] = post_generator.to_value(field_name, result) - return result + return {key: value for key, value in result.items() if key not in params} @classmethod def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: @@ -1085,6 +1132,9 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: """ result, generate_post, _build_context = cls._get_initial_variables(kwargs) + params = cls.get_factory_params() + result.update(cls._handle_factory_params(params, **kwargs)) + for field_meta in cls.get_model_fields(): field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) @@ -1120,7 +1170,8 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: for resolved in resolve_kwargs_coverage(result): for field_name, post_generator in generate_post.items(): resolved[field_name] = post_generator.to_value(field_name, resolved) - yield resolved + + yield {key: value for key, value in resolved.items() if key not in params} @classmethod def build(cls, *_: Any, **kwargs: Any) -> T: diff --git a/polyfactory/fields.py b/polyfactory/fields.py index 6598e450..f68a3fdf 100644 --- a/polyfactory/fields.py +++ b/polyfactory/fields.py @@ -4,11 +4,13 @@ from typing_extensions import ParamSpec -from polyfactory.exceptions import ParameterException +from polyfactory.exceptions import MissingParamException, ParameterException +from polyfactory.field_meta import Null from polyfactory.utils import deprecation from polyfactory.utils.predicates import is_safe_subclass T = TypeVar("T") +U = TypeVar("U") P = ParamSpec("P") @@ -118,3 +120,81 @@ def to_value(self) -> Any: if self.size is not None: return factory.batch(self.size, **self.kwargs) return factory.build(**self.kwargs) + + +class Param(Generic[T]): + """A constant parameter that can be used by other fields but will not be + passed to the final object. + + If a value for the parameter is not passed in the field's definition, it must + be passed at build time. Otherwise, a MissingParamException will be raised. + """ + + __slots__ = ("is_callable", "kwargs", "param") + + def __init__( + self, param: T | Callable[..., T] | type[Null] = Null, is_callable: bool = False, **kwargs: Any + ) -> None: + """Designate a parameter. + + :param param: A constant or an unpassed value that can be referenced later + """ + if param is not Null and is_callable and not callable(param): + msg = "If an object is passed to param, a callable must be passed when is_callable is True" + raise ParameterException(msg) + if not is_callable and kwargs: + msg = "kwargs can only be used with callable parameters" + raise ParameterException(msg) + + self.param = param + self.is_callable = is_callable + self.kwargs = kwargs + + def to_value(self, from_build: T | Callable[..., T] | type[Null] = Null, **kwargs: Any) -> T: + """Determines the value to use at build time + + If a value was passed to the constructor, it will be used. Otherwise, the value + passed at build time will be used. If no value was passed at build time, a + MissingParamException will be raised. + + :param args: from_build: The value passed at build time (if any). + :returns: The value + :raises: MissingParamException + """ + # If no param is passed at initialization, a value must be passed now + if self.param is Null: + # from_build was passed, so determine the value based on whether or + # not we're supposed to call a callable + if from_build is not Null: + return ( + cast("T", from_build) + if not self.is_callable + else cast("Callable[..., T]", from_build)(**{**self.kwargs, **kwargs}) + ) + + # Otherwise, raise an exception + msg = ( + "Expected a parameter value to be passed at build time" + if not self.is_callable + else "Expected a callable to be passed at build time" + ) + raise MissingParamException(msg) + # A param was passed at initialization + if self.is_callable: + # In this case, we are going to call the callable, but we can still + # override if are passed a callable at build + if from_build is not Null: + if callable(from_build): + return cast("Callable[..., T]", from_build)(**{**self.kwargs, **kwargs}) + + # If we were passed a value at build that isn't a callable, raise + # an exception + msg = "The value passed at build time is not callable" + raise TypeError(msg) + + # Otherwise, return the value passed at initialization + return cast("Callable[..., T]", self.param)(**{**self.kwargs, **kwargs}) + + # Inthis case, we are not using a callable, so return either the value + # passed at build time or initialization + return cast("T", self.param) if from_build is Null else cast("T", from_build) diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index fbd772ac..def90bb5 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -456,9 +456,8 @@ class Place(Base): id: Any = Column(Integer(), primary_key=True) numeric_field: Any = Column(Location, nullable=False) - factory = SQLAlchemyFactory.create_factory(Place) with pytest.raises( ParameterException, match="Unsupported type engine: Location()", ): - factory.build() + SQLAlchemyFactory.create_factory(Place) diff --git a/tests/test_factory_fields.py b/tests/test_factory_fields.py index e07aa9f6..800e7695 100644 --- a/tests/test_factory_fields.py +++ b/tests/test_factory_fields.py @@ -1,17 +1,23 @@ import random from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, ClassVar, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Optional, Type, Union import pytest from pydantic import BaseModel from polyfactory.decorators import post_generated -from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException +from polyfactory.exceptions import ( + ConfigurationException, + MissingBuildKwargException, + MissingParamException, + ParameterException, +) from polyfactory.factories.dataclass_factory import DataclassFactory from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.fields import Ignore, PostGenerated, Require, Use +from polyfactory.field_meta import Null +from polyfactory.fields import Ignore, Param, PostGenerated, Require, Use def test_use() -> None: @@ -83,6 +89,192 @@ class MyFactory(ModelFactory): assert MyFactory.build().name is None +@pytest.mark.parametrize( + "value,is_callable,kwargs", + [ + (None, False, {}), + (1, False, {}), + ("foo", False, {}), + (lambda value: value, True, {}), + (lambda value1, value2: value1 + value2, True, {}), + (lambda: "foo", True, {}), + (lambda: "foo", True, {"value": 3}), + ], +) +def test_param_init(value: Any, is_callable: bool, kwargs: Dict[str, Any]) -> None: + param = Param(value, is_callable, **kwargs) # type: ignore + assert isinstance(param, Param) + assert param.param == value + assert param.is_callable == is_callable + assert param.kwargs == kwargs + + +@pytest.mark.parametrize( + "value,is_callable,kwargs", + [ + (None, True, {}), + (1, True, {}), + ("foo", True, {}), + (Null, False, {"value": 3}), + (1, False, {"value": 3}), + ], +) +def test_param_init_error(value: Any, is_callable: bool, kwargs: Dict[str, Any]) -> None: + with pytest.raises( + ParameterException, + ): + Param(value, is_callable, **kwargs) + + +@pytest.mark.parametrize( + "initval,is_cabllable,initkwargs,buildval,buildkwargs,outcome", + [ + (None, False, {}, Null, {}, None), + (1, False, {}, 2, {}, 2), + ("foo", False, {}, Null, {}, "foo"), + (lambda value: value, True, {}, lambda value: value + 1, {"value": 3}, 4), + (lambda value1, value2: value1 + value2, True, {"value1": 2}, Null, {"value2": 1}, 3), + (lambda: "foo", True, {}, Null, {}, "foo"), + ], +) +def test_param_to_value( + initval: Any, + is_cabllable: bool, + initkwargs: Dict[str, Any], + buildval: Any, + buildkwargs: Dict[str, Any], + outcome: Any, +) -> None: + assert Param(initval, is_cabllable, **initkwargs).to_value(buildval, **buildkwargs) == outcome + + +@pytest.mark.parametrize( + "initval,is_cabllable,initkwargs,buildval,buildkwargs,exc", + [ + (Null, False, {}, Null, {}, MissingParamException), + (Null, True, {}, 1, {}, TypeError), + ], +) +def test_param_to_value_exception( + initval: Any, + is_cabllable: bool, + initkwargs: Dict[str, Any], + buildval: Any, + buildkwargs: Dict[str, Any], + exc: Type[Exception], +) -> None: + with pytest.raises(exc): + Param(initval, is_cabllable, **initkwargs).to_value(buildval, **buildkwargs) + + +def test_param_from_factory() -> None: + value: int = 3 + + class MyModel(BaseModel): + description: str + + class MyFactory(ModelFactory): + __model__ = MyModel + length = Param[int](value) + + @post_generated + @classmethod + def description(cls, length: int) -> str: + return "abcd"[:length] + + result = MyFactory.build() + assert result.description == "abc" + + +def test_param_from_kwargs() -> None: + value: int = 3 + + class MyModel(BaseModel): + description: str + + class MyFactory(ModelFactory): + __model__ = MyModel + length = Param[int]() + + @post_generated + @classmethod + def description(cls, length: int) -> str: + return "abcd"[:length] + + result = MyFactory.build(length=value) + assert result.description == "abc" + + +def test_param_from_kwargs_missing() -> None: + class MyModel(BaseModel): + description: str + + class MyFactory(ModelFactory): + __model__ = MyModel + length = Param[int]() + + @post_generated + @classmethod + def description(cls, length: int) -> str: + return "abcd"[:length] + + with pytest.raises(MissingBuildKwargException): + MyFactory.build() + + +def test_callable_param_from_factory() -> None: + class MyModel(BaseModel): + description: str + + class MyFactory(ModelFactory): + __model__ = MyModel + length = Param(lambda value: value, is_callable=True, value=3) + + @post_generated + @classmethod + def description(cls, length: int) -> str: + return "abcd"[:length] + + result = MyFactory.build() + assert result.description == "abc" + + +def test_callable_param_from_kwargs() -> None: + value1: int = 2 + value2: int = 1 + + class MyModel(BaseModel): + description: str + + class MyFactory(ModelFactory): + __model__ = MyModel + length = Param[int](is_callable=True, value1=value1, value2=value2) + + @post_generated + @classmethod + def description(cls, length: int) -> str: + return "abcd"[:length] + + result = MyFactory.build(length=lambda value1, value2: value1 + value2) + assert result.description == "abcd"[: value1 + value2] + + +def test_param_name_overlaps_model_field() -> None: + class MyModel(BaseModel): + name: str + other: int + + with pytest.raises(ConfigurationException) as exc: + + class MyFactory(ModelFactory): + __model__ = MyModel + name = Param[str]("foo") + other = 1 + + assert "name" in str(exc) + assert "other" not in str(exc) + + def test_post_generation() -> None: random_delta = timedelta(days=random.randint(0, 12), seconds=random.randint(13, 13000)) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py index d3861ff8..58e92129 100644 --- a/tests/test_type_coverage_generation.py +++ b/tests/test_type_coverage_generation.py @@ -12,10 +12,11 @@ from pydantic import BaseModel from polyfactory.decorators import post_generated -from polyfactory.exceptions import ParameterException +from polyfactory.exceptions import MissingBuildKwargException, ParameterException from polyfactory.factories.dataclass_factory import DataclassFactory from polyfactory.factories.pydantic_factory import ModelFactory from polyfactory.factories.typed_dict_factory import TypedDictFactory +from polyfactory.fields import Param from polyfactory.utils.types import NoneType from tests.test_pydantic_factory import IS_PYDANTIC_V1 @@ -39,6 +40,83 @@ class ProfileFactory(DataclassFactory[Profile]): assert isinstance(result, Profile) +def test_coverage_param__from_factory() -> None: + value = "Demo" + + @dataclass + class Profile: + name: str + high_score: Union[int, float] + dob: date + data: Union[str, date, int, float] + + class ProfileFactory(DataclassFactory[Profile]): + __model__ = Profile + last_name = Param[str](value) + + @post_generated + @classmethod + def name(cls, last_name: str) -> str: + return f"The {last_name}" + + results = list(ProfileFactory.coverage()) + + assert len(results) == 4 + + for result in results: + assert isinstance(result, Profile) + assert result.name == f"The {value}" + + +def test_coverage_param__from_kwargs() -> None: + value = "Demo" + + @dataclass + class Profile: + name: str + high_score: Union[int, float] + dob: date + data: Union[str, date, int, float] + + class ProfileFactory(DataclassFactory[Profile]): + __model__ = Profile + last_name = Param[str]() + + @post_generated + @classmethod + def name(cls, last_name: str) -> str: + return f"The {last_name}" + + results = list(ProfileFactory.coverage(last_name=value)) + + assert len(results) == 4 + + for result in results: + assert isinstance(result, Profile) + assert result.name == f"The {value}" + + +def test_coverage_param__from_kwargs__missing() -> None: + @dataclass + class Profile: + name: str + high_score: Union[int, float] + dob: date + data: Union[str, date, int, float] + + class ProfileFactory(DataclassFactory[Profile]): + __model__ = Profile + last_name = Param[str]() + + @post_generated + @classmethod + def name(cls, last_name: str) -> str: + return f"The {last_name}" + + with pytest.raises(MissingBuildKwargException): + list(ProfileFactory.coverage()) + + def test_coverage_tuple() -> None: @dataclass class Pair: