Skip to content

Commit 969ee0a

Browse files
kylebeggsJoshuaLampertgithub-actions[bot]mikeingold
authored
add AD (Enzyme) support via MeshIntegralsEnzymeExt (#152)
* add Enzyme as a potential differentiation method for the jacobian * refactor check for enzyme support * add FP to _default_diff_method * add `using Enzyme` to benchmarks.jl * update CoordRefSystems.jl compat Co-authored-by: Joshua Lampert <[email protected]> * add Enzyme to Benchmark Project.toml * fix Meshes compat in Benchmark Project.toml Co-authored-by: Joshua Lampert <[email protected]> * use import Enzyme, not using Enzyme Co-authored-by: Joshua Lampert <[email protected]> * fix typo in Benchmarks Project.toml * remove Meshes version check in combinations.jl * Apply format suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/Project.toml Co-authored-by: Joshua Lampert <[email protected]> * Bump compat of Enzyme to v0.13.19 * test supports_autoenzyme to combinations; test both backends for wrong dims * Restore recently-updated FiniteDifference constructors * Add docstrings, formatting * Formatting * Add test for two-arg jacobian * Use rest of MeshIntegrals namespace * Disambiguate use of jacobian * fix test * use `import Enzyme` Co-authored-by: Joshua Lampert <[email protected]> * use `import Enzyme` Co-authored-by: Joshua Lampert <[email protected]> * remove unneeded MeshIntegrals.jl Co-authored-by: Joshua Lampert <[email protected]> --------- Co-authored-by: Joshua Lampert <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Michael Ingold <[email protected]>
1 parent c7c0a47 commit 969ee0a

15 files changed

+175
-68
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@ docs/site/
2222
# committed for packages, but should be committed for applications that require a static
2323
# environment.
2424
Manifest.toml
25+
26+
# development related
27+
.vscode
28+
dev

Project.toml

+9-2
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,20 @@ Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
1212
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
1313
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
1414

15+
[weakdeps]
16+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
17+
18+
[extensions]
19+
MeshIntegralsEnzymeExt = "Enzyme"
20+
1521
[compat]
1622
CliffordNumbers = "0.1.9"
17-
CoordRefSystems = "0.12, 0.13, 0.14, 0.15, 0.16"
23+
CoordRefSystems = "0.15, 0.16"
24+
Enzyme = "0.13.19"
1825
FastGaussQuadrature = "1"
1926
HCubature = "1.5"
2027
LinearAlgebra = "1"
21-
Meshes = "0.50, 0.51, 0.52"
28+
Meshes = "0.51.20, 0.52"
2229
QuadGK = "2.1.1"
2330
Unitful = "1.19"
2431
julia = "1.9"

benchmark/Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
34
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
45
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
56
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
67

78
[compat]
89
BenchmarkTools = "1.5"
10+
Enzyme = "0.13.19"
911
LinearAlgebra = "1"
10-
Meshes = "0.50, 0.51, 0.52"
12+
Meshes = "0.51.20, 0.52"
1113
Unitful = "1.19"
1214
julia = "1.9"

benchmark/benchmarks.jl

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using LinearAlgebra
33
using Meshes
44
using MeshIntegrals
55
using Unitful
6+
import Enzyme
67

78
const SUITE = BenchmarkGroup()
89

ext/MeshIntegralsEnzymeExt.jl

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
module MeshIntegralsEnzymeExt
2+
3+
using MeshIntegrals: MeshIntegrals, AutoEnzyme
4+
using Meshes: Meshes
5+
using Enzyme: Enzyme
6+
7+
function MeshIntegrals.jacobian(
8+
geometry::Meshes.Geometry,
9+
ts::Union{AbstractVector{T}, Tuple{T, Vararg{T}}},
10+
::AutoEnzyme
11+
) where {T <: AbstractFloat}
12+
Dim = Meshes.paramdim(geometry)
13+
if Dim != length(ts)
14+
throw(ArgumentError("ts must have same number of dimensions as geometry."))
15+
end
16+
return Meshes.to.(Enzyme.jacobian(Enzyme.Forward, geometry, ts...))
17+
end
18+
19+
end

src/MeshIntegrals.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import QuadGK
1010
import Unitful
1111

1212
include("differentiation.jl")
13-
export DifferentiationMethod, FiniteDifference, jacobian
13+
export DifferentiationMethod, FiniteDifference, AutoEnzyme, jacobian
1414

1515
include("utils.jl")
1616

src/differentiation.jl

+11-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ A category of types used to specify the desired method for calculating derivativ
99
Derivatives are used to form Jacobian matrices when calculating the differential
1010
element size throughout the integration region.
1111
12-
See also [`FiniteDifference`](@ref).
12+
See also [`FiniteDifference`](@ref), [`AutoEnzyme`](@ref).
1313
"""
1414
abstract type DifferentiationMethod end
1515

@@ -27,8 +27,14 @@ end
2727
FiniteDifference{T}() where {T <: AbstractFloat} = FiniteDifference{T}(T(1e-6))
2828
FiniteDifference() = FiniteDifference{Float64}()
2929

30+
"""
31+
AutoEnzyme()
32+
33+
Use to specify use of the Enzyme.jl for calculating derivatives.
34+
"""
35+
struct AutoEnzyme <: DifferentiationMethod end
36+
3037
# Future Support:
31-
# struct AutoEnzyme <: DifferentiationMethod end
3238
# struct AutoZygote <: DifferentiationMethod end
3339

3440
################################################################################
@@ -52,7 +58,7 @@ function jacobian(
5258
geometry::G,
5359
ts::Union{AbstractVector{T}, Tuple{T, Vararg{T}}}
5460
) where {G <: Geometry, T <: AbstractFloat}
55-
return jacobian(geometry, ts, _default_diff_method(G))
61+
return jacobian(geometry, ts, _default_diff_method(G, T))
5662
end
5763

5864
function jacobian(
@@ -68,7 +74,7 @@ function jacobian(
6874
# Get the partial derivative along the n'th axis via finite difference
6975
# approximation, where ts is the current parametric position
7076
function ∂ₙr(ts, n, ε)
71-
# Build left/right parametric coordinates with non-allocating iterators
77+
# Build left/right parametric coordinates with non-allocating iterators
7278
left = Iterators.map(((i, t),) -> i == n ? t - ε : t, enumerate(ts))
7379
right = Iterators.map(((i, t),) -> i == n ? t + ε : t, enumerate(ts))
7480
# Select orientation of finite-diff
@@ -107,7 +113,7 @@ possible and finite difference approximations otherwise.
107113
function differential(
108114
geometry::G,
109115
ts::Union{AbstractVector{T}, Tuple{T, Vararg{T}}},
110-
diff_method::DifferentiationMethod = _default_diff_method(G)
116+
diff_method::DifferentiationMethod = _default_diff_method(G, T)
111117
) where {G <: Geometry, T <: AbstractFloat}
112118
J = Iterators.map(_KVector, jacobian(geometry, ts, diff_method))
113119
return LinearAlgebra.norm(foldl(, J))

src/integral.jl

+11-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
################################################################################
44

55
"""
6-
integral(f, geometry[, rule]; diff_method=_default_method(geometry), FP=Float64)
6+
integral(f, geometry[, rule]; diff_method=_default_diff_method(geometry, FP), FP=Float64)
77
88
Numerically integrate a given function `f(::Point)` over the domain defined by
99
a `geometry` using a particular numerical integration `rule` with floating point
@@ -16,7 +16,7 @@ precision of type `FP`.
1616
`GaussKronrod()` in 1D and `HAdaptiveCubature()` else)
1717
1818
# Keyword Arguments
19-
- `diff_method::DifferentiationMethod = _default_method(geometry)`: the method to
19+
- `diff_method::DifferentiationMethod = _default_diff_method(geometry, FP)`: the method to
2020
use for calculating Jacobians that are used to calculate differential elements
2121
- `FP = Float64`: the floating point precision desired.
2222
"""
@@ -42,8 +42,10 @@ function _integral(
4242
geometry,
4343
rule::GaussKronrod;
4444
FP::Type{T} = Float64,
45-
diff_method::DM = _default_diff_method(geometry)
45+
diff_method::DM = _default_diff_method(geometry, FP)
4646
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
47+
_check_diff_method_support(geometry, diff_method)
48+
4749
# Implementation depends on number of parametric dimensions over which to integrate
4850
N = Meshes.paramdim(geometry)
4951
if N == 1
@@ -70,8 +72,10 @@ function _integral(
7072
geometry,
7173
rule::GaussLegendre;
7274
FP::Type{T} = Float64,
73-
diff_method::DM = _default_diff_method(geometry)
75+
diff_method::DM = _default_diff_method(geometry, FP)
7476
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
77+
_check_diff_method_support(geometry, diff_method)
78+
7579
N = Meshes.paramdim(geometry)
7680

7781
# Get Gauss-Legendre nodes and weights of type FP for a region [-1,1]ᴺ
@@ -99,8 +103,10 @@ function _integral(
99103
geometry,
100104
rule::HAdaptiveCubature;
101105
FP::Type{T} = Float64,
102-
diff_method::DM = _default_diff_method(geometry)
106+
diff_method::DM = _default_diff_method(geometry, FP)
103107
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
108+
_check_diff_method_support(geometry, diff_method)
109+
104110
N = Meshes.paramdim(geometry)
105111

106112
integrand(ts) = f(geometry(ts...)) * differential(geometry, ts, diff_method)

src/specializations/BezierCurve.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@ function integral(
3636
curve::Meshes.BezierCurve,
3737
rule::IntegrationRule;
3838
alg::Meshes.BezierEvalMethod = Meshes.Horner(),
39+
FP::Type{T} = Float64,
40+
diff_method::DM = _default_diff_method(curve, FP),
3941
kwargs...
40-
)
42+
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
43+
_check_diff_method_support(curve, diff_method)
44+
4145
# Generate a _ParametricGeometry whose parametric function auto-applies the alg kwarg
4246
param_curve = _ParametricGeometry(_parametric(curve, alg), Meshes.paramdim(curve))
4347

4448
# Integrate the _ParametricGeometry using the standard methods
45-
return _integral(f, param_curve, rule; kwargs...)
49+
return _integral(f, param_curve, rule; diff_method = diff_method, FP = FP, kwargs...)
4650
end
4751

4852
################################################################################

src/specializations/CylinderSurface.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,22 @@ function integral(
1212
f,
1313
cyl::Meshes.CylinderSurface,
1414
rule::I;
15+
FP::Type{T} = Float64,
16+
diff_method::DM = _default_diff_method(cyl, FP),
1517
kwargs...
16-
) where {I <: IntegrationRule}
18+
) where {I <: IntegrationRule, DM <: DifferentiationMethod, T <: AbstractFloat}
19+
_check_diff_method_support(cyl, diff_method)
20+
1721
# The generic method only parametrizes the sides
18-
sides = _integral(f, cyl, rule; kwargs...)
22+
sides = _integral(f, cyl, rule; diff_method = diff_method, FP = FP, kwargs...)
1923

2024
# Integrate the Disk at the top
2125
disk_top = Meshes.Disk(cyl.top, cyl.radius)
22-
top = _integral(f, disk_top, rule; kwargs...)
26+
top = _integral(f, disk_top, rule; diff_method = diff_method, FP = FP, kwargs...)
2327

2428
# Integrate the Disk at the bottom
2529
disk_bottom = Meshes.Disk(cyl.bot, cyl.radius)
26-
bottom = _integral(f, disk_bottom, rule; kwargs...)
30+
bottom = _integral(f, disk_bottom, rule; diff_method = diff_method, FP = FP, kwargs...)
2731

2832
return sides + top + bottom
2933
end

src/utils.jl

+41-6
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,50 @@ end
1212
# DifferentiationMethod
1313
################################################################################
1414

15-
# Return the default DifferentiationMethod instance for a particular geometry type
15+
"""
16+
supports_autoenzyme(geometry)
17+
18+
Return whether a geometry (or geometry type) has a parametric function that can be
19+
differentiated with Enzyme. See GitHub Issue #154 for more information.
20+
"""
21+
supports_autoenzyme(::Type{<:Meshes.Geometry}) = true
22+
supports_autoenzyme(::Type{<:Meshes.BezierCurve}) = false
23+
supports_autoenzyme(::Type{<:Meshes.CylinderSurface}) = false
24+
supports_autoenzyme(::Type{<:Meshes.Cylinder}) = false
25+
supports_autoenzyme(::Type{<:Meshes.ParametrizedCurve}) = false
26+
supports_autoenzyme(::G) where {G <: Geometry} = supports_autoenzyme(G)
27+
28+
"""
29+
_check_diff_method_support(::Geometry, ::DifferentiationMethod) -> nothing
30+
31+
Throw an error if incompatible geometry-diff_method combination detected.
32+
"""
33+
_check_diff_method_support(::Geometry, ::DifferentiationMethod) = nothing
34+
function _check_diff_method_support(geometry::Geometry, ::AutoEnzyme)
35+
if !supports_autoenzyme(geometry)
36+
throw(ArgumentError("AutoEnzyme not supported for this geometry."))
37+
end
38+
end
39+
40+
"""
41+
_default_diff_method(geometry, FP)
42+
43+
Return an instance of the default DifferentiationMethod for a particular geometry
44+
(or geometry type) and floating point type.
45+
"""
1646
function _default_diff_method(
17-
g::Type{G}
18-
) where {G <: Geometry}
19-
return FiniteDifference()
47+
g::Type{G}, FP::Type{T}
48+
) where {G <: Geometry, T <: AbstractFloat}
49+
if supports_autoenzyme(g) && FP <: Union{Float32, Float64}
50+
AutoEnzyme()
51+
else
52+
FiniteDifference()
53+
end
2054
end
2155

22-
# Return the default DifferentiationMethod instance for a particular geometry instance
23-
_default_diff_method(g::G) where {G <: Geometry} = _default_diff_method(G)
56+
function _default_diff_method(::G, ::Type{T}) where {G <: Geometry, T <: AbstractFloat}
57+
_default_diff_method(G, T)
58+
end
2459

2560
################################################################################
2661
# Numerical Tools

test/Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
CoordRefSystems = "b46f11dc-f210-4604-bfba-323c1ec968cb"
4+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
45
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
56
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
67
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
@@ -12,10 +13,11 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
1213

1314
[compat]
1415
Aqua = "0.7, 0.8"
15-
CoordRefSystems = "0.12, 0.13, 0.14, 0.15, 0.16"
16+
CoordRefSystems = "0.15, 0.16"
17+
Enzyme = "0.13.19"
1618
ExplicitImports = "1.6.0"
1719
LinearAlgebra = "1"
18-
Meshes = "0.50, 0.51, 0.52"
20+
Meshes = "0.51.20, 0.52"
1921
SpecialFunctions = "2"
2022
TestItemRunner = "1"
2123
TestItems = "1"

0 commit comments

Comments
 (0)