Description
Previous ID | SR-12155 |
Radar | None |
Original Reporter | @dan-zheng |
Type | Sub-task |
Additional Detail from JIRA
Votes | 0 |
Component/s | |
Labels | Sub-task |
Assignee | None |
Priority | Medium |
md5: 15ee3b8b198d3b486765c49a84724e6e
Issue Description:
Action items
We should support:
-
For
@differentiable
attribute: referencing JVP/VJP from superclass. -
For
@derivative
attribute: referencing original function from superclass. -
Overriding JVP/VJP in subclass without overriding original function from superclass.
Overview
Currently, @derivative
attribute requires that the derivative function must be declared in the same type context as the original function.
class Super {
@differentiable(vjp: vjpFoo)
func foo(_ x: Float) -> Float {
return x
}
final func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (foo(x), { v in v })
}
}
class Sub : Super {
@differentiable(vjp: vjpFoo)
override func foo(_ x: Float) -> Float {
return x * x
}
}
tf-649.swift:12:24: error: 'vjpFoo' is not defined in the current type context
@differentiable(vjp: vjpFoo)
^
However, for class methods, it is reasonable to allow @derivative
attributes to specify original function from a superclass. This enables superclass to "override" derivative functions without overriding the original function.
class Super {
@differentiable
func foo(_ x: Float) -> Float {
return x
}
@derivative(of: foo)
final func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (foo(x), { v in v })
}
}
class Sub : Super {
// TF-649: Override JVP/VJP without overriding original function.
@derivative(of: foo)
final func vjpFoo2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (foo(x), { v in v })
}
}
Currently, the only way is for superclasses to override the original function and declare a @differentiable
attribute with overriding JVP/VJPs.
Example:
class Super {
@differentiable
func foo(_ x: Float) -> Float {
return x
}
@derivative(of: foo)
final func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (foo(x), { v in v })
}
}
class Sub : Super {
@derivative(of: foo)
override final func vjpFoo2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
print("Override derivative!")
return (foo(x), { v in v })
}
}