Skip to content

[SR-12155] Allow @derivative attribute to specify original function from superclass #53512

Open
@dan-zheng

Description

@dan-zheng
Previous ID SR-12155
Radar None
Original Reporter @dan-zheng
Type Sub-task
Additional Detail from JIRA
Votes 0
Component/s
Labels Sub-task
Assignee None
Priority Medium

md5: 15ee3b8b198d3b486765c49a84724e6e

Issue Description:

Action items

We should support:

  • For @differentiable attribute: referencing JVP/VJP from superclass.

  • For @derivative attribute: referencing original function from superclass.

  • Overriding JVP/VJP in subclass without overriding original function from superclass.

Overview

Currently, @derivative attribute requires that the derivative function must be declared in the same type context as the original function.

class Super {
  @differentiable(vjp: vjpFoo)
  func foo(_ x: Float) -> Float {
    return x
  }
  final func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    return (foo(x), { v in v })
  }
}

class Sub : Super {
  @differentiable(vjp: vjpFoo)
  override func foo(_ x: Float) -> Float {
    return x * x
  }
}
tf-649.swift:12:24: error: 'vjpFoo' is not defined in the current type context
  @differentiable(vjp: vjpFoo)
                       ^

However, for class methods, it is reasonable to allow @derivative attributes to specify original function from a superclass. This enables superclass to "override" derivative functions without overriding the original function.

class Super {
  @differentiable
  func foo(_ x: Float) -> Float {
    return x
  }
  @derivative(of: foo)
  final func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    return (foo(x), { v in v })
  }
}

class Sub : Super {
  // TF-649: Override JVP/VJP without overriding original function.
  @derivative(of: foo)
  final func vjpFoo2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    return (foo(x), { v in v })
  }
}

Currently, the only way is for superclasses to override the original function and declare a @differentiable attribute with overriding JVP/VJPs.

Example:

class Super {
  @differentiable
  func foo(_ x: Float) -> Float {
    return x
  }

  @derivative(of: foo)
  final func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    return (foo(x), { v in v })
  }
}

class Sub : Super {
  @derivative(of: foo)
  override final func vjpFoo2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    print("Override derivative!")
    return (foo(x), { v in v })
  }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    AutoDiffattributesFeature: Declaration and type attributesclassFeature → type declarations: Class declarationscompilerThe Swift compiler itselffeatureA feature request or implementationmethodsFeature → functions: methods (member functions)parserArea → compiler: The legacy C++ parser

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions