Is there a unified syntax for element-wise in-place operations on scalars and arrays in Julia?

229 views Asked by At

Consider the following accumulator type, which works like an array in that you can push things to it, but only tracks its mean:

mutable struct Accumulator{T}
    data::T
    count::Int64
end

function Base.push!(acc::Accumulator, term)
    acc.data += term       # <-- in-place addition
    acc.count += 1
    acc
end

mean(acc::Accumulator) = acc.data ./ acc.count

I want this to work for T being a scalar or an array type. However, it turns out that for T being an array type, the addition in push! creates a temporary. This is because in Julia, x+=a is equivalent to x=x+a, and I suspect Julia cannot guarantee that acc.data and term do not alias.

A simple fix is to replace += with element-wise addition, .+=. However, this will then break scalar types, which do not allow this. So the only way I came up with to fix this problem is to add a specialization of the following form:

function Base.push!(acc::Accumulator, term::AbstractArray)
    acc.data .+= term       # <-- element-wise addition
    acc.count += 1
    acc
end

This is however somewhat ugly and also brittle... does anyone know a better way of doing this, preferrably in a generic fashion and without the temporary creation?

1

There are 1 answers

2
Cameron Bieganek On BEST ANSWER

Oddly enough, Numbers are iterable in Julia, but that doesn't seem to help us here, because there is no setindex! method for Numbers.

Here are two different approaches. The first uses iterator traits and the second just patches up the method signatures a bit to address corner cases.

Iterator traits

We can use the IteratorSize trait to distinguish between scalars and vectors. For scalars, Base.IteratorSize(x) returns Base.HasShape{0}. For arrays, Base.IteratorSize(x) returns Base.HasShape{N}, where N is the number of dimensions of the array.

mutable struct Accumulator{T}
    data::T
    count::Int64
end

function Base.push!(acc::Accumulator{T}, term::S) where {T, S}
    _push_acc!(Base.IteratorSize(T), Base.IteratorSize(S), acc, term)
end

function _push_acc!(::Base.HasShape{0}, ::Base.HasShape{0}, acc::Accumulator, term)
    acc.data += term
    acc.count += 1
    acc
end

function _push_acc!(::Base.HasShape{N}, ::Base.HasShape{N}, acc::Accumulator, term) where {N}
    acc.data .+= term
    acc.count += 1
    acc
end

function _push_acc!(::Base.HasShape{M}, ::Base.HasShape{N}, ::Accumulator, ::Any) where {M, N}
    throw(ArgumentError("Accumulator and term have inconsistent shapes"))
end

In action at the REPL:

julia> a = Accumulator(1, 0)
Accumulator{Int64}(1, 0)

julia> b = Accumulator([1, 2], 0)
Accumulator{Array{Int64,1}}([1, 2], 0)

julia> push!(a, 42)
Accumulator{Int64}(43, 1)

julia> push!(b, [3, 4])
Accumulator{Array{Int64,1}}([4, 6], 1)

julia> push!(a, [5, 6])
ERROR: ArgumentError: Accumulator and term have inconsistent shapes
Stacktrace:
 [1] _push_acc!(::Base.HasShape{0}, ::Base.HasShape{1}, ::Accumulator{Int64}, ::Array{Int64,1}) at ...
 [2] push!(::Accumulator{Int64}, ::Array{Int64,1}) at ...
 [3] top-level scope at REPL[6]:1

julia> push!(b, 10)
ERROR: ArgumentError: Accumulator and term have inconsistent shapes
Stacktrace:
 [1] _push_acc!(::Base.HasShape{1}, ::Base.HasShape{0}, ::Accumulator{Array{Int64,1}}, ::Int64) at ...
 [2] push!(::Accumulator{Array{Int64,1}}, ::Int64) at ...
 [3] top-level scope at REPL[7]:1

Patching the method signatures

Instead of using iterator traits, we could just make a couple small tweaks to your push! method signatures to prevent pushing an array onto a scalar.

mutable struct Accumulator{T}
    data::T
    count::Int64
end

function Base.push!(acc::Accumulator, term)
    acc.data += term
    acc.count += 1
    acc
end

function Base.push!(acc::Accumulator{T}, term::AbstractArray) where {T <: AbstractArray}
    acc.data .+= term
    acc.count += 1
    acc
end

function Base.push!(::Accumulator, ::AbstractArray)
    throw(ArgumentError("Can't push an array onto a scalar"))
end

Now we get a sensible error message if we try to push an array onto a scalar:

julia> a = Accumulator(42, 0)
Accumulator{Int64}(42, 0)

julia> push!(a, [1, 2])
ERROR: ArgumentError: Can't push an array onto a scalar