Consider the following Julia "compound" iterator: it merges two iterators, a
and b
,
each of which are assumed to be sorted according to order
, to a single ordered
sequence:
struct MergeSorted{T,A,B,O}
a::A
b::B
order::O
MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O} =
new{promote_type(eltype(A),eltype(B)),A,B,O}(a, b, order)
end
Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T
@inline function Base.iterate(self::MergeSorted{T},
state=(iterate(self.a), iterate(self.b))) where T
a_result, b_result = state
if b_result === nothing
a_result === nothing && return nothing
a_curr, a_state = a_result
return T(a_curr), (iterate(self.a, a_state), b_result)
end
b_curr, b_state = b_result
if a_result !== nothing
a_curr, a_state = a_result
Base.Order.lt(self.order, a_curr, b_curr) &&
return T(a_curr), (iterate(self.a, a_state), b_result)
end
return T(b_curr), (a_result, iterate(self.b, b_state))
end
This code works, but is type-instable since the Julia iteration facilities are inherently so. For most cases, the compiler can work this out automatically, however, here it does not work: the following test code illustrates that temporaries are created:
>>> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);
>>> sum(x);
>>> @time sum(x);
0.000013 seconds (61 allocations: 2.312 KiB)
Note the allocation count.
Is there any way to efficiently debug such situations other than playing around with the code and hoping that the compiler will be able to optimize out the type ambiguities? Does anyone know there any solution in this specific case that does not create temporaries?
How to diagnose the problem?
Answer: use
@code_warntype
Run:
and you see that there are too many types of return value, so Julia gives up specializing them (and just assumes the second element of return type is
Any
).How to fix the problem?
Answer: reduce the number of return type options of
iterate
.Here is a quick write up (I do not claim it is most terse and have not tested it extensively so there might be some bug, but it was simple enough to write quickly using your code to show how one could approach your problem; note that I use special branches when one of the collections is empty as then it should be faster to just iterate one collection):
And now you have:
EDIT: now I have realized that my implementation is not fully correct, as it assumes that the return value of
iterate
if it is notnothing
is type stable (which it does not have to be). But if it is not type stable then compiler must allocate. So a fully correct solution would first check if iterate is type stable. If it is - use my solution, and if it is not - use e.g. your solution.