Skip to content

Enhancing DifferentiateWith Interface #806

@yebai

Description

@yebai

Current Interface

The existing DifferentiateWith(f, backend) interface in DifferentiationInterface.jl presents a significant limitation: it inherently supports only single-argument functions. This design makes it cumbersome to:

  • Differentiate functions with multiple arguments.
  • Pass additional context or non-differentiable arguments (constants, pre-allocated caches) to the differentiation backend.

Proposed Interface

To address these limitations, we propose a more expressive interface for DifferentiateWith:

Tfunc_sig = Tuple{typeof(f), T_arg1, T_arg2, ..., T_argN}
DifferentiateWith(Tfunc_sig, backend_to_use::AbstractADType)

Where Tfunc_sig represents the function signature. The first element is the function f itself (or its type), and subsequent elements T_arg1, T_arg2, ..., T_argN represent the types of arguments to f.

Argument Type Wrappers:

To provide more context to the backend about how each argument should be treated, we can introduce wrapper types:

  • Default: Arguments are assumed to be "active" (i.e., to be differentiated with respect to).
  • Constant{T}: Indicates that an argument of type T is a constant and should not be differentiated.
  • Cache{T}: Signals that an argument of type T is a pre-allocated cache that the backend can utilise.

Example Usage:

Consider a function f(x, y, z, c) where x and y are active arguments, z is a constant, and c is a cache. The func_sig would be constructed as:

Targtypes = (typeof(x), typeof(y), Constant{typeof(z)}, Cache{typeof(c)})
Tfunc_sig = Tuple{typeof(f), Targtypes...}
# or more explicitly:
# Tfunc_sig = Tuple{typeof(f), typeof(x), typeof(y), Constant{typeof(z)}, Cache{typeof(c)}}

dw = DifferentiateWith(Tfunc_sig, backend)

Internal Handling:

With this richer Tfunc_sig, DifferentiateWith can internally manage functions with multiple arguments. For backends that fundamentally operate on single-argument functions (e.g., by packing arguments into a tuple), DifferentiateWith can perform this packing/unpacking automatically before invoking the backend's pushforward or pullback implementations. This keeps the backend APIs simpler while providing a user-friendly multi-argument interface.

Metadata

Metadata

Assignees

No one assigned

    Labels

    coreRelated to the core utilities of the package

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions