Skip to content

Commit 00d400b

Browse files
authored
Fix bug handling Adjoints in macros (#1698)
* Fix bug handling Adjoints in macros * Explicitly call broadcast in destructive_add! instead of relying on lowering. This is a work-around for a weird feature of Julia that doesn't propagate adjoints properly through `.` broadcasts. * Update parse_expr.jl * Make tests catch adjoint bug properly * Update macros.jl
1 parent a6a9c6e commit 00d400b

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

src/parse_expr.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,10 @@ function destructive_add!(ex::AbstractArray{<:GenericAffExpr}, c::Number,
285285
return result
286286
end
287287

288-
289-
destructive_add!(ex, c, x) = ex .+ c * x
288+
# For some reason, the broadcast syntax `ex .+ c * x` fails if `x` is an
289+
# `Adjoint`. But if we explicitly call `broadcast` it seems to work.
290+
# See JuMP PR #1698 for more discussion.
291+
destructive_add!(ex, c, x) = broadcast(+, ex, c * x)
290292

291293
destructive_add_with_reorder!(ex, arg) = destructive_add!(ex, 1.0, arg)
292294
# Special case because "Val{false}()" is used as the default empty expression.

test/macros.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,20 @@ end
270270
"DenseAxisArray, or SparseAxisArray.")
271271
@test_throws exception @variable(model, x[1:3], container=Oops)
272272
end
273+
274+
@testset "Adjoints" begin
275+
model = Model()
276+
@variable(model, x[1:2])
277+
obj = @objective(model, Min, x' * ones(2, 2) * x)
278+
@test JuMP.isequal_canonical(obj, x[1]^2 + 2 * x[1] * x[2] + x[2]^2)
279+
cref = @constraint(model, x' * ones(2, 2) * x <= 1)
280+
c = JuMP.constraint_object(cref)
281+
@test JuMP.isequal_canonical(c.func, x[1]^2 + 2 * x[1] * x[2] + x[2]^2)
282+
@test c.set == MOI.LessThan(1.0)
283+
@test JuMP.isequal_canonical(
284+
JuMP.destructive_add!(0.0, x', ones(2, 2)), x' * ones(2, 2)
285+
)
286+
end
273287
end
274288

275289
@testset "Macros for JuMPExtension.MyModel" begin

0 commit comments

Comments
 (0)