Skip to content

Commit 1129f9b

Browse files
authored
switch DiffBase dependency over to DiffRules/DiffTests/DiffResults, remove RealInterface dependency (#262)
1 parent 887d282 commit 1129f9b

23 files changed

+326
-364
lines changed

REQUIRE

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
julia 0.6.0
22
Compat 0.31.0
33
StaticArrays 0.5.0
4-
DiffBase 0.3.2 0.4.0
4+
DiffResults 0.0.1
5+
DiffRules 0.0.1
56
NaNMath 0.2.2
67
SpecialFunctions 0.1.0
7-
RealInterface 0.0.2
88
CommonSubexpressions 0.0.1

appveyor.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
environment:
22
matrix:
3-
- JULIA_URL: "https://julialangnightlies-s3.julialang.org/bin/winnt/x64/julia-latest-win64.exe"
3+
- JULIA_URL: "https://julialang-s3.julialang.org/bin/winnt/x86/0.6/julia-0.6-latest-win32.exe"
4+
- JULIA_URL: "https://julialang-s3.julialang.org/bin/winnt/x64/0.6/julia-0.6-latest-win64.exe"
45

56
branches:
67
only:

benchmarks/benchmarks.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ForwardDiff, DiffBase
1+
using ForwardDiff, DiffTests, DiffResults
22
using BenchmarkTools
33

44
include(joinpath(dirname(dirname(@__FILE__)), "test", "utils.jl"))
@@ -16,44 +16,44 @@ const gradient_group = addgroup!(SUITE, "gradient")
1616
const jacobian_group = addgroup!(SUITE, "jacobian")
1717
const hessian_group = addgroup!(SUITE, "hessian")
1818

19-
for f in (DiffBase.NUMBER_TO_NUMBER_FUNCS..., DiffBase.NUMBER_TO_ARRAY_FUNCS...)
19+
for f in (DiffTests.NUMBER_TO_NUMBER_FUNCS..., DiffTests.NUMBER_TO_ARRAY_FUNCS...)
2020
x = 1.0
2121
y = f(x)
2222

2323
value_group[name(f)] = @benchmarkable $(f)($x)
2424

25-
out = isa(y, Number) ? DiffBase.DiffResult(y, y) : DiffBase.DiffResult(similar(y), similar(y))
25+
out = isa(y, Number) ? DiffResults.DiffResult(y, y) : DiffResults.DiffResult(similar(y), similar(y))
2626
derivative_group[name(f)] = @benchmarkable ForwardDiff.derivative!($out, $f, $x)
2727
end
2828

29-
for f in (DiffBase.VECTOR_TO_NUMBER_FUNCS..., DiffBase.MATRIX_TO_NUMBER_FUNCS...)
29+
for f in (DiffTests.VECTOR_TO_NUMBER_FUNCS..., DiffTests.MATRIX_TO_NUMBER_FUNCS...)
3030
fval = addgroup!(value_group, name(f))
3131
fgrad = addgroup!(gradient_group, name(f))
3232
fhess = addgroup!(hessian_group, name(f))
33-
arrs = in(f, DiffBase.VECTOR_TO_NUMBER_FUNCS) ? vecs : mats
33+
arrs = in(f, DiffTests.VECTOR_TO_NUMBER_FUNCS) ? vecs : mats
3434
for x in arrs
3535
y = f(x)
3636

3737
fval[length(x)] = @benchmarkable $(f)($x)
3838

39-
gout = DiffBase.DiffResult(y, similar(x, typeof(y)))
39+
gout = DiffResults.DiffResult(y, similar(x, typeof(y)))
4040
gcfg = ForwardDiff.GradientConfig(nothing, x)
4141
fgrad[length(x)] = @benchmarkable ForwardDiff.gradient!($gout, $f, $x, $gcfg)
4242

43-
hout = DiffBase.DiffResult(y, similar(x, typeof(y)), similar(x, typeof(y), length(x), length(x)))
43+
hout = DiffResults.DiffResult(y, similar(x, typeof(y)), similar(x, typeof(y), length(x), length(x)))
4444
hcfg = ForwardDiff.HessianConfig(nothing, hout, x)
4545
fhess[length(x)] = @benchmarkable ForwardDiff.hessian!($hout, $f, $x, $hcfg)
4646
end
4747
end
4848

49-
for f in DiffBase.ARRAY_TO_ARRAY_FUNCS
49+
for f in DiffTests.ARRAY_TO_ARRAY_FUNCS
5050
fval = addgroup!(value_group, name(f))
5151
fjac = addgroup!(jacobian_group, name(f))
5252
for x in mats
5353
y = f(x)
5454
fval[length(x)] = @benchmarkable $(f)($x)
5555

56-
out = DiffBase.JacobianResult(y, x)
56+
out = DiffResults.JacobianResult(y, x)
5757
cfg = ForwardDiff.JacobianConfig(nothing, y, x)
5858
fjac[length(x)] = @benchmarkable ForwardDiff.jacobian!($out, $f, $y, $x, $cfg)
5959
end

docs/src/dev/contributing.md

+12-16
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,15 @@ If you're new GitHub, here's an outline of the workflow you should use:
1414
## Adding New Derivative Definitions
1515

1616
In general, new derivative implementations for `Dual` are automatically defined via simple
17-
symbolic rules. ForwardDiff accomplishes this by looping over the [the function names listed
18-
in the RealInterface package](https://github.com/jrevels/RealInterface.jl), and for every
19-
function (and relevant arity), it attempts to generate a `Dual` definition by applying the
20-
[symbolic rules provided by the DiffBase package](https://github.com/JuliaDiff/DiffBase.jl/blob/master/src/rules.jl).
21-
Conveniently, these auto-generated definitions are also automatically tested.
22-
23-
Thus, in order to add a new derivative implementation for `Dual`, you should do the
24-
following:
25-
26-
1. Make sure the name of the function is appropriately listed in the RealInterface package
27-
2. Define the appropriate derivative rule(s) in DiffBase
28-
3. Check that calling the function on `Dual` instances delivers the desired result.
29-
30-
Depending on the arity of your function and its category in RealInterface, ForwardDiff's
31-
auto-definition mechanism might need to be expanded to include it. If this is the case,
32-
ForwardDiff's maintainers can help you out.
17+
symbolic rules. ForwardDiff accomplishes this by looping over the rules provided by
18+
[the DiffRules package](https://github.com/JuliaDiff/DiffRules.jl) and using them to
19+
auto-generate `Dual` definitions. Conveniently, these auto-generated definitions are also
20+
automatically tested.
21+
22+
Thus, in order to add a new derivative implementation for `Dual`, you should define the
23+
appropriate derivative rule(s) in DiffRules, and then check that calling the function on
24+
`Dual` instances delivers the desired result.
25+
26+
Depending on your function, ForwardDiff's auto-definition mechanism might need to be
27+
expanded to support it. If this is the case, file an issue/PR so that ForwardDiff's
28+
maintainers can help you out.

docs/src/user/advanced.md

+6-10
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,11 @@ this task!**
1212

1313
In the course of calculating higher-order derivatives, ForwardDiff ends up calculating all
1414
the lower-order derivatives and primal value `f(x)`. To retrieve these results in one fell
15-
swoop, you can utilize the DiffResult API provided by the DiffBase package. To learn how to
16-
use this functionality, please consult the [relevant DiffBase
17-
documentation](http://www.juliadiff.org/DiffBase.jl/stable/diffresultapi.html).
15+
swoop, you can utilize the [DiffResults](https://github.com/JuliaDiff/DiffResults.jl) API.
1816

19-
Note that running `using ForwardDiff` will automatically bring the `DiffBase` module
20-
into scope, and that all mutating ForwardDiff API methods support the DiffResult API.
21-
In other words, API methods of the form `ForwardDiff.method!(out, args...)` will
22-
work appropriately if `isa(out, DiffResult)`.
17+
All mutating ForwardDiff API methods support the DiffResults API. In other words, API
18+
methods of the form `ForwardDiff.method!(out, args...)` will work appropriately if
19+
`isa(out, DiffResults.DiffResult)`.
2320

2421
## Configuring Chunk Size
2522

@@ -141,9 +138,8 @@ decrease performance (~5%-10% on our benchmarks).
141138

142139
In order to preserve performance in the majority of use cases, ForwardDiff disables this
143140
check by default. If your code is affected by this `NaN` behvaior, you can enable
144-
ForwardDiff's `NaN`-safe mode by setting `NANSAFE_MODE_ENABLED` to `true` in
145-
ForwardDiff's source. This constant is located in `src/ForwardDiff.jl` in the
146-
package's directory.
141+
ForwardDiff's `NaN`-safe mode by setting the `NANSAFE_MODE_ENABLED` constant to `true` in
142+
ForwardDiff's source.
147143

148144
In the future, we plan on allowing users and downstream library authors to dynamically
149145
enable [NaN`-safe mode via the `AbstractConfig`

docs/src/user/upgrade.md

+8
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ out = ForwardDiff.hessian!(out, f, x) # re-alias output!
8080
v = DiffBase.value(out)
8181
g = DiffBase.gradient(out)
8282
h = DiffBase.hessian(out)
83+
84+
# ForwardDiff v0.7 & above
85+
using DiffResults
86+
out = DiffResults.HessianResult(x)
87+
out = ForwardDiff.hessian!(out, f, x) # re-alias output!
88+
v = DiffResults.value(out)
89+
g = DiffResults.gradient(out)
90+
h = DiffResults.hessian(out)
8391
```
8492

8593
## Higher-Order Differentiation

src/ForwardDiff.jl

+8-42
Original file line numberDiff line numberDiff line change
@@ -2,63 +2,29 @@ __precompile__()
22

33
module ForwardDiff
44

5-
using DiffBase
6-
using DiffBase: DiffResult, MutableDiffResult, ImmutableDiffResult
5+
using DiffRules, DiffResults
6+
using DiffResults: DiffResult, MutableDiffResult, ImmutableDiffResult
77
using StaticArrays
88
using Compat
99

1010
import NaNMath
1111
import SpecialFunctions
12-
import RealInterface
1312
import CommonSubexpressions
1413

15-
#############################
16-
# types/functions/constants #
17-
#############################
18-
19-
const NANSAFE_MODE_ENABLED = false
20-
21-
const REAL_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, Irrational{:e}, Irrational{})
22-
23-
const DEFAULT_CHUNK_THRESHOLD = 10
24-
25-
struct Chunk{N} end
26-
27-
function Chunk(input_length::Integer, threshold::Integer = DEFAULT_CHUNK_THRESHOLD)
28-
N = pickchunksize(input_length, threshold)
29-
return Chunk{N}()
30-
end
31-
32-
function Chunk(x::AbstractArray, threshold::Integer = DEFAULT_CHUNK_THRESHOLD)
33-
return Chunk(length(x), threshold)
34-
end
35-
36-
# Constrained to `N <= threshold`, minimize (in order of priority):
37-
# 1. the number of chunks that need to be computed
38-
# 2. the number of "left over" perturbations in the final chunk
39-
function pickchunksize(input_length, threshold = DEFAULT_CHUNK_THRESHOLD)
40-
if input_length <= threshold
41-
return input_length
42-
else
43-
nchunks = round(Int, input_length / DEFAULT_CHUNK_THRESHOLD, RoundUp)
44-
return round(Int, input_length / nchunks, RoundUp)
45-
end
46-
end
47-
48-
############
49-
# includes #
50-
############
51-
14+
include("prelude.jl")
5215
include("partials.jl")
5316
include("dual.jl")
5417
include("config.jl")
55-
include("utils.jl")
18+
include("apiutils.jl")
5619
include("derivative.jl")
5720
include("gradient.jl")
5821
include("jacobian.jl")
5922
include("hessian.jl")
6023
include("deprecated.jl")
6124

62-
export DiffBase
25+
# This is a deprecation binding and should be removed in the next minor release.
26+
const DiffBase = DiffResults
27+
28+
export DiffBase, DiffResults
6329

6430
end # module

src/utils.jl renamed to src/apiutils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
# value extraction #
33
####################
44

5-
@inline extract_value!(out::DiffResult, ydual) = DiffBase.value!(value, out, ydual)
5+
@inline extract_value!(out::DiffResult, ydual) = DiffResults.value!(value, out, ydual)
66
@inline extract_value!(out, ydual) = out
77

88
@inline function extract_value!(out, y, ydual)
99
map!(value, y, ydual)
1010
copy_value!(out, y)
1111
end
1212

13-
@inline copy_value!(out::DiffResult, y) = DiffBase.value!(out, y)
13+
@inline copy_value!(out::DiffResult, y) = DiffResults.value!(out, y)
1414
@inline copy_value!(out, y) = out
1515

1616
###################################

src/config.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ function HessianConfig(f::F,
233233
x::AbstractArray{V},
234234
chunk::Chunk = Chunk(x),
235235
tag::Tag = Tag(F, Dual{Void,V,0})) where {F,V}
236-
jacobian_config = JacobianConfig(nothing, DiffBase.gradient(result), x, chunk)
236+
jacobian_config = JacobianConfig(nothing, DiffResults.gradient(result), x, chunk)
237237
gradient_config = GradientConfig(f, jacobian_config.duals[2], chunk, tag)
238238
return HessianConfig(jacobian_config, gradient_config)
239239
end

src/derivative.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,4 @@ end
8383
#----------#
8484

8585
extract_derivative!(result::AbstractArray, y::AbstractArray) = map!(extract_derivative, result, y)
86-
extract_derivative!(result::DiffResult, y) = DiffBase.derivative!(extract_derivative, result, y)
86+
extract_derivative!(result::DiffResult, y) = DiffResults.derivative!(extract_derivative, result, y)

0 commit comments

Comments
 (0)