Skip to content

Commit 31502e3

Browse files
mbaumansimeonschaub
authored andcommitted
Dot-broadcasting for short-circuiting ops .&& and .|| (JuliaLang#39594)
I have long wanted a proper fix for issue JuliaLang#5187. It was the very first Julia issue I filed. This is a shot at such a fix. This PR: * Enables parsing for `.&&` and `.||`. They are parsed into `Expr(:call, :.&&, ...)` expressions at the same precedence as their respective `&&` and `||`: ```julia-repl julia> Meta.show_sexpr(:(a .&& b)) (:call, :.&&, :a, :b) ``` * Unlike all other dotted operators `.op` (like `.+`), the `op`-alone part (`var"&&"`) is not an exported name from Base. As such, this effectively lowers to `broadcasted((x,y)->x && y, ...)`, but instead of using an anonymous function I've named it `Base.andand` and `Base.oror`: ```julia-repl julia> Meta.@lower a .&& b :($(Expr(:thunk, CodeInfo( @ none within `top-level scope' 1 ─ %1 = Base.broadcasted(Base.andand, a, b) │ %2 = Base.materialize(%1) └── return %2 )))) ``` * I've used a named function to enable short-circuiting behavior _within the broadcast kernel itself_. In the case that the second argument is a part of the same fused broadcast kernel, it will only evaluate if required: ```julia-repl julia> mutable struct F5187; x; end julia> (f::F5187)(x) = (f.x += x) julia> (iseven.(1:4) .|| (F5187(0)).(ones(4))) 4-element Vector{Real}: 1.0 true 2.0 true ``` Co-authored-by: Simeon Schaub <[email protected]>
1 parent eab1f65 commit 31502e3

File tree

5 files changed

+63
-14
lines changed

5 files changed

+63
-14
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ New language features
77
* `(; a, b) = x` can now be used to destructure properties `a` and `b` of `x`. This syntax is equivalent to `a = getproperty(x, :a)`
88
and similarly for `b`. ([#39285])
99
* Implicit multiplication by juxtaposition is now allowed for radical symbols (e.g., `x√y` and `x∛y`). ([#40173])
10+
* The short-circuiting operators `&&` and `||` can now be dotted to participate in broadcast fusion
11+
as `.&&` and `.||`. ([#39594])
1012

1113
Language changes
1214
----------------

base/broadcast.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using .Base.Cartesian
1111
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
1212
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
1313
import .Base: copy, copyto!, axes
14-
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastFunction
14+
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastFunction, andand, oror
1515

1616
## Computing the result's axes: deprecated name
1717
const broadcast_axes = axes
@@ -179,6 +179,21 @@ function Broadcasted{Style}(f::F, args::Args, axes=nothing) where {Style, F, Arg
179179
Broadcasted{Style, typeof(axes), Core.Typeof(f), Args}(f, args, axes)
180180
end
181181

182+
struct AndAnd end
183+
andand = AndAnd()
184+
broadcasted(::AndAnd, a, b) = broadcasted((a, b) -> a && b, a, b)
185+
function broadcasted(::AndAnd, a, bc::Broadcasted)
186+
bcf = flatten(bc)
187+
broadcasted((a, args...) -> a && bcf.f(args...), a, bcf.args...)
188+
end
189+
struct OrOr end
190+
const oror = OrOr()
191+
broadcasted(::OrOr, a, b) = broadcasted((a, b) -> a || b, a, b)
192+
function broadcasted(::OrOr, a, bc::Broadcasted)
193+
bcf = flatten(bc)
194+
broadcasted((a, args...) -> a || bcf.f(args...), a, bcf.args...)
195+
end
196+
182197
Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} =
183198
Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes)
184199

@@ -1257,15 +1272,9 @@ function __dot__(x::Expr)
12571272
tmp = x.head === :(<:) ? :.<: : :.>:
12581273
Expr(:call, tmp, dotargs...)
12591274
else
1260-
if x.head === :&& || x.head === :||
1261-
error("""
1262-
Using `&&` and `||` is disallowed in `@.` expressions.
1263-
Use `&` or `|` for elementwise logical operations.
1264-
""")
1265-
end
1266-
head = string(x.head)
1267-
if last(head) == '=' && first(head) != '.'
1268-
Expr(Symbol('.',head), dotargs...)
1275+
head = String(x.head)::String
1276+
if last(head) == '=' && first(head) != '.' || head == "&&" || head == "||"
1277+
Expr(Symbol('.', head), dotargs...)
12691278
else
12701279
Expr(x.head, dotargs...)
12711280
end

src/julia-parser.scm

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
(define prec-pair (add-dots '(=>)))
1212
(define prec-conditional '(?))
1313
(define prec-arrow (add-dots '(← → ↔ ↚ ↛ ↞ ↠ ↢ ↣ ↦ ↤ ↮ ⇎ ⇍ ⇏ ⇐ ⇒ ⇔ ⇴ ⇶ ⇷ ⇸ ⇹ ⇺ ⇻ ⇼ ⇽ ⇾ ⇿ ⟵ ⟶ ⟷ ⟹ ⟺ ⟻ ⟼ ⟽ ⟾ ⟿ ⤀ ⤁ ⤂ ⤃ ⤄ ⤅ ⤆ ⤇ ⤌ ⤍ ⤎ ⤏ ⤐ ⤑ ⤔ ⤕ ⤖ ⤗ ⤘ ⤝ ⤞ ⤟ ⤠ ⥄ ⥅ ⥆ ⥇ ⥈ ⥊ ⥋ ⥎ ⥐ ⥒ ⥓ ⥖ ⥗ ⥚ ⥛ ⥞ ⥟ ⥢ ⥤ ⥦ ⥧ ⥨ ⥩ ⥪ ⥫ ⥬ ⥭ ⥰ ⧴ ⬱ ⬰ ⬲ ⬳ ⬴ ⬵ ⬶ ⬷ ⬸ ⬹ ⬺ ⬻ ⬼ ⬽ ⬾ ⬿ ⭀ ⭁ ⭂ ⭃ ⭄ ⭇ ⭈ ⭉ ⭊ ⭋ ⭌ ← → ⇜ ⇝ ↜ ↝ ↩ ↪ ↫ ↬ ↼ ↽ ⇀ ⇁ ⇄ ⇆ ⇇ ⇉ ⇋ ⇌ ⇚ ⇛ ⇠ ⇢ ↷ ↶ ↺ ↻ --> <-- <-->)))
14-
(define prec-lazy-or '(|\|\||))
15-
(define prec-lazy-and '(&&))
14+
(define prec-lazy-or (add-dots '(|\|\||)))
15+
(define prec-lazy-and (add-dots '(&&)))
1616
(define prec-comparison
1717
(append! '(in isa)
1818
(add-dots '(> < >= ≥ <= ≤ == === ≡ != ≠ !== ≢ ∈ ∉ ∋ ∌ ⊆ ⊈ ⊂ ⊄ ⊊ ∝ ∊ ∍ ∥ ∦ ∷ ∺ ∻ ∽ ∾ ≁ ≃ ≂ ≄ ≅ ≆ ≇ ≈ ≉ ≊ ≋ ≌ ≍ ≎ ≐ ≑ ≒ ≓ ≖ ≗ ≘ ≙ ≚ ≛ ≜ ≝ ≞ ≟ ≣ ≦ ≧ ≨ ≩ ≪ ≫ ≬ ≭ ≮ ≯ ≰ ≱ ≲ ≳ ≴ ≵ ≶ ≷ ≸ ≹ ≺ ≻ ≼ ≽ ≾ ≿ ⊀ ⊁ ⊃ ⊅ ⊇ ⊉ ⊋ ⊏ ⊐ ⊑ ⊒ ⊜ ⊩ ⊬ ⊮ ⊰ ⊱ ⊲ ⊳ ⊴ ⊵ ⊶ ⊷ ⋍ ⋐ ⋑ ⋕ ⋖ ⋗ ⋘ ⋙ ⋚ ⋛ ⋜ ⋝ ⋞ ⋟ ⋠ ⋡ ⋢ ⋣ ⋤ ⋥ ⋦ ⋧ ⋨ ⋩ ⋪ ⋫ ⋬ ⋭ ⋲ ⋳ ⋴ ⋵ ⋶ ⋷ ⋸ ⋹ ⋺ ⋻ ⋼ ⋽ ⋾ ⋿ ⟈ ⟉ ⟒ ⦷ ⧀ ⧁ ⧡ ⧣ ⧤ ⧥ ⩦ ⩧ ⩪ ⩫ ⩬ ⩭ ⩮ ⩯ ⩰ ⩱ ⩲ ⩳ ⩵ ⩶ ⩷ ⩸ ⩹ ⩺ ⩻ ⩼ ⩽ ⩾ ⩿ ⪀ ⪁ ⪂ ⪃ ⪄ ⪅ ⪆ ⪇ ⪈ ⪉ ⪊ ⪋ ⪌ ⪍ ⪎ ⪏ ⪐ ⪑ ⪒ ⪓ ⪔ ⪕ ⪖ ⪗ ⪘ ⪙ ⪚ ⪛ ⪜ ⪝ ⪞ ⪟ ⪠ ⪡ ⪢ ⪣ ⪤ ⪥ ⪦ ⪧ ⪨ ⪩ ⪪ ⪫ ⪬ ⪭ ⪮ ⪯ ⪰ ⪱ ⪲ ⪳ ⪴ ⪵ ⪶ ⪷ ⪸ ⪹ ⪺ ⪻ ⪼ ⪽ ⪾ ⪿ ⫀ ⫁ ⫂ ⫃ ⫄ ⫅ ⫆ ⫇ ⫈ ⫉ ⫊ ⫋ ⫌ ⫍ ⫎ ⫏ ⫐ ⫑ ⫒ ⫓ ⫔ ⫕ ⫖ ⫗ ⫘ ⫙ ⫷ ⫸ ⫹ ⫺ ⊢ ⊣ ⟂ <: >:))))
@@ -111,8 +111,8 @@
111111

112112
; operators that are special forms, not function names
113113
(define syntactic-operators
114-
(append! (add-dots '(= += -= *= /= //= |\\=| ^= ÷= %= <<= >>= >>>= |\|=| &= ⊻=))
115-
'(:= $= && |\|\|| |.| ... ->)))
114+
(append! (add-dots '(&& |\|\|| = += -= *= /= //= |\\=| ^= ÷= %= <<= >>= >>>= |\|=| &= ⊻=))
115+
'(:= $= |.| ... ->)))
116116
(define syntactic-unary-operators '($ & |::|))
117117

118118
(define syntactic-op? (Set syntactic-operators))

src/julia-syntax.scm

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,6 +1804,10 @@
18041804
e))))
18051805
((and (pair? e) (eq? (car e) 'comparison))
18061806
(dot-to-fuse (expand-compare-chain (cdr e)) top))
1807+
((and (pair? e) (eq? (car e) '.&&))
1808+
(make-fuse '(top andand) (cdr e)))
1809+
((and (pair? e) (eq? (car e) '|.\|\||))
1810+
(make-fuse '(top oror) (cdr e)))
18071811
(else e)))
18081812
(let ((e (dot-to-fuse rhs #t)) ; an expression '(fuse func args) if expr is a dot call
18091813
(lhs-view (ref-to-view lhs))) ; x[...] expressions on lhs turn in to view(x, ...) to update x in-place
@@ -2125,6 +2129,11 @@
21252129
;; e = (|.| f x)
21262130
(expand-fuse-broadcast '() e)))
21272131

2132+
'.&&
2133+
(lambda (e) (expand-fuse-broadcast '() e))
2134+
'|.\|\||
2135+
(lambda (e) (expand-fuse-broadcast '() e))
2136+
21282137
'.=
21292138
(lambda (e)
21302139
(expand-fuse-broadcast (cadr e) (caddr e)))

test/broadcast.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,35 @@ p0 = copy(p)
951951
@test repr(.!) == "Base.Broadcast.BroadcastFunction(!)"
952952
@test eval(:(.+)) == Base.BroadcastFunction(+)
953953

954+
@testset "Issue #5187: Broadcasting of short-circuiting ops" begin
955+
ex = Meta.parse("A .< 1 .|| A .> 2")
956+
@test ex == :((A .< 1) .|| (A .> 2))
957+
@test ex.head == :.||
958+
ex = Meta.parse("A .< 1 .&& A .> 2")
959+
@test ex == :((A .< 1) .&& (A .> 2))
960+
@test ex.head == :.&&
961+
962+
A = -1:4
963+
@test (A .< 1 .|| A .> 2) == [true, true, false, false, true, true]
964+
@test (A .>= 1 .&& A .<= 2) == [false, false, true, true, false, false]
965+
966+
mutable struct F5187; x; end
967+
(f::F5187)(x) = (f.x += x)
968+
@test (iseven.(1:4) .&& (F5187(0)).(ones(4))) == [false, 1, false, 2]
969+
@test (iseven.(1:4) .|| (F5187(0)).(ones(4))) == [1, true, 2, true]
970+
r = 1:4; o = ones(4); f = F5187(0);
971+
@test (@. iseven(r) && f(o)) == [false, 1, false, 2]
972+
@test (@. iseven(r) || f(o)) == [3, true, 4, true]
973+
974+
@test (iseven.(1:8) .&& iseven.((F5187(0)).(ones(8))) .&& (F5187(0)).(ones(8))) == [false,false,false,1,false,false,false,2]
975+
@test (iseven.(1:8) .|| iseven.((F5187(0)).(ones(8))) .|| (F5187(0)).(ones(8))) == [1,true,true,true,2,true,true,true]
976+
r = 1:8; o = ones(8); f1 = F5187(0); f2 = F5187(0)
977+
@test (@. iseven(r) && iseven(f1(o)) && f2(o)) == [false,false,false,1,false,false,false,2]
978+
@test (@. iseven(r) || iseven(f1(o)) || f2(o)) == [3,true,true,true,4,true,true,true]
979+
@test (iseven.(1:8) .&& iseven.((F5187(0)).(ones(8))) .&& (F5187(0)).(ones(8))) == [false,false,false,1,false,false,false,2]
980+
@test (iseven.(1:8) .|| iseven.((F5187(0)).(ones(8))) .|| (F5187(0)).(ones(8))) == [1,true,true,true,2,true,true,true]
981+
end
982+
954983
@testset "Issue #28382: inferrability of broadcast with Union eltype" begin
955984
@test isequal([1, 2] .+ [3.0, missing], [4.0, missing])
956985
@test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},

0 commit comments

Comments
 (0)