Description
Current Interface
The existing DifferentiateWith(f, backend)
interface in DifferentiationInterface.jl
presents a significant limitation: it inherently supports only single-argument functions. This design makes it cumbersome to:
- Differentiate functions with multiple arguments.
- Pass additional context or non-differentiable arguments (constants, pre-allocated caches) to the differentiation backend.
Proposed Interface
To address these limitations, we propose a more expressive interface for DifferentiateWith
:
Tfunc_sig = Tuple{typeof(f), T_arg1, T_arg2, ..., T_argN}
DifferentiateWith(Tfunc_sig, backend_to_use::AbstractADType)
Where Tfunc_sig
represents the function signature. The first element is the function f
itself (or its type), and subsequent elements T_arg1, T_arg2, ..., T_argN
represent the types of arguments to f
.
Argument Type Wrappers:
To provide more context to the backend about how each argument should be treated, we can introduce wrapper types:
- Default: Arguments are assumed to be "active" (i.e., to be differentiated with respect to).
Constant{T}
: Indicates that an argument of typeT
is a constant and should not be differentiated.Cache{T}
: Signals that an argument of typeT
is a pre-allocated cache that the backend can utilise.
Example Usage:
Consider a function f(x, y, z, c)
where x
and y
are active arguments, z
is a constant, and c
is a cache. The func_sig
would be constructed as:
Targtypes = (typeof(x), typeof(y), Constant{typeof(z)}, Cache{typeof(c)})
Tfunc_sig = Tuple{typeof(f), Targtypes...}
# or more explicitly:
# Tfunc_sig = Tuple{typeof(f), typeof(x), typeof(y), Constant{typeof(z)}, Cache{typeof(c)}}
dw = DifferentiateWith(Tfunc_sig, backend)
Internal Handling:
With this richer Tfunc_sig
, DifferentiateWith
can internally manage functions with multiple arguments. For backends that fundamentally operate on single-argument functions (e.g., by packing arguments into a tuple), DifferentiateWith
can perform this packing/unpacking automatically before invoking the backend's pushforward or pullback implementations. This keeps the backend APIs simpler while providing a user-friendly multi-argument interface.