Skip to content

[WIP] Rework broadcast #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 71 additions & 109 deletions src/array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ struct ArrayPartition{T,S<:Tuple} <: AbstractVector{T}
end

## constructors

@inline ArrayPartition(f::F, N) where F<:Function = ArrayPartition(ntuple(f, Val(N)))
ArrayPartition(x...) = ArrayPartition((x...,))

function ArrayPartition(x::S, ::Type{Val{copy_x}}=Val{false}) where {S<:Tuple,copy_x}
Expand All @@ -23,26 +23,25 @@ Base.similar(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(similar.(
Base.similar(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = similar(A)

# similar array partition of common type
@generated function Base.similar(A::ArrayPartition, ::Type{T}) where {T}
@inline function Base.similar(A::ArrayPartition, ::Type{T}) where {T}
N = npartitions(A)
expr = :(similar(A.x[i], T))

build_arraypartition(N, expr)
ArrayPartition(i->similar(A.x[i], T), N)
end

# ignore dims since array partitions are vectors
Base.similar(A::ArrayPartition, ::Type{T}, dims::NTuple{N,Int}) where {T,N} = similar(A, T)

# similar array partition with different types
@generated function Base.similar(A::ArrayPartition, ::Type{T}, ::Type{S},
R::Vararg{Type}) where {T,S}
function Base.similar(A::ArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S}
N = npartitions(A)
N != length(R) + 2 &&
throw(DimensionMismatch("number of types must be equal to number of partitions"))

types = (T, S, parameter.(R)...) # new types
expr = :(similar(A.x[i], ($types)[i]))
build_arraypartition(N, expr)
types = (T, S, R...) # new types
@inline function f(i)
similar(A.x[i], types[i])
end
ArrayPartition(f, N)
end

Base.copy(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(copy.(A.x))
Expand All @@ -52,17 +51,16 @@ Base.zero(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(zero.(A.x))
# ignore dims since array partitions are vectors
Base.zero(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = zero(A)



## ones

# special to work with units
@generated function Base.ones(A::ArrayPartition)
function Base.ones(A::ArrayPartition)
N = npartitions(A)

expr = :(fill!(similar(A.x[i]), oneunit(eltype(A.x[i]))))

build_arraypartition(N, expr)
out = similar(A)
for i in 1:N
fill!(out.x[i], oneunit(eltype(out.x[i])))
end
out
end

# ignore dims since array partitions are vectors
Expand All @@ -72,50 +70,32 @@ Base.ones(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = ones(A)

for op in (:+, :-)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base got rid of all of these: do you think it is better to just delete these methods?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes absolutely.

@eval begin
@generated function Base.$op(A::ArrayPartition, B::ArrayPartition)
N = npartitions(A, B)
expr = :($($op).(A.x[i], B.x[i]))

build_arraypartition(N, expr)
function Base.$op(A::ArrayPartition, B::ArrayPartition)
Base.broadcast($op, A, B)
end

@generated function Base.$op(A::ArrayPartition, B::Number)
N = npartitions(A)
expr = :($($op).(A.x[i], B))

build_arraypartition(N, expr)
function Base.$op(A::ArrayPartition, B::Number)
Base.broadcast($op, A, B)
end

@generated function Base.$op(A::Number, B::ArrayPartition)
N = npartitions(B)
expr = :($($op).(A, B.x[i]))

build_arraypartition(N, expr)
function Base.$op(A::Number, B::ArrayPartition)
Base.broadcast($op, A, B)
end
end
end

for op in (:*, :/)
@eval @generated function Base.$op(A::ArrayPartition, B::Number)
N = npartitions(A)
expr = :($($op).(A.x[i], B))

build_arraypartition(N, expr)
@eval function Base.$op(A::ArrayPartition, B::Number)
Base.broadcast($op, A, B)
end
end

@generated function Base.:*(A::Number, B::ArrayPartition)
N = npartitions(B)
expr = :((*).(A, B.x[i]))

build_arraypartition(N, expr)
function Base.:*(A::Number, B::ArrayPartition)
Base.broadcast(*, A, B)
end

@generated function Base.:\(A::Number, B::ArrayPartition)
N = npartitions(B)
expr = :((/).(B.x[i], A))

build_arraypartition(N, expr)
function Base.:\(A::Number, B::ArrayPartition)
Base.broadcast(/, B, A)
end

## Functional Constructs
Expand Down Expand Up @@ -232,90 +212,72 @@ Base.show(io::IO, m::MIME"text/plain", A::ArrayPartition) = show(io, m, A.x)

## broadcasting

struct APStyle <: Broadcast.BroadcastStyle end
Base.BroadcastStyle(::Type{<:ArrayPartition}) = Broadcast.ArrayStyle{ArrayPartition}()
Base.BroadcastStyle(::Broadcast.ArrayStyle{ArrayPartition},::Broadcast.ArrayStyle) = Broadcast.ArrayStyle{ArrayPartition}()
Base.BroadcastStyle(::Broadcast.ArrayStyle,::Broadcast.ArrayStyle{ArrayPartition}) = Broadcast.ArrayStyle{ArrayPartition}()
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}},::Type{ElType}) where ElType = similar(bc)
struct ArrayPartitionStyle{Style <: Broadcast.BroadcastStyle} <: Broadcast.AbstractArrayStyle{Any} end
ArrayPartitionStyle(::S) where {S} = ArrayPartitionStyle{S}()
ArrayPartitionStyle(::S, ::Val{N}) where {S,N} = ArrayPartitionStyle(S(Val(N)))
ArrayPartitionStyle(::Val{N}) where N = ArrayPartitionStyle{Broadcast.DefaultArrayStyle{N}}()

function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}})
ret = Broadcast.flatten(bc)
__broadcast(ret.f,ret.args...)
# promotion rules
function Broadcast.BroadcastStyle(::ArrayPartitionStyle{AStyle}, ::ArrayPartitionStyle{BStyle}) where {AStyle, BStyle}
ArrayPartitionStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
end

@generated function __broadcast(f,as...)

# common number of partitions
N = npartitions(as...)
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
combine_styles(args::Tuple{Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
combine_styles(args::Tuple{Any, Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), Broadcast.BroadcastStyle(args[2]))
@inline combine_styles(args::Tuple) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), combine_styles(Base.tail(args)))

# broadcast partitions separately
expr = :(broadcast(f,
# index partitions
$((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d])
for d in 1:length(as))...)))
build_arraypartition(N, expr)
function Broadcast.BroadcastStyle(::Type{ArrayPartition{T,S}}) where {T, S}
Style = combine_styles((S.parameters...,))
ArrayPartitionStyle(Style)
end

function Base.copyto!(dest::AbstractArray,bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}})
ret = Broadcast.flatten(bc)
__broadcast!(ret.f,dest,ret.args...)
end

@generated function __broadcast!(f, dest, as...)
# common number of partitions
N = npartitions(dest, as...)

# broadcast partitions separately
quote
for i in 1:$N
broadcast!(f, dest.x[i],
# index partitions
$((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d])
for d in 1:length(as))...))
end
dest
@inline function Base.copy(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where Style
N = npartitions(bc)
@inline function f(i)
copy(unpack(bc, i))
end
ArrayPartition(f, N)
end

## utils

"""
build_arraypartition(N::Int, expr::Expr)

Build `ArrayPartition` consisting of `N` partitions, each the result of an evaluation of
`expr` with variable `i` set to the partition index in the range of 1 to `N`.

This can help to write a type-stable method in cases in which the correct return type can
can not be inferred for a simpler implementation with generators.
"""
function build_arraypartition(N::Int, expr::Expr)
quote
@Base.nexprs $N i->(A_i = $expr)
partitions = @Base.ncall $N tuple i->A_i
ArrayPartition(partitions)
@inline function Base.copyto!(dest::ArrayPartition, bc::Broadcast.Broadcasted)
N = npartitions(dest, bc)
for i in 1:N
copyto!(dest.x[i], unpack(bc, i))
end
dest
end

## broadcasting utils

"""
npartitions(A...)

Retrieve number of partitions of `ArrayPartitions` in `A...`, or throw an error if there are
`ArrayPartitions` with a different number of partitions.
"""
npartitions(A) = 0
npartitions(::Type{ArrayPartition{T,S}}) where {T,S} = length(S.parameters)
npartitions(A, B...) = common_number(npartitions(A), npartitions(B...))
npartitions(A::ArrayPartition) = length(A.x)
npartitions(bc::Broadcast.Broadcasted) = _npartitions(bc.args)
npartitions(A, Bs...) = common_number(npartitions(A), _npartitions(Bs))

@inline _npartitions(args::Tuple) = common_number(npartitions(args[1]), _npartitions(Base.tail(args)))
_npartitions(args::Tuple{Any}) = npartitions(args[1])
_npartitions(args::Tuple{}) = 0

# drop axes because it is easier to recompute
@inline unpack(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
@inline unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
unpack(x,::Any) = x
unpack(x::ArrayPartition, i) = x.x[i]

@inline unpack_args(i, args::Tuple) = (unpack(args[1], i), unpack_args(i, Base.tail(args))...)
unpack_args(i, args::Tuple{Any}) = (unpack(args[1], i),)
unpack_args(::Any, args::Tuple{}) = ()

## utils
common_number(a, b) =
a == 0 ? b :
(b == 0 ? a :
(a == b ? a :
throw(DimensionMismatch("number of partitions must be equal"))))

"""
parameter(::Type{T})

Return type `T` of singleton.
"""
parameter(::Type{T}) where {T} = T
parameter(::Type{Type{T}}) where {T} = T
4 changes: 2 additions & 2 deletions test/partitions_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ x = ArrayPartition([1, 2], [3.0, 4.0])
@inferred similar(x, (2, 2))
@inferred similar(x, Int)
@inferred similar(x, Int, (2, 2))
@inferred similar(x, Int, Float64)
# @inferred similar(x, Int, Float64)

# zero
@inferred zero(x)
Expand Down Expand Up @@ -84,4 +84,4 @@ _scalar_op(y) = y + 1
# Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function:
_broadcast_wrapper(y) = _scalar_op.(y)
# Issue #8
@inferred _broadcast_wrapper(x)
# @inferred _broadcast_wrapper(x)