Type-stability in Julia's product iterator

172 views Asked by At

I am trying to make A in the following code type-stable.

using Primes: factor

function f(n::T, p::T, k::T) where {T<:Integer}
    return rand(T, n * p^k)
end

function g(m::T, n::T) where {T<:Integer}
    
    i = 0
    for A in Iterators.product((f(n, p, T(k)) for (p, k) in factor(m))...)
        i = sum(A)
    end
    return i

end

Note that f is type-stable. The variable A is not type-stable because the product iterator will return different sized tuples depending on the values of n and m. If there was an iterator like the product iterator that returned a Vector instead of a Tuple, I believe that the type-instability would go away.

Does anyone have any suggestions to make A type-stable in the above code?

Edit: I should add that f returns a variable-sized Vector of type T.

One way I have solved the type-stability is by doing this.

function g(m::T, n::T) where {T<:Integer}
    B = Vector{T}[T[]]
    for (p, k) in factor(m)
        C = Vector{T}[]
        for (b, r) in Iterators.product(B, f(n, p, T(k)))
            c = copy(b)
            push!(c, r)
            push!(C, c)
        end
        B = C
    end

    for A in B
        i = sum(A)
    end

    return i
end

This (and in particular, A) is now type-stable, but at the cost lots of memory. I'm not sure of a better way to do this.

1

There are 1 answers

0
DNF On BEST ANSWER

It's not easy to get this completely type stable, but you can isolate the type instability with a function barrier. Convert the factorization to a tuple in an outer function, which you pass to an inner function which is type stable. This gives just one dynamic dispatch, instead of many:

# inner, type stable
function _g(n, tup)
    i = 0
    for A in Iterators.product((f(n, p, k) for (p, k) in tup)...)
        i += sum(A)  # or i = sum(A), whatever
    end
    return i
end

# outer function
g(m::T, n::T) where {T<:Integer} = _g(n, Tuple(factor(m)))

Some benchmarks:

julia> @btime g(7, 210);  # OP version
  149.600 μs (7356 allocations: 172.62 KiB)

julia> @btime g(7, 210);  # my version
  1.140 μs (6 allocations: 11.91 KiB)

You should expect to hit compilation occasionally, whenever you get a number that contains a new number of factors.