Skip to content

Commit 2d8938d

Browse files
Fix bug where integral fails when Enzyme ext not loaded (#160)
* Add explicit finitedifference test, shorten line lengths * supports_autoenzyme always false unless ext is loaded * Use explicit package name * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Tweak supports_autoenzme any method, comments, style * Add tests without Enzyme ext loaded, and more comprehensive tests when loaded * Code style * Update utils.jl * Update utils.jl * Update testitem name for consistency * Avoid circular definition/verification by requiring explicit autoenzyme support flag in testing * Drop newly-redundant enzyme support checks * Add a supports_autoenzyme check * Use blacklist in SupportStatus to determine autoenzyme * Remove problematic testitem * Remove leftover autoenzyme kwarg, code style tweaks * Add a test * Update changelog --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 67fc779 commit 2d8938d

File tree

5 files changed

+91
-43
lines changed

5 files changed

+91
-43
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212

1313
- Implemented a more efficient internal parametric transformation for `Meshes.Tetrahedron`, resulting in about an 80% integral performance improvement.
1414

15+
### Fixed
16+
17+
- Fixed a bug where `integral` would default to `diff_method=AutoEnzyme()` even when the Enzyme extension isn't loaded.
18+
1519

1620
## [0.16.0] - 2024-12-14
1721

ext/MeshIntegralsEnzymeExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,12 @@ function MeshIntegrals.jacobian(
1616
return Meshes.to.(Enzyme.jacobian(Enzyme.Forward, geometry, ts...))
1717
end
1818

19+
# Supports all geometries except for those that throw errors
20+
# See GitHub Issue #154 for more information
21+
MeshIntegrals.supports_autoenzyme(::Type{<:Meshes.Geometry}) = true
22+
MeshIntegrals.supports_autoenzyme(::Type{<:Meshes.BezierCurve}) = false
23+
MeshIntegrals.supports_autoenzyme(::Type{<:Meshes.CylinderSurface}) = false
24+
MeshIntegrals.supports_autoenzyme(::Type{<:Meshes.Cylinder}) = false
25+
MeshIntegrals.supports_autoenzyme(::Type{<:Meshes.ParametrizedCurve}) = false
26+
1927
end

src/utils.jl

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,48 +13,63 @@ end
1313
################################################################################
1414

1515
"""
16-
supports_autoenzyme(geometry)
16+
supports_autoenzyme(geometry::Geometry)
17+
supports_autoenzyme(type::Type{<:Geometry})
1718
18-
Return whether a geometry (or geometry type) has a parametric function that can be
19+
Return whether a geometry or geometry type has a parametric function that can be
1920
differentiated with Enzyme. See GitHub Issue #154 for more information.
2021
"""
21-
supports_autoenzyme(::Type{<:Meshes.Geometry}) = true
22-
supports_autoenzyme(::Type{<:Meshes.BezierCurve}) = false
23-
supports_autoenzyme(::Type{<:Meshes.CylinderSurface}) = false
24-
supports_autoenzyme(::Type{<:Meshes.Cylinder}) = false
25-
supports_autoenzyme(::Type{<:Meshes.ParametrizedCurve}) = false
22+
function supports_autoenzyme end
23+
24+
# Returns false for all geometries when Enzyme extension is not loaded
25+
supports_autoenzyme(::Type{<:Any}) = false
26+
27+
# If provided a geometry instance, re-run with the type as argument
2628
supports_autoenzyme(::G) where {G <: Geometry} = supports_autoenzyme(G)
2729

2830
"""
2931
_check_diff_method_support(::Geometry, ::DifferentiationMethod) -> nothing
3032
31-
Throw an error if incompatible geometry-diff_method combination detected.
33+
Throw an error if incompatible combination {geometry, diff_method} detected.
3234
"""
33-
_check_diff_method_support(::Geometry, ::DifferentiationMethod) = nothing
35+
function _check_diff_method_support end
36+
37+
# If diff_method == Enzyme, then perform check
3438
function _check_diff_method_support(geometry::Geometry, ::AutoEnzyme)
3539
if !supports_autoenzyme(geometry)
3640
throw(ArgumentError("AutoEnzyme not supported for this geometry."))
3741
end
3842
end
3943

44+
# If diff_method != AutoEnzyme, then do nothing
45+
_check_diff_method_support(::Geometry, ::DifferentiationMethod) = nothing
46+
4047
"""
4148
_default_diff_method(geometry, FP)
4249
4350
Return an instance of the default DifferentiationMethod for a particular geometry
4451
(or geometry type) and floating point type.
4552
"""
4653
function _default_diff_method(
47-
g::Type{G}, FP::Type{T}
48-
) where {G <: Geometry, T <: AbstractFloat}
49-
if supports_autoenzyme(g) && FP <: Union{Float32, Float64}
50-
AutoEnzyme()
54+
::Type{G},
55+
::Type{FP}
56+
) where {G <: Geometry, FP <: AbstractFloat}
57+
# Enzyme only works with these FP types
58+
uses_Enzyme_supported_FP_type = (FP <: Union{Float32, Float64})
59+
60+
if supports_autoenzyme(G) && uses_Enzyme_supported_FP_type
61+
return AutoEnzyme()
5162
else
52-
FiniteDifference()
63+
return FiniteDifference()
5364
end
5465
end
5566

56-
function _default_diff_method(::G, ::Type{T}) where {G <: Geometry, T <: AbstractFloat}
57-
_default_diff_method(G, T)
67+
# If provided a geometry instance, re-run with the type as argument
68+
function _default_diff_method(
69+
::G,
70+
::Type{FP}
71+
) where {G <: Geometry, FP <: AbstractFloat}
72+
return _default_diff_method(G, FP)
5873
end
5974

6075
################################################################################

test/combinations.jl

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,12 @@ This file includes tests for:
4848
end
4949

5050
# Shortcut constructor for geometries with typical support structure
51-
function SupportStatus(g::Geometry, autoenzyme = MeshIntegrals.supports_autoenzyme(g))
52-
N = Meshes.paramdim(g)
51+
function SupportStatus(geometry::G;) where {G <: Geometry}
52+
# Check whether AutoEnzyme should be supported, i.e. not on blacklist
53+
unsupported_Gs = Union{BezierCurve, Cylinder, CylinderSurface, ParametrizedCurve}
54+
autoenzyme = !(G <: unsupported_Gs)
55+
56+
N = Meshes.paramdim(geometry)
5357
if N == 1
5458
# line/curve
5559
aliases = Bool.((1, 0, 0))
@@ -70,14 +74,17 @@ This file includes tests for:
7074
aliases = Bool.((0, 0, 0))
7175
rules = Bool.((0, 1, 1))
7276
return SupportStatus(aliases..., rules..., autoenzyme)
73-
end
74-
end
77+
end #if
78+
end # function
79+
80+
# Generate applicable tests for this geometry
81+
function runtests(testable::TestableGeometry; rtol = sqrt(eps()))
82+
# Determine support matrix for this geometry
83+
supports = SupportStatus(testable.geometry)
84+
85+
# Ensure consistency of SupportStatus with supports_autoenzyme
86+
@test MeshIntegrals.supports_autoenzyme(testable.geometry) == supports.autoenzyme
7587

76-
function runtests(
77-
testable::TestableGeometry,
78-
supports::SupportStatus = SupportStatus(testable.geometry);
79-
rtol = sqrt(eps())
80-
)
8188
# Test alias functions
8289
for alias in (lineintegral, surfaceintegral, volumeintegral)
8390
# if supports.alias
@@ -86,15 +93,14 @@ This file includes tests for:
8693
else
8794
@test_throws "not supported" alias(testable.integrand, testable.geometry)
8895
end
89-
end
96+
end # for
9097

98+
# Iteratively test all IntegrationRules
9199
iter_rules = (
92100
(supports.gausskronrod, GaussKronrod()),
93101
(supports.gausslegendre, GaussLegendre(100)),
94102
(supports.hadaptivecubature, HAdaptiveCubature())
95103
)
96-
97-
# Test rules
98104
for (supported, rule) in iter_rules
99105
if supported
100106
# Scalar integrand
@@ -116,19 +122,21 @@ This file includes tests for:
116122
end
117123
end # for
118124

125+
# Iteratively test all DifferentiationMethods
119126
iter_diff_methods = (
120-
(supports.autoenzyme, AutoEnzyme()),
127+
(true, FiniteDifference()),
128+
(supports.autoenzyme, AutoEnzyme())
121129
)
130+
for (supported, method) in iter_diff_methods
131+
# Aliases for improved code readability
132+
f = testable.integrand
133+
geometry = testable.geometry
134+
sol = testable.solution
122135

123-
for (supported, diff_method) in iter_diff_methods
124136
if supported
125-
@test integral(
126-
testable.integrand, testable.geometry; diff_method = diff_method)testable.solution rtol=rtol
127-
@test MeshIntegrals.supports_autoenzyme(testable.geometry) == true
137+
@test integral(f, geometry; diff_method = method)sol rtol=rtol
128138
else
129-
@test_throws "not supported" integral(
130-
testable.integrand, testable.geometry; diff_method = diff_method)
131-
@test MeshIntegrals.supports_autoenzyme(testable.geometry) == false
139+
@test_throws "not supported" integral(f, geometry; diff_method = method)
132140
end
133141
end # for
134142
end # function
@@ -465,7 +473,7 @@ end
465473
runtests(testable)
466474
end
467475

468-
@testitem "ParametrizedCurve" setup=[Combinations] begin
476+
@testitem "Meshes.ParametrizedCurve" setup=[Combinations] begin
469477
using CoordRefSystems: Polar
470478

471479
# Geometries

test/utils.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,25 @@ end
2525
@test _ones(Float32, 2) == (1.0f0, 1.0f0)
2626
end
2727

28-
@testitem "Differentiation" setup=[Utils] begin
29-
# _default_diff_method
30-
sphere = Sphere(Point(0, 0, 0), 1.0)
31-
@test _default_diff_method(Meshes.Sphere, Float64) isa AutoEnzyme
32-
@test _default_diff_method(sphere, Float64) isa AutoEnzyme
33-
@test _default_diff_method(sphere, BigFloat) isa FiniteDifference
28+
@testitem "Differentiation (EnzymeExt loaded)" setup=[Utils] begin
29+
# supports_autoenzyme(::Type{<:Any})
30+
@test MeshIntegrals.supports_autoenzyme(Nothing) == false
31+
32+
# _default_diff_method -- using type or instance, Enzyme-supported combination
33+
let sphere = Sphere(Point(0, 0, 0), 1.0)
34+
@test _default_diff_method(Meshes.Sphere, Float64) isa AutoEnzyme
35+
@test _default_diff_method(sphere, Float64) isa AutoEnzyme
36+
end
37+
38+
# _default_diff_method -- Enzyme-unsupported FP types
39+
@test _default_diff_method(Meshes.Sphere, Float16) isa FiniteDifference
40+
@test _default_diff_method(Meshes.Sphere, BigFloat) isa FiniteDifference
41+
42+
# _default_diff_method -- geometries that currently error with AutoEnzyme
43+
@test _default_diff_method(Meshes.BezierCurve, Float64) isa FiniteDifference
44+
@test _default_diff_method(Meshes.CylinderSurface, Float64) isa FiniteDifference
45+
@test _default_diff_method(Meshes.Cylinder, Float64) isa FiniteDifference
46+
@test _default_diff_method(Meshes.ParametrizedCurve, Float64) isa FiniteDifference
3447

3548
# FiniteDifference
3649
@test FiniteDifference().ε 1e-6

0 commit comments

Comments
 (0)