Skip to content

Commit 584442d

Browse files
devmotionnalimilan
andauthored
Add dims keyword argument to softmax and softmax! (#28)
* Add `dims` keyword argument to `softmax` and `softmax!` * Bump version * Apply suggestions from code review Co-authored-by: Milan Bouchet-Valat <[email protected]> * Fix test error * Update documentation * Fix * Extend docstring Co-authored-by: Milan Bouchet-Valat <[email protected]> Co-authored-by: Milan Bouchet-Valat <[email protected]>
1 parent 93b0cc1 commit 584442d

File tree

6 files changed

+207
-53
lines changed

6 files changed

+207
-53
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.3.2"
4+
version = "0.3.3"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/Manifest.toml

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,58 @@
11
# This file is machine-generated - editing it directly is not advised
22

3+
[[ANSIColoredPrinters]]
4+
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
5+
uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9"
6+
version = "0.0.1"
7+
8+
[[ArgTools]]
9+
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
10+
11+
[[Artifacts]]
12+
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
13+
314
[[Base64]]
415
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
516

17+
[[ChainRulesCore]]
18+
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
19+
git-tree-sha1 = "30ee06de5ff870b45c78f529a6b093b3323256a3"
20+
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
21+
version = "1.3.1"
22+
23+
[[Compat]]
24+
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
25+
git-tree-sha1 = "4866e381721b30fac8dda4c8cb1d9db45c8d2994"
26+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
27+
version = "3.37.0"
28+
629
[[Dates]]
730
deps = ["Printf"]
831
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
932

33+
[[DelimitedFiles]]
34+
deps = ["Mmap"]
35+
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
36+
37+
[[Distributed]]
38+
deps = ["Random", "Serialization", "Sockets"]
39+
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
40+
1041
[[DocStringExtensions]]
1142
deps = ["LibGit2"]
1243
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
1344
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1445
version = "0.8.5"
1546

1647
[[Documenter]]
17-
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
18-
git-tree-sha1 = "47f13b6305ab195edb73c86815962d84e31b0f48"
48+
deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
49+
git-tree-sha1 = "fe0bc46b27cd3413df55859152fd70e50744025f"
1950
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
20-
version = "0.27.3"
51+
version = "0.27.6"
52+
53+
[[Downloads]]
54+
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
55+
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
2156

2257
[[IOCapture]]
2358
deps = ["Logging", "Random"]
@@ -36,14 +71,26 @@ version = "0.1.0"
3671

3772
[[JSON]]
3873
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
39-
git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
74+
git-tree-sha1 = "8076680b162ada2a031f707ac7b4953e30667a37"
4075
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
41-
version = "0.21.1"
76+
version = "0.21.2"
77+
78+
[[LibCURL]]
79+
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
80+
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
81+
82+
[[LibCURL_jll]]
83+
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
84+
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
4285

4386
[[LibGit2]]
4487
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
4588
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
4689

90+
[[LibSSH2_jll]]
91+
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
92+
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
93+
4794
[[Libdl]]
4895
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
4996

@@ -52,10 +99,10 @@ deps = ["Libdl"]
5299
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
53100

54101
[[LogExpFunctions]]
55-
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
102+
deps = ["ChainRulesCore", "DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
56103
path = ".."
57104
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
58-
version = "0.3.0"
105+
version = "0.3.3"
59106

60107
[[Logging]]
61108
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -64,17 +111,28 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
64111
deps = ["Base64"]
65112
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
66113

114+
[[MbedTLS_jll]]
115+
deps = ["Artifacts", "Libdl"]
116+
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
117+
67118
[[Mmap]]
68119
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
69120

121+
[[MozillaCACerts_jll]]
122+
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
123+
70124
[[NetworkOptions]]
71125
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
72126

73127
[[Parsers]]
74128
deps = ["Dates"]
75-
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
129+
git-tree-sha1 = "438d35d2d95ae2c5e8780b330592b6de8494e779"
76130
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
77-
version = "1.1.0"
131+
version = "2.0.3"
132+
133+
[[Pkg]]
134+
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
135+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
78136

79137
[[Printf]]
80138
deps = ["Unicode"]
@@ -94,12 +152,48 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
94152
[[Serialization]]
95153
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
96154

155+
[[SharedArrays]]
156+
deps = ["Distributed", "Mmap", "Random", "Serialization"]
157+
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
158+
97159
[[Sockets]]
98160
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
99161

162+
[[SparseArrays]]
163+
deps = ["LinearAlgebra", "Random"]
164+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
165+
166+
[[Statistics]]
167+
deps = ["LinearAlgebra", "SparseArrays"]
168+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
169+
170+
[[TOML]]
171+
deps = ["Dates"]
172+
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
173+
174+
[[Tar]]
175+
deps = ["ArgTools", "SHA"]
176+
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
177+
100178
[[Test]]
101179
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
102180
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
103181

182+
[[UUIDs]]
183+
deps = ["Random", "SHA"]
184+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
185+
104186
[[Unicode]]
105187
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
188+
189+
[[Zlib_jll]]
190+
deps = ["Libdl"]
191+
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
192+
193+
[[nghttp2_jll]]
194+
deps = ["Artifacts", "Libdl"]
195+
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
196+
197+
[[p7zip_jll]]
198+
deps = ["Artifacts", "Libdl"]
199+
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"

src/basicfuns.jl

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -235,15 +235,35 @@ function logsubexp(x::Real, y::Real)
235235
end
236236

237237
"""
238-
$(SIGNATURES)
238+
softmax!(r::AbstractArray{<:Real}, x::AbstractArray{<:Real}=r; dims=:)
239+
240+
Overwrite `r` with the
241+
[softmax transformation](https://en.wikipedia.org/wiki/Softmax_function) of `x` over
242+
dimension `dims`.
239243
240-
Overwrite `r` with the `softmax` (or _normalized exponential_) transformation of `x`
244+
That is, `r` is overwritten with `exp.(x)`, normalized to sum to 1 over the given
245+
dimensions.
241246
242-
That is, `r` is overwritten with `exp.(x)`, normalized to sum to 1.
247+
See also: [`softmax`](@ref)
248+
"""
249+
softmax!(r::AbstractArray{<:Real}, x::AbstractArray{<:Real}=r; dims=:) =
250+
_softmax!(r, x, dims)
243251

244-
See the [Wikipedia entry](https://en.wikipedia.org/wiki/Softmax_function)
245252
"""
246-
function softmax!(r::AbstractArray{<:Real}, x::AbstractArray{<:Real})
253+
softmax(x::AbstractArray{<:Real}; dims=:)
254+
255+
Return the
256+
[softmax transformation](https://en.wikipedia.org/wiki/Softmax_function) of `x` over
257+
dimension `dims`.
258+
259+
That is, return `exp.(x)`, normalized to sum to 1 over the given dimensions.
260+
261+
See also: [`softmax!`](@ref)
262+
"""
263+
softmax(x::AbstractArray{<:Real}; dims=:) =
264+
softmax!(similar(x, float(eltype(x))), x; dims=dims)
265+
266+
function _softmax!(r, x, ::Colon)
247267
length(r) == length(x) || throw(DimensionMismatch("inconsistent array lengths"))
248268
u = maximum(x)
249269
map!(r, x) do xi
@@ -253,18 +273,16 @@ function softmax!(r::AbstractArray{<:Real}, x::AbstractArray{<:Real})
253273
return r
254274
end
255275

256-
"""
257-
$(SIGNATURES)
258-
259-
Return the [`softmax transformation`](https://en.wikipedia.org/wiki/Softmax_function)
260-
applied to `x` *in place*.
261-
"""
262-
softmax!(x::AbstractArray{<:AbstractFloat}) = softmax!(x, x)
263-
264-
"""
265-
$(SIGNATURES)
266-
267-
Return the [`softmax transformation`](https://en.wikipedia.org/wiki/Softmax_function)
268-
applied to `x`.
269-
"""
270-
softmax(x::AbstractArray{<:Real}) = softmax!(similar(x, float(eltype(x))), x)
276+
function _softmax!(r, x, dims)
277+
size(r) == size(x) || throw(DimensionMismatch("inconsistent array sizes"))
278+
u = maximum(x; dims=dims)
279+
r .= exp.(x .- u)
280+
if u isa Array{eltype(r)}
281+
# array can be reused
282+
sum!(u, r)
283+
r ./= u
284+
else
285+
r ./= sum(r; dims=dims)
286+
end
287+
return r
288+
end

src/chainrules.jl

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,36 @@ function ChainRulesCore.rrule(::typeof(logsumexp), x::AbstractArray{<:Real}; dim
3434
return Ω, logsumexp_pullback
3535
end
3636

37-
function ChainRulesCore.frule(
38-
(_, _, Δx), ::typeof(softmax!), r::AbstractArray{<:Real}, x::AbstractArray{<:Real},
39-
)
40-
softmax!(r, x)
41-
_Δx = reshape(Δx, size(r))
42-
Δr = r .* (_Δx .- LinearAlgebra.dot(r, _Δx))
43-
return r, Δr
37+
# no rules for mutating functions currently:
38+
# https://juliadiff.org/ChainRulesCore.jl/stable/writing_good_rules.html#Which-functions-need-rules?
39+
function ChainRulesCore.frule((_, Δx), ::typeof(softmax), x::AbstractArray{<:Real}; dims=:)
40+
Ω = softmax(x; dims=dims)
41+
ΔΩ = if dims === (:)
42+
Ω .* (Δx .- LinearAlgebra.dot(Ω, Δx))
43+
else
44+
ΩΔx = Ω .* Δx
45+
ΩΔx .- Ω .* sum(ΩΔx; dims=dims)
46+
end
47+
return Ω, ΔΩ
4448
end
45-
function ChainRulesCore.rrule(
46-
::typeof(softmax!), r::AbstractArray{<:Real}, x::AbstractArray{<:Real},
47-
)
48-
softmax!(r, x)
49+
function ChainRulesCore.rrule(::typeof(softmax), x::AbstractArray{<:Real}; dims=:)
50+
Ω = softmax(x; dims=dims)
51+
Ωcopy = copy(Ω)
4952
project_x = ChainRulesCore.ProjectTo(x)
50-
rcopy = copy(reshape(r, size(x)))
51-
function softmax!_pullback(r̄)
52-
_r̄ = reshape(r̄, size(rcopy))
53-
= ChainRulesCore.InplaceableThunk(
54-
Δ -> Δ .+= rcopy .* (_r̄ .- LinearAlgebra.dot(rcopy, _r̄)),
55-
ChainRulesCore.@thunk(project_x(rcopy .* (_r̄ .- LinearAlgebra.dot(rcopy, _r̄)))),
56-
)
57-
return ChainRulesCore.NoTangent(), ChainRulesCore.ZeroTangent(), x̄
53+
function softmax_pullback(Ω̄)
54+
= if dims === (:)
55+
ChainRulesCore.InplaceableThunk(
56+
Δ -> Δ .+= Ωcopy .* (Ω̄ .- LinearAlgebra.dot(Ωcopy, Ω̄)),
57+
ChainRulesCore.@thunk(project_x(Ωcopy .* (Ω̄ .- LinearAlgebra.dot(Ωcopy, Ω̄)))),
58+
)
59+
else
60+
ΩcopyΩ̄ = Ωcopy .* Ω̄
61+
ChainRulesCore.InplaceableThunk(
62+
Δ -> Δ .+= ΩcopyΩ̄ .- Ωcopy .* sum(ΩcopyΩ̄; dims=dims),
63+
ChainRulesCore.@thunk(project_x(ΩcopyΩ̄ .- Ωcopy .* sum(ΩcopyΩ̄; dims=dims))),
64+
)
65+
end
66+
return ChainRulesCore.NoTangent(), x̄
5867
end
59-
return r, softmax!_pullback
68+
return Ω, softmax_pullback
6069
end

test/basicfuns.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,23 @@ end
207207
softmax!(s, x)
208208
@test s r
209209

210+
fill!(s, zero(T))
211+
softmax!(s, x; dims=1)
212+
@test s r
213+
210214
s = Matrix{T}(undef, 1, 3)
211215
softmax!(s, x)
212216
@test s permutedims(r)
217+
218+
@test_throws DimensionMismatch softmax!(s, x; dims=1)
219+
220+
fill!(s, zero(T))
221+
softmax!(s, permutedims(x); dims=2)
222+
@test s permutedims(r)
223+
224+
fill!(s, zero(T))
225+
softmax!(s, permutedims(x); dims=1:2)
226+
@test s permutedims(r)
213227
end
214228
softmax!(x)
215229
@test x r
@@ -219,6 +233,21 @@ end
219233
s = softmax(x)
220234
@test s r
221235
@test eltype(s) === T
236+
237+
x = repeat(S[1, 2, 3], 1, 3)
238+
s = softmax(x; dims=1)
239+
@test s repeat(r, 1, 3)
240+
@test eltype(s) === T
241+
242+
x = repeat(S[1 2 3], 3, 1)
243+
s = softmax(x; dims=2)
244+
@test s repeat(permutedims(r), 3, 1)
245+
@test eltype(s) === T
246+
247+
x = S[1 2 3]
248+
s = softmax(x; dims=1:2)
249+
@test s permutedims(r)
250+
@test eltype(s) === T
222251
end
223252

224253
x = [1//2, 2//3, 3//4]

test/chainrules.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,13 @@
7676
end
7777

7878
for x in (randn(10), randn(10, 8))
79-
for r in (similar(x), similar(x, 1, size(x)...))
80-
test_frule(softmax!, r, x)
81-
test_rrule(softmax!, r, x)
79+
test_frule(softmax, x)
80+
test_rrule(softmax, x)
81+
82+
for dims in (1, 1:2, 2)
83+
all(d <= ndims(x) for d in dims) || continue
84+
test_frule(softmax, x; fkwargs=(dims=dims,))
85+
test_rrule(softmax, x; fkwargs=(dims=dims,))
8286
end
8387
end
8488
end

0 commit comments

Comments
 (0)