diff --git a/polyfactory/__init__.py b/polyfactory/__init__.py index 0a2269ea..6c1a904c 100644 --- a/polyfactory/__init__.py +++ b/polyfactory/__init__.py @@ -1,14 +1,16 @@ from .exceptions import ConfigurationException from .factories import BaseFactory -from .fields import Fixture, Ignore, PostGenerated, Require, Use +from .fields import AlwaysNone, Fixture, Ignore, NeverNone, PostGenerated, Require, Use from .persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol __all__ = ( + "AlwaysNone", "AsyncPersistenceProtocol", "BaseFactory", "ConfigurationException", "Fixture", "Ignore", + "NeverNone", "PostGenerated", "Require", "SyncPersistenceProtocol", diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 09058037..d18d5ef3 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -54,7 +54,7 @@ ) from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException from polyfactory.field_meta import Null -from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use +from polyfactory.fields import AlwaysNone, Fixture, Ignore, NeverNone, PostGenerated, Require, Use from polyfactory.utils.helpers import ( flatten_annotation, get_collection_type, @@ -334,6 +334,7 @@ def _handle_factory_field( # noqa: PLR0911 if isinstance(field_value, Fixture): return field_value.to_value() + # if a raw lambda is passed, invoke it if callable(field_value): return field_value() @@ -946,8 +947,16 @@ def should_set_none_value(cls, field_meta: FieldMeta) -> bool: :returns: A boolean determining whether 'None' should be set for the given field_meta. """ + field_value = hasattr(cls, field_meta.name) and getattr(cls, field_meta.name) + never_none = field_value and isinstance(field_value, NeverNone) + always_none = field_value and isinstance(field_value, AlwaysNone) + + if always_none: + return True + return ( cls.__allow_none_optionals__ + and not never_none and is_optional(field_meta.annotation) and create_random_boolean(random=cls.__random__) ) @@ -1021,13 +1030,16 @@ def _check_declared_fields_exist_in_model(cls) -> None: f"{field_name} is declared on the factory {cls.__name__}" f" but it is not part of the model {cls.__model__.__name__}" ) - if isinstance(field_value, (Use, PostGenerated, Ignore, Require)): + if isinstance(field_value, (Use, PostGenerated, Ignore, Require, NeverNone, AlwaysNone)): raise ConfigurationException(error_message) @classmethod def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: """Process the given kwargs and generate values for the factory's model. + If you need to deeply customize field values, you'll want to override this method. This is where values are + generated and assigned for the fields on the model. + :param kwargs: Any build kwargs. :returns: A dictionary of build results. @@ -1038,8 +1050,15 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: for field_meta in cls.get_model_fields(): field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta): - if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name): - field_value = getattr(cls, field_meta.name) + has_field_value = hasattr(cls, field_meta.name) + field_value = has_field_value and getattr(cls, field_meta.name) + + # NeverNone & AlwaysNone should be treated as a normally-generated field, since this changes logic + # within get_field_value. + excluded_field_value = has_field_value and isinstance(field_value, (NeverNone, AlwaysNone)) + + # TODO why do we need the BaseFactory check here, only dunder methods which are ignored would trigger this? # noqa: FIX002 + if has_field_value and not hasattr(BaseFactory, field_meta.name) and not excluded_field_value: if isinstance(field_value, Ignore): continue diff --git a/polyfactory/fields.py b/polyfactory/fields.py index 6598e450..00e85f52 100644 --- a/polyfactory/fields.py +++ b/polyfactory/fields.py @@ -22,8 +22,24 @@ class Require: """A factory field that marks an attribute as a required build-time kwarg.""" +class NeverNone: + """A factory field that marks as always generated, even if it's an optional.""" + + +class AlwaysNone: + """A factory field that marks as never generated, setting the value to None, regardless of if it's an optional field + + This is distinct from Ignore() which does not set a value for a field at all. If Ignore() is used and a default value + for a field is not set on the underlying model, then the field will not be set at all. + """ + + class Ignore: - """A factory field that marks an attribute as ignored.""" + """A factory field that marks an attribute as ignored. This prevents the factory generating any value for this field. + + If you are using this on a pydantic model this will cause the field to be omitted from the resulting pydantic model + if there is no default value set for the pydantic field. + """ class Use(Generic[P, T]): diff --git a/tests/test_nones.py b/tests/test_nones.py new file mode 100644 index 00000000..43d5ca9a --- /dev/null +++ b/tests/test_nones.py @@ -0,0 +1,28 @@ +from typing import Optional + +from pydantic import BaseModel + +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.fields import AlwaysNone, NeverNone + + +def test_never_none() -> None: + class MyModel(BaseModel): + name: Optional[str] + + class MyFactory(ModelFactory[MyModel]): + name = NeverNone() + + assert MyFactory.build().name is not None + + +def test_always_none() -> None: + class MyModel(BaseModel): + name: Optional[str] + + class MyFactory(ModelFactory[MyModel]): + name = AlwaysNone() + # NOTE `name = None` does not end up + + # field is still accessible even though there is + assert MyFactory.build().name is None diff --git a/tests/test_optional_model_field_inference.py b/tests/test_optional_model_field_inference.py index 7bb318fe..14f2282c 100644 --- a/tests/test_optional_model_field_inference.py +++ b/tests/test_optional_model_field_inference.py @@ -72,7 +72,7 @@ class TypedDictBase(TypedDict): (TypedDictFactory, TypedDictBase), ], ) -def test_modeL_inference_ok(base_factory: Type[BaseFactory], generic_arg: Type[Any]) -> None: +def test_model_inference_ok(base_factory: Type[BaseFactory], generic_arg: Type[Any]) -> None: class Foo(base_factory[generic_arg]): # type: ignore ...