Skip to content

Commit 9e5b423

Browse files
committed
RFC: Dot-broadcasting for short-circuiting ops .&& .||
I have long wanted a proper fix for issue #5187. It was the very first Julia issue I filed. This is a shot at such a fix. This PR: * Enables parsing for `.&&` and `.||`. They are parsed into `Expr(:call, :.&&, ...)` expressions at the same precedence as their respective `&&` and `||`: ```julia-repl julia> Meta.show_sexpr(:(a .&& b)) (:call, :.&&, :a, :b) ``` * Unlike all other dotted operators `.op` (like `.+`), the `op`-alone part (`var"&&"`) is not an exported name from Base. As such, this effectively lowers to `broadcasted((x,y)->x && y, ...)`, but instead of using an anonymous function I've named it `Base.andand` and `Base.oror`: ```julia-repl julia> Meta.@lower a .&& b :($(Expr(:thunk, CodeInfo( @ none within `top-level scope' 1 ─ %1 = Base.broadcasted(Base.andand, a, b) │ %2 = Base.materialize(%1) └── return %2 )))) ``` * I've used a named function to enable short-circuiting behavior _within the broadcast kernel itself_. In the case that the second argument is a part of the same fused broadcast kernel, it will only evaluate if required: ```julia-repl julia> mutable struct F5187; x; end julia> (f::F5187)(x) = (f.x += x) julia> (iseven.(1:4) .|| (F5187(0)).(ones(4))) 4-element Vector{Real}: 1.0 true 2.0 true ``` * This also enables support for standalone `.&&` and `.||` as `BroadcastFunction`s, but of course they are not able to short-circuit when used as functions themselves. Request for feedback -------------------- * [ ] A good bikeshed could be had over the names themselves. We could actually use `var"&&"` and `var"||"` for these names — and it'd _almost_ simplify the implementation, but in order to do so we'd have to actually _export_ them from Base, too. It seems like it might just be confusing. * [ ] Is this the implementation we want? This uses `Broadcast.flatten` to create the lazy function needed for short-circuiting the second argument. This could alternatively be done directly within the parser — perhaps by resurrecting the old 0.5 broadcast parsing behavior. Someone else would have to do that work if they wanted it. * [ ] Do we want to support the stand-alone `.&&` and `.||` `BroadcastFunction`s if they cannot possibly short circuit?
1 parent df7334a commit 9e5b423

File tree

5 files changed

+75
-17
lines changed

5 files changed

+75
-17
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ New language features
66

77
* `(; a, b) = x` can now be used to destructure properties `a` and `b` of `x`. This syntax is equivalent to `a = getproperty(x, :a)`
88
and similarly for `b`. ([#39285])
9+
* The short-circuiting operators `&&` and `||` can now be dotted to participate in broadcast fusion
10+
as `.&&` and `.||`. ([#39593])
911

1012
Language changes
1113
----------------

base/broadcast.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using .Base.Cartesian
1111
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
1212
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
1313
import .Base: copy, copyto!, axes
14-
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastFunction
14+
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastFunction, andand, oror
1515

1616
## Computing the result's axes: deprecated name
1717
const broadcast_axes = axes
@@ -179,6 +179,17 @@ function Broadcasted{Style}(f::F, args::Args, axes=nothing) where {Style, F, Arg
179179
Broadcasted{Style, typeof(axes), Core.Typeof(f), Args}(f, args, axes)
180180
end
181181

182+
andand(a, b) = a && b
183+
function broadcasted(::typeof(andand), a, bc::Broadcasted)
184+
bcf = flatten(bc)
185+
broadcasted((a, args...) -> a && bcf.f(args...), a, bcf.args...)
186+
end
187+
oror(a, b) = a || b
188+
function broadcasted(::typeof(oror), a, bc::Broadcasted)
189+
bcf = flatten(bc)
190+
broadcasted((a, args...) -> a || bcf.f(args...), a, bcf.args...)
191+
end
192+
182193
Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} =
183194
Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes)
184195

@@ -1250,15 +1261,11 @@ function __dot__(x::Expr)
12501261
tmp = x.head === :(<:) ? :.<: : :.>:
12511262
Expr(:call, tmp, dotargs...)
12521263
else
1253-
if x.head === :&& || x.head === :||
1254-
error("""
1255-
Using `&&` and `||` is disallowed in `@.` expressions.
1256-
Use `&` or `|` for elementwise logical operations.
1257-
""")
1258-
end
12591264
head = string(x.head)
12601265
if last(head) == '=' && first(head) != '.'
12611266
Expr(Symbol('.',head), dotargs...)
1267+
elseif head == "&&" || head == "||"
1268+
Expr(:call, Symbol('.', head), dotargs...)
12621269
else
12631270
Expr(x.head, dotargs...)
12641271
end

src/julia-parser.scm

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
(define prec-pair (add-dots '(=>)))
1212
(define prec-conditional '(?))
1313
(define prec-arrow (add-dots '(← → ↔ ↚ ↛ ↞ ↠ ↢ ↣ ↦ ↤ ↮ ⇎ ⇍ ⇏ ⇐ ⇒ ⇔ ⇴ ⇶ ⇷ ⇸ ⇹ ⇺ ⇻ ⇼ ⇽ ⇾ ⇿ ⟵ ⟶ ⟷ ⟹ ⟺ ⟻ ⟼ ⟽ ⟾ ⟿ ⤀ ⤁ ⤂ ⤃ ⤄ ⤅ ⤆ ⤇ ⤌ ⤍ ⤎ ⤏ ⤐ ⤑ ⤔ ⤕ ⤖ ⤗ ⤘ ⤝ ⤞ ⤟ ⤠ ⥄ ⥅ ⥆ ⥇ ⥈ ⥊ ⥋ ⥎ ⥐ ⥒ ⥓ ⥖ ⥗ ⥚ ⥛ ⥞ ⥟ ⥢ ⥤ ⥦ ⥧ ⥨ ⥩ ⥪ ⥫ ⥬ ⥭ ⥰ ⧴ ⬱ ⬰ ⬲ ⬳ ⬴ ⬵ ⬶ ⬷ ⬸ ⬹ ⬺ ⬻ ⬼ ⬽ ⬾ ⬿ ⭀ ⭁ ⭂ ⭃ ⭄ ⭇ ⭈ ⭉ ⭊ ⭋ ⭌ ← → ⇜ ⇝ ↜ ↝ ↩ ↪ ↫ ↬ ↼ ↽ ⇀ ⇁ ⇄ ⇆ ⇇ ⇉ ⇋ ⇌ ⇚ ⇛ ⇠ ⇢ ↷ ↶ ↺ ↻ --> <-- <-->)))
14-
(define prec-lazy-or '(|\|\||))
15-
(define prec-lazy-and '(&&))
14+
(define prec-lazy-or (add-dots '(|\|\||)))
15+
(define prec-lazy-and (add-dots '(&&)))
1616
(define prec-comparison
1717
(append! '(in isa)
1818
(add-dots '(> < >= ≥ <= ≤ == === ≡ != ≠ !== ≢ ∈ ∉ ∋ ∌ ⊆ ⊈ ⊂ ⊄ ⊊ ∝ ∊ ∍ ∥ ∦ ∷ ∺ ∻ ∽ ∾ ≁ ≃ ≂ ≄ ≅ ≆ ≇ ≈ ≉ ≊ ≋ ≌ ≍ ≎ ≐ ≑ ≒ ≓ ≖ ≗ ≘ ≙ ≚ ≛ ≜ ≝ ≞ ≟ ≣ ≦ ≧ ≨ ≩ ≪ ≫ ≬ ≭ ≮ ≯ ≰ ≱ ≲ ≳ ≴ ≵ ≶ ≷ ≸ ≹ ≺ ≻ ≼ ≽ ≾ ≿ ⊀ ⊁ ⊃ ⊅ ⊇ ⊉ ⊋ ⊏ ⊐ ⊑ ⊒ ⊜ ⊩ ⊬ ⊮ ⊰ ⊱ ⊲ ⊳ ⊴ ⊵ ⊶ ⊷ ⋍ ⋐ ⋑ ⋕ ⋖ ⋗ ⋘ ⋙ ⋚ ⋛ ⋜ ⋝ ⋞ ⋟ ⋠ ⋡ ⋢ ⋣ ⋤ ⋥ ⋦ ⋧ ⋨ ⋩ ⋪ ⋫ ⋬ ⋭ ⋲ ⋳ ⋴ ⋵ ⋶ ⋷ ⋸ ⋹ ⋺ ⋻ ⋼ ⋽ ⋾ ⋿ ⟈ ⟉ ⟒ ⦷ ⧀ ⧁ ⧡ ⧣ ⧤ ⧥ ⩦ ⩧ ⩪ ⩫ ⩬ ⩭ ⩮ ⩯ ⩰ ⩱ ⩲ ⩳ ⩵ ⩶ ⩷ ⩸ ⩹ ⩺ ⩻ ⩼ ⩽ ⩾ ⩿ ⪀ ⪁ ⪂ ⪃ ⪄ ⪅ ⪆ ⪇ ⪈ ⪉ ⪊ ⪋ ⪌ ⪍ ⪎ ⪏ ⪐ ⪑ ⪒ ⪓ ⪔ ⪕ ⪖ ⪗ ⪘ ⪙ ⪚ ⪛ ⪜ ⪝ ⪞ ⪟ ⪠ ⪡ ⪢ ⪣ ⪤ ⪥ ⪦ ⪧ ⪨ ⪩ ⪪ ⪫ ⪬ ⪭ ⪮ ⪯ ⪰ ⪱ ⪲ ⪳ ⪴ ⪵ ⪶ ⪷ ⪸ ⪹ ⪺ ⪻ ⪼ ⪽ ⪾ ⪿ ⫀ ⫁ ⫂ ⫃ ⫄ ⫅ ⫆ ⫇ ⫈ ⫉ ⫊ ⫋ ⫌ ⫍ ⫎ ⫏ ⫐ ⫑ ⫒ ⫓ ⫔ ⫕ ⫖ ⫗ ⫘ ⫙ ⫷ ⫸ ⫹ ⫺ ⊢ ⊣ ⟂ <: >:))))
@@ -110,7 +110,7 @@
110110
; operators that are special forms, not function names
111111
(define syntactic-operators
112112
(append! (add-dots '(= += -= *= /= //= |\\=| ^= ÷= %= <<= >>= >>>= |\|=| &= ⊻=))
113-
'(:= $= && |\|\|| |.| ... ->)))
113+
'(&& |\|\|| := $= |.| ... ->)))
114114
(define syntactic-unary-operators '($ & |::|))
115115

116116
(define syntactic-op? (Set syntactic-operators))
@@ -809,8 +809,8 @@
809809
(else ex))))
810810

811811
(define (parse-arrow s) (parse-RtoL s parse-or is-prec-arrow? (eq? t '-->) parse-arrow))
812-
(define (parse-or s) (parse-RtoL s parse-and is-prec-lazy-or? #t parse-or))
813-
(define (parse-and s) (parse-RtoL s parse-comparison is-prec-lazy-and? #t parse-and))
812+
(define (parse-or s) (parse-RtoL s parse-and is-prec-lazy-or? (eq? t '|\|\||) parse-or))
813+
(define (parse-and s) (parse-RtoL s parse-comparison is-prec-lazy-and? (eq? t '&&) parse-and))
814814

815815
(define (parse-comparison s)
816816
(let loop ((ex (parse-pipe< s))

src/julia-syntax.scm

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,10 +1807,10 @@
18071807
(make-fuse f x)))
18081808
(let ((f (cadr e)))
18091809
(cond ((dotop-named? f)
1810-
(make-fuse- (undotop f) (cddr e)))
1810+
(make-fuse- (unshortcircuit (undotop f)) (cddr e)))
18111811
;; (.+)(a, b) is parsed as (call (|.| +) a b), but we still want it to fuse
18121812
((and (length= f 2) (eq? (car f) '|.|))
1813-
(make-fuse- (cadr f) (cddr e)))
1813+
(make-fuse- (unshortcircuit (cadr f)) (cddr e)))
18141814
(else
18151815
e))))
18161816
e)))
@@ -2055,6 +2055,12 @@
20552055
(cons (car e)
20562056
(map expand-forms (cdr e)))))))
20572057

2058+
(define (unshortcircuit op)
2059+
(case op
2060+
((&&) `(top andand))
2061+
((|\|\||) `(top oror))
2062+
(else op)))
2063+
20582064
;; table mapping expression head to a function expanding that form
20592065
(define expand-table
20602066
(table
@@ -2109,10 +2115,15 @@
21092115
(lambda (e)
21102116
(if (length= e 2)
21112117
;; e = (|.| op)
2112-
`(call (top BroadcastFunction) ,(cadr e))
2118+
`(call (top BroadcastFunction) ,(unshortcircuit (cadr e)))
21132119
;; e = (|.| f x)
21142120
(expand-fuse-broadcast '() e)))
21152121

2122+
'.&&
2123+
(lambda (e) (expand-fuse-broadcast '() `(|.| ,(unshortcircuit (undotop (car e))) (tuple ,@(cdr e)))))
2124+
'|.\|\||
2125+
(lambda (e) (expand-fuse-broadcast '() `(|.| ,(unshortcircuit (undotop (car e))) (tuple ,@(cdr e)))))
2126+
21162127
'.=
21172128
(lambda (e)
21182129
(expand-fuse-broadcast (cadr e) (caddr e)))
@@ -2293,10 +2304,10 @@
22932304
(if (length> e 2)
22942305
(let ((f (cadr e)))
22952306
(cond ((dotop-named? f)
2296-
(expand-fuse-broadcast '() `(|.| ,(undotop f) (tuple ,@(cddr e)))))
2307+
(expand-fuse-broadcast '() `(|.| ,(unshortcircuit (undotop f)) (tuple ,@(cddr e)))))
22972308
;; "(.op)(...)"
22982309
((and (length= f 2) (eq? (car f) '|.|))
2299-
(expand-fuse-broadcast '() `(|.| ,(cadr f) (tuple ,@(cddr e)))))
2310+
(expand-fuse-broadcast '() `(|.| ,(unshortcircuit (cadr f)) (tuple ,@(cddr e)))))
23002311
((eq? f 'ccall)
23012312
(if (not (length> e 4)) (error "too few arguments to ccall"))
23022313
(let* ((cconv (cadddr e))

test/broadcast.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,44 @@ p0 = copy(p)
951951
@test repr(.!) == "Base.Broadcast.BroadcastFunction(!)"
952952
@test eval(:(.+)) == Base.BroadcastFunction(+)
953953

954+
@testset "Issue #5187: Broadcasting of short-circuiting ops" begin
955+
ex = Meta.parse("A .< 1 .|| A .> 2")
956+
@test ex == :((A .< 1) .|| (A .> 2))
957+
@test ex.head == Symbol(".||")
958+
ex = Meta.parse("A .< 1 .&& A .> 2")
959+
@test ex == :((A .< 1) .&& (A .> 2))
960+
@test ex.head == Symbol(".&&")
961+
962+
A = -1:4
963+
@test (A .< 1 .|| A .> 2) == [true, true, false, false, true, true]
964+
@test (A .>= 1 .&& A .<= 2) == [false, false, true, true, false, false]
965+
966+
mutable struct F5187; x; end
967+
(f::F5187)(x) = (f.x += x)
968+
@test (iseven.(1:4) .&& (F5187(0)).(ones(4))) == [false, 1, false, 2]
969+
@test (iseven.(1:4) .|| (F5187(0)).(ones(4))) == [1, true, 2, true]
970+
@test (.&&)(iseven.(1:4), (F5187(0)).(ones(4))) == [false, 1, false, 2]
971+
@test (.||)(iseven.(1:4), (F5187(0)).(ones(4))) == [1, true, 2, true]
972+
r = 1:4; o = ones(4); f = F5187(0);
973+
@test (@. iseven(r) && f(o)) == [false, 1, false, 2]
974+
@test (@. iseven(r) || f(o)) == [3, true, 4, true]
975+
976+
@test (iseven.(1:8) .&& iseven.((F5187(0)).(ones(8))) .&& (F5187(0)).(ones(8))) == [false,false,false,1,false,false,false,2]
977+
@test (iseven.(1:8) .|| iseven.((F5187(0)).(ones(8))) .|| (F5187(0)).(ones(8))) == [1,true,true,true,2,true,true,true]
978+
r = 1:8; o = ones(8); f1 = F5187(0); f2 = F5187(0)
979+
@test (@. iseven(r) && iseven(f1(o)) && f2(o)) == [false,false,false,1,false,false,false,2]
980+
@test (@. iseven(r) || iseven(f1(o)) || f2(o)) == [3,true,true,true,4,true,true,true]
981+
@test (.&&)(iseven.(1:8), (.&&)(iseven.((F5187(0)).(ones(8))), (F5187(0)).(ones(8)))) == [false,false,false,1,false,false,false,2]
982+
@test (.||)(iseven.(1:8), (.||)(iseven.((F5187(0)).(ones(8))), (F5187(0)).(ones(8)))) == [1,true,true,true,2,true,true,true]
983+
@test (iseven.(1:8) .&& (.&&)(iseven.((F5187(0)).(ones(8))), (F5187(0)).(ones(8)))) == [false,false,false,1,false,false,false,2]
984+
@test (iseven.(1:8) .|| (.||)(iseven.((F5187(0)).(ones(8))), (F5187(0)).(ones(8)))) == [1,true,true,true,2,true,true,true]
985+
@test (.&&)(iseven.(1:8), iseven.((F5187(0)).(ones(8))) .&& (F5187(0)).(ones(8))) == [false,false,false,1,false,false,false,2]
986+
@test (.||)(iseven.(1:8), iseven.((F5187(0)).(ones(8))) .|| (F5187(0)).(ones(8))) == [1,true,true,true,2,true,true,true]
987+
988+
@test map(.&&, iseven.(1:4), 1:4) == map((x,y)->x&&y, iseven.(1:4), 1:4) == [false, 2, false, 4]
989+
@test map(.||, iseven.(1:4), 1:4) == map((x,y)->x||y, iseven.(1:4), 1:4) == [1, true, 3, true]
990+
end
991+
954992
@testset "Issue #28382: inferrability of broadcast with Union eltype" begin
955993
@test isequal([1, 2] .+ [3.0, missing], [4.0, missing])
956994
@test_broken Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},

0 commit comments

Comments
 (0)