How to avoid memory allocations in custom Julia iterators?

455 views Asked by At

Consider the following Julia "compound" iterator: it merges two iterators, a and b, each of which are assumed to be sorted according to order, to a single ordered sequence:

struct MergeSorted{T,A,B,O}
    a::A
    b::B
    order::O

    MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O} =
        new{promote_type(eltype(A),eltype(B)),A,B,O}(a, b, order)
end

Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T

@inline function Base.iterate(self::MergeSorted{T}, 
                      state=(iterate(self.a), iterate(self.b))) where T
    a_result, b_result = state
    if b_result === nothing
        a_result === nothing && return nothing
        a_curr, a_state = a_result
        return T(a_curr), (iterate(self.a, a_state), b_result)
    end

    b_curr, b_state = b_result
    if a_result !== nothing
        a_curr, a_state = a_result
        Base.Order.lt(self.order, a_curr, b_curr) &&
            return T(a_curr), (iterate(self.a, a_state), b_result)
    end
    return T(b_curr), (a_result, iterate(self.b, b_state))
end

This code works, but is type-instable since the Julia iteration facilities are inherently so. For most cases, the compiler can work this out automatically, however, here it does not work: the following test code illustrates that temporaries are created:

>>> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);
>>> sum(x);
>>> @time sum(x);
0.000013 seconds (61 allocations: 2.312 KiB)

Note the allocation count.

Is there any way to efficiently debug such situations other than playing around with the code and hoping that the compiler will be able to optimize out the type ambiguities? Does anyone know there any solution in this specific case that does not create temporaries?

1

There are 1 answers

3
Bogumił Kamiński On BEST ANSWER

How to diagnose the problem?

Answer: use @code_warntype

Run:

julia> @code_warntype iterate(x, iterate(x)[2])
Variables
  #self#::Core.Const(iterate)
  self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering}
  state::Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
  @_4::Int64
  @_5::Int64
  @_6::Union{}
  @_7::Int64
  b_state::Int64
  b_curr::Int64
  a_state::Int64
  a_curr::Int64
  b_result::Tuple{Int64, Int64}
  a_result::Tuple{Int64, Int64}

Body::Tuple{Int64, Any}
1 ─       nothing
│         Core.NewvarNode(:(@_4))
│         Core.NewvarNode(:(@_5))
│         Core.NewvarNode(:(@_6))
│         Core.NewvarNode(:(b_state))
│         Core.NewvarNode(:(b_curr))
│         Core.NewvarNode(:(a_state))
│         Core.NewvarNode(:(a_curr))
│   %9  = Base.indexed_iterate(state, 1)::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(2)])
│         (a_result = Core.getfield(%9, 1))
│         (@_7 = Core.getfield(%9, 2))
│   %12 = Base.indexed_iterate(state, 2, @_7::Core.Const(2))::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(3)])
│         (b_result = Core.getfield(%12, 1))
│   %14 = (b_result === Main.nothing)::Core.Const(false)
└──       goto #3 if not %14
2 ─       Core.Const(:(a_result === Main.nothing))
│         Core.Const(:(%16))
│         Core.Const(:(return Main.nothing))
│         Core.Const(:(Base.indexed_iterate(a_result, 1)))
│         Core.Const(:(a_curr = Core.getfield(%19, 1)))
│         Core.Const(:(@_6 = Core.getfield(%19, 2)))
│         Core.Const(:(Base.indexed_iterate(a_result, 2, @_6)))
│         Core.Const(:(a_state = Core.getfield(%22, 1)))
│         Core.Const(:(($(Expr(:static_parameter, 1)))(a_curr)))
│         Core.Const(:(Base.getproperty(self, :a)))
│         Core.Const(:(Main.iterate(%25, a_state)))
│         Core.Const(:(Core.tuple(%26, b_result)))
│         Core.Const(:(Core.tuple(%24, %27)))
└──       Core.Const(:(return %28))
3 ┄ %30 = Base.indexed_iterate(b_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (b_curr = Core.getfield(%30, 1))
│         (@_5 = Core.getfield(%30, 2))
│   %33 = Base.indexed_iterate(b_result, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (b_state = Core.getfield(%33, 1))
│   %35 = (a_result !== Main.nothing)::Core.Const(true)
└──       goto #6 if not %35
4 ─ %37 = Base.indexed_iterate(a_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (a_curr = Core.getfield(%37, 1))
│         (@_4 = Core.getfield(%37, 2))
│   %40 = Base.indexed_iterate(a_result, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (a_state = Core.getfield(%40, 1))
│   %42 = Base.Order::Core.Const(Base.Order)
│   %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
│   %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
│   %45 = a_curr::Int64
│   %46 = (%43)(%44, %45, b_curr)::Bool
└──       goto #6 if not %46
5 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│   %49 = Base.getproperty(self, :a)::Vector{Int64}
│   %50 = Main.iterate(%49, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│   %51 = Core.tuple(%50, b_result)::Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}
│   %52 = Core.tuple(%48, %51)::Tuple{Int64, Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}}
└──       return %52
6 ┄ %54 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
│   %55 = a_result::Tuple{Int64, Int64}
│   %56 = Base.getproperty(self, :b)::Vector{Int64}
│   %57 = Main.iterate(%56, b_state)::Union{Nothing, Tuple{Int64, Int64}}
│   %58 = Core.tuple(%55, %57)::Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}
│   %59 = Core.tuple(%54, %58)::Tuple{Int64, Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}}
└──       return %59

and you see that there are too many types of return value, so Julia gives up specializing them (and just assumes the second element of return type is Any).

How to fix the problem?

Answer: reduce the number of return type options of iterate.

Here is a quick write up (I do not claim it is most terse and have not tested it extensively so there might be some bug, but it was simple enough to write quickly using your code to show how one could approach your problem; note that I use special branches when one of the collections is empty as then it should be faster to just iterate one collection):

struct MergeSorted{T,A,B,O,F1,F2}
    a::A
    b::B
    order::O
    fa::F1
    fb::F2
    function MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O}
        fa, fb = iterate(a), iterate(b)
        F1 = typeof(fa)
        F2 = typeof(fb)
        new{promote_type(eltype(A),eltype(B)),A,B,O,F1,F2}(a, b, order, fa, fb)
    end
end

Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T

struct State{Ta, Tb}
    a::Union{Nothing, Ta}
    b::Union{Nothing, Tb}
end

function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,Nothing}) where {T,A,B,O}
    return nothing
end

function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}) where {T,A,B,O,F1}
    return self.fa
end

function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}, state) where {T,A,B,O,F1}
    return iterate(self.a, state)
end

function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}) where {T,A,B,O,F2}
    return self.fb
end

function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}, state) where {T,A,B,O,F2}
    return iterate(self.b, state)
end

@inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2}) where {T,A,B,O,F1,F2}
    a_result, b_result = self.fa, self.fb
    return iterate(self, State{F1,F2}(a_result, b_result))
end

@inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2}, 
    state::State{F1,F2}) where {T,A,B,O,F1,F2}
    a_result, b_result = state.a, state.b

    if b_result === nothing
        a_result === nothing && return nothing
        a_curr, a_state = a_result
        return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
    end

    b_curr, b_state = b_result
    if a_result !== nothing
        a_curr, a_state = a_result
        Base.Order.lt(self.order, a_curr, b_curr) &&
            return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
    end
    return T(b_curr), State{F1,F2}(a_result, iterate(self.b, b_state))
end

And now you have:

julia> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);

julia> sum(x)
269

julia> @allocated sum(x)
0

julia> @code_warntype iterate(x, iterate(x)[2])
Variables
  #self#::Core.Const(iterate)
  self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering, Tuple{Int64, Int64}, Tuple{Int64, Int64}}
  state::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
  @_4::Int64
  @_5::Int64
  @_6::Int64
  b_state::Int64
  b_curr::Int64
  a_state::Int64
  a_curr::Int64
  b_result::Union{Nothing, Tuple{Int64, Int64}}
  a_result::Union{Nothing, Tuple{Int64, Int64}}

Body::Union{Nothing, Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}}
1 ─       nothing
│         Core.NewvarNode(:(@_4))
│         Core.NewvarNode(:(@_5))
│         Core.NewvarNode(:(@_6))
│         Core.NewvarNode(:(b_state))
│         Core.NewvarNode(:(b_curr))
│         Core.NewvarNode(:(a_state))
│         Core.NewvarNode(:(a_curr))
│   %9  = Base.getproperty(state, :a)::Union{Nothing, Tuple{Int64, Int64}}
│   %10 = Base.getproperty(state, :b)::Union{Nothing, Tuple{Int64, Int64}}
│         (a_result = %9)
│         (b_result = %10)
│   %13 = (b_result === Main.nothing)::Bool
└──       goto #5 if not %13
2 ─ %15 = (a_result === Main.nothing)::Bool
└──       goto #4 if not %15
3 ─       return Main.nothing
4 ─ %18 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (a_curr = Core.getfield(%18, 1))
│         (@_6 = Core.getfield(%18, 2))
│   %21 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_6::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (a_state = Core.getfield(%21, 1))
│   %23 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│   %24 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│   %25 = Base.getproperty(self, :a)::Vector{Int64}
│   %26 = Main.iterate(%25, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│   %27 = (%24)(%26, b_result::Core.Const(nothing))::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│   %28 = Core.tuple(%23, %27)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└──       return %28
5 ─ %30 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (b_curr = Core.getfield(%30, 1))
│         (@_5 = Core.getfield(%30, 2))
│   %33 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (b_state = Core.getfield(%33, 1))
│   %35 = (a_result !== Main.nothing)::Bool
└──       goto #8 if not %35
6 ─ %37 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (a_curr = Core.getfield(%37, 1))
│         (@_4 = Core.getfield(%37, 2))
│   %40 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (a_state = Core.getfield(%40, 1))
│   %42 = Base.Order::Core.Const(Base.Order)
│   %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
│   %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
│   %45 = a_curr::Int64
│   %46 = (%43)(%44, %45, b_curr)::Bool
└──       goto #8 if not %46
7 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│   %49 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│   %50 = Base.getproperty(self, :a)::Vector{Int64}
│   %51 = Main.iterate(%50, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│   %52 = (%49)(%51, b_result::Tuple{Int64, Int64})::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│   %53 = Core.tuple(%48, %52)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└──       return %53
8 ┄ %55 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
│   %56 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│   %57 = a_result::Union{Nothing, Tuple{Int64, Int64}}
│   %58 = Base.getproperty(self, :b)::Vector{Int64}
│   %59 = Main.iterate(%58, b_state)::Union{Nothing, Tuple{Int64, Int64}}
│   %60 = (%56)(%57, %59)::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│   %61 = Core.tuple(%55, %60)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└──       return %61

EDIT: now I have realized that my implementation is not fully correct, as it assumes that the return value of iterate if it is not nothing is type stable (which it does not have to be). But if it is not type stable then compiler must allocate. So a fully correct solution would first check if iterate is type stable. If it is - use my solution, and if it is not - use e.g. your solution.