Skip to content

Commit 0cf7942

Browse files
authored
Merge pull request #37 from mcabbott/biwalk
Make `fmap(f, x, y)` useful
2 parents 4a834ce + 077d6e5 commit 0cf7942

File tree

7 files changed

+185
-82
lines changed

7 files changed

+185
-82
lines changed

.github/workflows/ci.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.5' # Replace this with the minimum Julia version that your package supports.
21-
# - '1' # automatically expands to the latest stable 1.x release of Julia
20+
- '1.0'
21+
- '1.6' # Replace this with the minimum Julia version that your package supports.
22+
- '1' # automatically expands to the latest stable 1.x release of Julia
2223
- 'nightly'
2324
os:
2425
- ubuntu-latest

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
name = "Functors"
22
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
33
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.2.7"
4+
version = "0.2.8"
55

66
[compat]
7-
julia = "1"
87
Documenter = "0.27"
8+
julia = "1"
99

1010
[extras]
1111
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

src/Functors.jl

+12
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ usually using the macro [@functor](@ref).
2323
"""
2424
functor
2525

26+
@static if VERSION >= v"1.5" # var"@functor" doesn't work on 1.0, temporarily disable
2627
"""
2728
@functor T
2829
@functor T (x,)
@@ -65,6 +66,7 @@ TwoThirds(Foo(10, 20), Foo(3, 4), 560)
6566
```
6667
"""
6768
var"@functor"
69+
end # VERSION
6870

6971
"""
7072
Functors.isleaf(x)
@@ -182,6 +184,16 @@ This function walks (maps) over `xs` calling the continuation `f'` to continue t
182184
julia> fmap(x -> 10x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x))
183185
Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7))))
184186
```
187+
188+
The behaviour when the same node appears twice can be altered by giving a value
189+
to the `prune` keyword, which is then used in place of all but the first:
190+
191+
```jldoctest
192+
julia> twice = [1, 2];
193+
194+
julia> fmap(float, (x = twice, y = [1,2], z = twice); prune = missing)
195+
(x = [1.0, 2.0], y = [1.0, 2.0], z = missing)
196+
```
185197
"""
186198
fmap
187199

src/functor.jl

+11-25
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ functor(T, x) = (), _ -> x
33
functor(x) = functor(typeof(x), x)
44

55
functor(::Type{<:Tuple}, x) = x, y -> y
6-
functor(::Type{<:NamedTuple}, x) = x, y -> y
6+
functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity
77

88
functor(::Type{<:AbstractArray}, x) = x, y -> y
99
functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x
@@ -43,12 +43,11 @@ function _default_walk(f, x)
4343
re(map(f, func))
4444
end
4545

46-
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
47-
haskey(cache, x) && return cache[x]
48-
y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x)
49-
cache[x] = y
46+
struct NoKeyword end
5047

51-
return y
48+
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
49+
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
50+
cache[x] = exclude(x) ? f(x) : walk(x -> fmap(f, x; exclude=exclude, walk=walk, cache=cache, prune=prune), x)
5251
end
5352

5453
###
@@ -74,27 +73,16 @@ end
7473
### Vararg forms
7574
###
7675

77-
function fmap(f, x, dx...; cache = IdDict())
78-
haskey(cache, x) && return cache[x]
79-
cache[x] = isleaf(x) ? f(x, dx...) : _default_walk((x...) -> fmap(f, x..., cache = cache), x, dx...)
76+
function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
77+
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
78+
cache[x] = exclude(x) ? f(x, ys...) : walk((xy...,) -> fmap(f, xy...; exclude=exclude, walk=walk, cache=cache, prune=prune), x, ys...)
8079
end
8180

82-
function functor_tuple(f, x::Tuple, dx::Tuple)
83-
map(x, dx) do x, x̄
84-
_default_walk(f, x, x̄)
85-
end
86-
end
87-
functor_tuple(f, x, dx) = f(x, dx)
88-
functor_tuple(f, x, ::Nothing) = x
89-
90-
function _default_walk(f, x, dx)
81+
function _default_walk(f, x, ys...)
9182
func, re = functor(x)
92-
map(func, dx) do x, x̄
93-
# functor_tuple(f, x, x̄)
94-
f(x, x̄)
95-
end |> re
83+
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
84+
re(map(f, func, yfuncs...))
9685
end
97-
_default_walk(f, ::Nothing, ::Nothing) = nothing
9886

9987
###
10088
### FlexibleFunctors.jl
@@ -112,9 +100,7 @@ function makeflexiblefunctor(m::Module, T, pfield)
112100
func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields))
113101
return func, re
114102
end
115-
116103
end
117-
118104
end
119105

120106
function flexiblefunctorm(T, pfield = :params)

test/basics.jl

+157-29
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,16 @@
1-
struct Foo
2-
x
3-
y
4-
end
1+
2+
using Functors: functor
3+
4+
struct Foo; x; y; end
55
@functor Foo
66

7-
struct Bar
8-
x
9-
end
7+
struct Bar; x; end
108
@functor Bar
119

12-
struct Baz
13-
x
14-
y
15-
z
16-
end
17-
@functor Baz (y,)
10+
struct OneChild3; x; y; z; end
11+
@functor OneChild3 (y,)
1812

19-
struct NoChildren
20-
x
21-
y
22-
end
13+
struct NoChildren2; x; y; end
2314

2415
@static if VERSION >= v"1.6"
2516
@testset "ComposedFunction" begin
@@ -31,6 +22,10 @@ end
3122
end
3223
end
3324

25+
###
26+
### Basic functionality
27+
###
28+
3429
@testset "Nested" begin
3530
model = Bar(Foo(1, [1, 2, 3]))
3631

@@ -53,20 +48,80 @@ end
5348
@test fmap(f, x; exclude = x -> x isa AbstractArray) == x
5449
end
5550

51+
@testset "Property list" begin
52+
model = OneChild3(1, 2, 3)
53+
model′ = fmap(x -> 2x, model)
54+
55+
@test (model′.x, model′.y, model′.z) == (1, 4, 3)
56+
end
57+
58+
@testset "cache" begin
59+
shared = [1,2,3]
60+
m1 = Foo(shared, Foo([1,2,3], Foo(shared, [1,2,3])))
61+
m1f = fmap(float, m1)
62+
@test m1f.x === m1f.y.y.x
63+
@test m1f.x !== m1f.y.x
64+
m1p = fmapstructure(identity, m1; prune = nothing)
65+
@test m1p == (x = [1, 2, 3], y = (x = [1, 2, 3], y = (x = nothing, y = [1, 2, 3])))
66+
67+
# A non-leaf node can also be repeated:
68+
m2 = Foo(Foo(shared, 4), Foo(shared, 4))
69+
@test m2.x === m2.y
70+
m2f = fmap(float, m2)
71+
@test m2f.x.x === m2f.y.x
72+
m2p = fmapstructure(identity, m2; prune = Bar(0))
73+
@test m2p == (x = (x = [1, 2, 3], y = 4), y = Bar(0))
74+
75+
# Repeated isbits types should not automatically be regarded as shared:
76+
m3 = Foo(Foo(shared, 1:3), Foo(1:3, shared))
77+
m3p = fmapstructure(identity, m3; prune = 0)
78+
@test m3p.y.y == 0
79+
@test_broken m3p.y.x == 1:3
80+
end
81+
82+
@testset "functor(typeof(x), y) from @functor" begin
83+
nt1, re1 = functor(Foo, (x=1, y=2, z=3))
84+
@test nt1 == (x = 1, y = 2)
85+
@test re1((x = 10, y = 20)) == Foo(10, 20)
86+
re1((y = 22, x = 11)) # gives Foo(22, 11), is that a bug?
87+
88+
nt2, re2 = functor(Foo, (z=33, x=1, y=2))
89+
@test nt2 == (x = 1, y = 2)
90+
@test re2((x = 10, y = 20)) == Foo(10, 20)
91+
92+
@test_throws Exception functor(Foo, (z=33, x=1)) # type NamedTuple has no field y
93+
94+
nt3, re3 = functor(OneChild3, (x=1, y=2, z=3))
95+
@test nt3 == (y = 2,)
96+
@test re3((y = 20,)) == OneChild3(1, 20, 3)
97+
re3(22) # gives OneChild3(1, 22, 3), is that a bug?
98+
end
99+
100+
@testset "functor(typeof(x), y) for Base types" begin
101+
nt11, re11 = functor(NamedTuple{(:x, :y)}, (x=1, y=2, z=3))
102+
@test nt11 == (x = 1, y = 2)
103+
@test re11((x = 10, y = 20)) == (x = 10, y = 20)
104+
re11((y = 22, x = 11))
105+
re11((11, 22)) # passes right through
106+
107+
nt12, re12 = functor(NamedTuple{(:x, :y)}, (z=33, x=1, y=2))
108+
@test nt12 == (x = 1, y = 2)
109+
@test re12((x = 10, y = 20)) == (x = 10, y = 20)
110+
111+
@test_throws Exception functor(NamedTuple{(:x, :y)}, (z=33, x=1))
112+
end
113+
114+
###
115+
### Extras
116+
###
117+
56118
@testset "Walk" begin
57119
model = Foo((0, Bar([1, 2, 3])), [4, 5])
58120

59121
model′ = fmapstructure(identity, model)
60122
@test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5])
61123
end
62124

63-
@testset "Property list" begin
64-
model = Baz(1, 2, 3)
65-
model′ = fmap(x -> 2x, model)
66-
67-
@test (model′.x, model′.y, model′.z) == (1, 4, 3)
68-
end
69-
70125
@testset "fcollect" begin
71126
m1 = [1, 2, 3]
72127
m2 = 1
@@ -78,7 +133,7 @@ end
78133

79134
m1 = [1, 2, 3]
80135
m2 = Bar(m1)
81-
m0 = NoChildren(:a, :b)
136+
m0 = NoChildren2(:a, :b)
82137
m3 = Foo(m2, m0)
83138
m4 = Bar(m3)
84139
@test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
@@ -89,6 +144,79 @@ end
89144
@test all(fcollect(m3) .=== [m3, m1, m2])
90145
end
91146

147+
###
148+
### Vararg forms
149+
###
150+
151+
@testset "fmap(f, x, y)" begin
152+
m1 = (x = [1,2], y = 3)
153+
n1 = (x = [4,5], y = 6)
154+
@test fmap(+, m1, n1) == (x = [5, 7], y = 9)
155+
156+
# Reconstruction type comes from the first argument
157+
foo1 = Foo([7,8], 9)
158+
@test fmap(+, m1, foo1) == (x = [8, 10], y = 12)
159+
@test fmap(+, foo1, n1) isa Foo
160+
@test fmap(+, foo1, n1).x == [11, 13]
161+
162+
# Mismatched trees should be an error
163+
m2 = (x = [1,2], y = (a = [3,4], b = 5))
164+
n2 = (x = [6,7], y = 8)
165+
@test_throws Exception fmap(firsttuple, m2, n2) # ERROR: type Int64 has no field a
166+
@test_throws Exception fmap(firsttuple, m2, n2)
167+
168+
# The cache uses IDs from the first argument
169+
shared = [1,2,3]
170+
m3 = (x = shared, y = [4,5,6], z = shared)
171+
n3 = (x = shared, y = shared, z = [7,8,9])
172+
@test fmap(+, m3, n3) == (x = [2, 4, 6], y = [5, 7, 9], z = [2, 4, 6])
173+
z3 = fmap(+, m3, n3)
174+
@test z3.x === z3.z
175+
176+
# Pruning of duplicates:
177+
@test fmap(+, m3, n3; prune = nothing) == (x = [2,4,6], y = [5,7,9], z = nothing)
178+
179+
# More than two arguments:
180+
z4 = fmap(+, m3, n3, m3, n3)
181+
@test z4 == fmap(x -> 2x, z3)
182+
@test z4.x === z4.z
183+
184+
@test fmap(+, foo1, m1, n1) isa Foo
185+
@static if VERSION >= v"1.6" # fails on Julia 1.0
186+
@test fmap(.*, m1, foo1, n1) == (x = [4*7, 2*5*8], y = 3*6*9)
187+
end
188+
end
189+
190+
@static if VERSION >= v"1.6" # Julia 1.0: LoadError: error compiling top-level scope: type definition not allowed inside a local scope
191+
@testset "old test update.jl" begin
192+
struct M{F,T,S}
193+
σ::F
194+
W::T
195+
b::S
196+
end
197+
198+
@functor M
199+
200+
(m::M)(x) = m.σ.(m.W * x .+ m.b)
201+
202+
m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3))
203+
x = ones(Float32, 4, 2)
204+
m̄, _ = gradient((m,x) -> sum(m(x)), m, x)
205+
= Functors.fmap(m, m̄) do x, y
206+
isnothing(x) && return y
207+
isnothing(y) && return x
208+
x .- 0.1f0 .* y
209+
end
210+
211+
@test.W fill(0.8f0, size(m.W))
212+
@test.b fill(-0.2f0, size(m.b))
213+
end
214+
end # VERSION
215+
216+
###
217+
### FlexibleFunctors.jl
218+
###
219+
92220
struct FFoo
93221
x
94222
y
@@ -102,13 +230,13 @@ struct FBar
102230
end
103231
@flexiblefunctor FBar p
104232

105-
struct FBaz
233+
struct FOneChild4
106234
x
107235
y
108236
z
109237
p
110238
end
111-
@flexiblefunctor FBaz p
239+
@flexiblefunctor FOneChild4 p
112240

113241
@testset "Flexible Nested" begin
114242
model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,))
@@ -132,7 +260,7 @@ end
132260
end
133261

134262
@testset "Flexible Property list" begin
135-
model = FBaz(1, 2, 3, (:x, :z))
263+
model = FOneChild4(1, 2, 3, (:x, :z))
136264
model′ = fmap(x -> 2x, model)
137265

138266
@test (model′.x, model′.y, model′.z) == (2, 2, 6)
@@ -147,7 +275,7 @@ end
147275
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3])
148276
@test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4])
149277

150-
m0 = NoChildren(:a, :b)
278+
m0 = NoChildren2(:a, :b)
151279
m1 = [1, 2, 3]
152280
m2 = FBar(m1, ())
153281
m3 = FFoo(m2, m0, (:x, :y,))

test/runtests.jl

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using Zygote
55

66
include("basics.jl")
77
include("base.jl")
8-
include("update.jl")
98

109
if VERSION < v"1.6" # || VERSION > v"1.7-"
1110
@warn "skipping doctests, on Julia $VERSION"

0 commit comments

Comments
 (0)