Skip to content

feat(fields): implement Param and CallableParam for handling unmapped parameters that can be referenced during build #650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions docs/examples/fields/test_example_10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from dataclasses import dataclass

from polyfactory.decorators import post_generated
from polyfactory.factories import DataclassFactory
from polyfactory.fields import Param


@dataclass
class Person:
name: str
age_next_year: int


class PersonFactoryWithParamValueSpecifiedInFactory(DataclassFactory[Person]):
"""In this factory, the next_years_age_from_calculator must be passed at build time."""

next_years_age_from_calculator = Param[int](lambda age: age + 1, is_callable=True, age=20)

@post_generated
@classmethod
def age_next_year(cls, next_years_age_from_calculator: int) -> int:
return next_years_age_from_calculator


def test_factory__in_factory() -> None:
person = PersonFactoryWithParamValueSpecifiedInFactory.build()

assert isinstance(person, Person)
assert not hasattr(person, "next_years_age_from_calculator")
assert person.age_next_year == 21


class PersonFactoryWithParamValueSetAtBuild(DataclassFactory[Person]):
"""In this factory, the next_years_age_from_calculator must be passed at build time."""

next_years_age_from_calculator = Param[int](is_callable=True, age=20)

@post_generated
@classmethod
def age_next_year(cls, next_years_age_from_calculator: int) -> int:
return next_years_age_from_calculator


def test_factory__build_time() -> None:
person = PersonFactoryWithParamValueSpecifiedInFactory.build(next_years_age_from_calculator=lambda age: age + 1)

assert isinstance(person, Person)
assert not hasattr(person, "next_years_age_from_calculator")
assert person.age_next_year == 21
52 changes: 52 additions & 0 deletions docs/examples/fields/test_example_9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from dataclasses import dataclass
from typing import List

from polyfactory.decorators import post_generated
from polyfactory.factories import DataclassFactory
from polyfactory.fields import Param


@dataclass
class Pet:
name: str
sound: str


class PetFactoryWithParamValueSetAtBuild(DataclassFactory[Pet]):
"""In this factory, the name_choices must be passed at build time."""

name_choices = Param[List[str]]()

@post_generated
@classmethod
def name(cls, name_choices: List[str]) -> str:
return cls.__random__.choice(name_choices)


def test_factory__build_time() -> None:
names = ["Ralph", "Roxy"]
pet = PetFactoryWithParamValueSetAtBuild.build(name_choices=names)

assert isinstance(pet, Pet)
assert not hasattr(pet, "name_choices")
assert pet.name in names


class PetFactoryWithParamSpecififiedInFactory(DataclassFactory[Pet]):
"""In this factory, the name_choices are specified in the
factory and do not need to be passed at build time."""

name_choices = Param[List[str]](["Ralph", "Roxy"])

@post_generated
@classmethod
def name(cls, name_choices: List[str]) -> str:
return cls.__random__.choice(name_choices)


def test_factory__in_factory() -> None:
pet = PetFactoryWithParamSpecififiedInFactory.build()

assert isinstance(pet, Pet)
assert not hasattr(pet, "name_choices")
assert pet.name in ["Ralph", "Roxy"]
22 changes: 22 additions & 0 deletions docs/usage/fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,28 @@ The signature for use is: ``cb: Callable, *args, **defaults`` it can receive an
callable should be: ``name: str, values: dict[str, Any], *args, **defaults``. The already generated values are mapped by
name in the values dictionary.


The ``Param`` Field
-------------------

The :class:`Param <polyfactory.fields.Param>` class denotes a parameter that can be referenced by other fields at build but whose value is not set on the final object. This is useful for passing values needed by other factory fields but that are not part of object being built.

A Param type can be either a constant or a callable. If a callable is used, it will be executed at the beginning of build and its return value will be used as the value for the field. Optional keyword arguments may be passed to the callable as part of the field definition on the factory. Any additional keyword arguments passed to the Param constructor will also not be mapped into the final object.

The Param type allows for flexibility in that it can either accept a value at the definition of the factory, or its value can be set at build time. If a value is provided at build time, it will take precedence over the value provided at the definition of the factory (if any).

If neither a value is provided at the definition of the factory nor at build time, an exception will be raised. Likewise, a Param cannot have the same name as any other model field.

.. literalinclude:: /examples/fields/test_example_9.py
:caption: Using the ``Param`` field with a constant
:language: python

.. literalinclude:: /examples/fields/test_example_10.py
:caption: Using the ``Param`` field with a callable
:language: python



Factories as Fields
-------------------

Expand Down
4 changes: 4 additions & 0 deletions polyfactory/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ class MissingBuildKwargException(FactoryException):

class MissingDependencyException(FactoryException, ImportError):
"""Missing dependency exception - used when a dependency is not installed"""


class MissingParamException(FactoryException):
"""Missing parameter exception - used when a required Param is not provided"""
61 changes: 56 additions & 5 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import EnumMeta
from functools import partial
from functools import lru_cache, partial
from importlib import import_module
from ipaddress import (
IPv4Address,
Expand Down Expand Up @@ -52,9 +52,14 @@
MIN_COLLECTION_LENGTH,
RANDOMIZE_COLLECTION_LENGTH,
)
from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException
from polyfactory.exceptions import (
ConfigurationException,
MissingBuildKwargException,
MissingParamException,
ParameterException,
)
from polyfactory.field_meta import Null
from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use
from polyfactory.fields import Fixture, Ignore, Param, PostGenerated, Require, Use
from polyfactory.utils.helpers import (
flatten_annotation,
get_collection_type,
Expand Down Expand Up @@ -230,6 +235,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901, PL
raise ConfigurationException(
msg,
)
cls._check_overlapping_param_names()
if cls.__check_model__:
cls._check_declared_fields_exist_in_model()
else:
Expand Down Expand Up @@ -1024,6 +1030,44 @@ def _check_declared_fields_exist_in_model(cls) -> None:
if isinstance(field_value, (Use, PostGenerated, Ignore, Require)):
raise ConfigurationException(error_message)

@classmethod
def _handle_factory_params(cls, params: dict[str, Param[Any]], **kwargs: Any) -> dict[str, Any]:
"""Get the factory parameters.

:param params: A dict of field name to Param instances.
:param kwargs: Any build kwargs.

:returns: A dict of fieldname mapped to realized Param values.
"""

try:
return {name: param.to_value(kwargs.get(name, Null)) for name, param in params.items()}
except MissingParamException as e:
msg = "Missing required kwargs"
raise MissingBuildKwargException(msg) from e

@classmethod
@lru_cache(maxsize=None)
def get_factory_params(cls) -> dict[str, Param[Any]]:
"""Get the factory parameters.

:returns: A dict of field name to Param instances.
"""
return {name: item for name, item in cls.__dict__.items() if isinstance(item, Param)}

@classmethod
def _check_overlapping_param_names(cls) -> None:
"""Checks if there are overlapping param names with model fields.


:raises: ConfigurationException
"""
model_fields_names = {field_meta.name for field_meta in cls.get_model_fields()}
overlapping_params = set(cls.get_factory_params().keys()) & model_fields_names
if overlapping_params:
msg = f"Factory Params {', '.join(overlapping_params)} overlap with model fields"
raise ConfigurationException(msg)

@classmethod
def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
"""Process the given kwargs and generate values for the factory's model.
Expand All @@ -1035,6 +1079,9 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
"""
result, generate_post, _build_context = cls._get_initial_variables(kwargs)

params = cls.get_factory_params()
result.update(cls._handle_factory_params(params, **kwargs))

for field_meta in cls.get_model_fields():
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)
if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta):
Expand Down Expand Up @@ -1071,7 +1118,7 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
for field_name, post_generator in generate_post.items():
result[field_name] = post_generator.to_value(field_name, result)

return result
return {key: value for key, value in result.items() if key not in params}

@classmethod
def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
Expand All @@ -1085,6 +1132,9 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
"""
result, generate_post, _build_context = cls._get_initial_variables(kwargs)

params = cls.get_factory_params()
result.update(cls._handle_factory_params(params, **kwargs))

for field_meta in cls.get_model_fields():
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)

Expand Down Expand Up @@ -1120,7 +1170,8 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
for resolved in resolve_kwargs_coverage(result):
for field_name, post_generator in generate_post.items():
resolved[field_name] = post_generator.to_value(field_name, resolved)
yield resolved

yield {key: value for key, value in resolved.items() if key not in params}

@classmethod
def build(cls, *_: Any, **kwargs: Any) -> T:
Expand Down
82 changes: 81 additions & 1 deletion polyfactory/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

from typing_extensions import ParamSpec

from polyfactory.exceptions import ParameterException
from polyfactory.exceptions import MissingParamException, ParameterException
from polyfactory.field_meta import Null
from polyfactory.utils import deprecation
from polyfactory.utils.predicates import is_safe_subclass

T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")


Expand Down Expand Up @@ -118,3 +120,81 @@ def to_value(self) -> Any:
if self.size is not None:
return factory.batch(self.size, **self.kwargs)
return factory.build(**self.kwargs)


class Param(Generic[T]):
"""A constant parameter that can be used by other fields but will not be
passed to the final object.

If a value for the parameter is not passed in the field's definition, it must
be passed at build time. Otherwise, a MissingParamException will be raised.
"""

__slots__ = ("is_callable", "kwargs", "param")

def __init__(
self, param: T | Callable[..., T] | type[Null] = Null, is_callable: bool = False, **kwargs: Any
) -> None:
"""Designate a parameter.

:param param: A constant or an unpassed value that can be referenced later
"""
if param is not Null and is_callable and not callable(param):
msg = "If an object is passed to param, a callable must be passed when is_callable is True"
raise ParameterException(msg)
if not is_callable and kwargs:
msg = "kwargs can only be used with callable parameters"
raise ParameterException(msg)

self.param = param
self.is_callable = is_callable
self.kwargs = kwargs

def to_value(self, from_build: T | Callable[..., T] | type[Null] = Null, **kwargs: Any) -> T:
"""Determines the value to use at build time

If a value was passed to the constructor, it will be used. Otherwise, the value
passed at build time will be used. If no value was passed at build time, a
MissingParamException will be raised.

:param args: from_build: The value passed at build time (if any).
:returns: The value
:raises: MissingParamException
"""
# If no param is passed at initialization, a value must be passed now
if self.param is Null:
# from_build was passed, so determine the value based on whether or
# not we're supposed to call a callable
if from_build is not Null:
return (
cast("T", from_build)
if not self.is_callable
else cast("Callable[..., T]", from_build)(**{**self.kwargs, **kwargs})
)

# Otherwise, raise an exception
msg = (
"Expected a parameter value to be passed at build time"
if not self.is_callable
else "Expected a callable to be passed at build time"
)
raise MissingParamException(msg)
# A param was passed at initialization
if self.is_callable:
# In this case, we are going to call the callable, but we can still
# override if are passed a callable at build
if from_build is not Null:
if callable(from_build):
return cast("Callable[..., T]", from_build)(**{**self.kwargs, **kwargs})

# If we were passed a value at build that isn't a callable, raise
# an exception
msg = "The value passed at build time is not callable"
raise TypeError(msg)

# Otherwise, return the value passed at initialization
return cast("Callable[..., T]", self.param)(**{**self.kwargs, **kwargs})

# Inthis case, we are not using a callable, so return either the value
# passed at build time or initialization
return cast("T", self.param) if from_build is Null else cast("T", from_build)
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,8 @@ class Place(Base):
id: Any = Column(Integer(), primary_key=True)
numeric_field: Any = Column(Location, nullable=False)

factory = SQLAlchemyFactory.create_factory(Place)
with pytest.raises(
ParameterException,
match="Unsupported type engine: Location()",
):
factory.build()
SQLAlchemyFactory.create_factory(Place)
Loading
Loading