Type stability for lists of closures

373 views Asked by At

I'm trying to design some code in Julia which will take a list of user-supplied functions and essentially apply some algebraic operations to them.

It appears that the return value of this list of functions will not be inferred if they are closures, leading to type-unstable code according to @code_warntype.

I tried supplying a return type with the closures but did not seem to be able to find the correct syntax.

Here is an example:

functions = Function[x -> x]

function f(u)
    ret = zeros(eltype(u), length(u))

    for func in functions
        ret .+= func(u)
    end

    ret
end

Run this:

u0 = [1.0, 2.0, 3.0]
@code_warntype f(u0)

and obtain

Body::Array{Float64,1}
1 ─ %1  = (Base.arraylen)(u)::Int64
│   %2  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Float64,1}, svec(Any, Int64), :(:ccall), 2, Array{Float64,1}, :(%1), :(%1)))::Array{Float64,1}
│   %3  = invoke Base.fill!(%2::Array{Float64,1}, 0.0::Float64)::Array{Float64,1}
│   %4  = Main.functions::Any
│   %5  = (Base.iterate)(%4)::Any
│   %6  = (%5 === nothing)::Bool
│   %7  = (Base.not_int)(%6)::Bool
└──       goto #4 if not %7
2 ┄ %9  = φ (#1 => %5, #3 => %15)::Any
│   %10 = (Core.getfield)(%9, 1)::Any
│   %11 = (Core.getfield)(%9, 2)::Any
│   %12 = (%10)(u)::Any
│   %13 = (Base.broadcasted)(Main.:+, %3, %12)::Any
│         (Base.materialize!)(%3, %13)
│   %15 = (Base.iterate)(%4, %11)::Any
│   %16 = (%15 === nothing)::Bool
│   %17 = (Base.not_int)(%16)::Bool
└──       goto #4 if not %17
3 ─       goto #2
4 ┄       return %3

So, how do I make this code type stable?

2

There are 2 answers

3
tholy On BEST ANSWER

If you want type-stability for arbitrary functions, you'll have to pass them as a tuple, which allows julia to know in advance which function will be applied at which stage.

function fsequential(u, fs::Fs) where Fs<:Tuple
    ret = similar(u)
    fill!(ret, 0)
    return fsequential!(ret, u, fs...)
end

@inline function fsequential!(ret, u, f::F, fs...) where F
    ret .+= f(u)
    return fsequential!(ret, u, fs...)
end
fsequential!(ret, u) = ret

julia> u0 = [1.0, 2.0, 3.0]
3-element Array{Float64,1}:
 1.0
 2.0
 3.0

julia> fsequential(u0, (identity, x-> x .+ 1))
3-element Array{Float64,1}:
 3.0
 5.0
 7.0

If you inspect this with @code_warntype you'll see it's inferrable.

fsequential! is an example of what is sometimes called "lispy tuple programming" in which you iteratively process one argument at a time until all vararg arguments have been exhausted. It's a powerful paradigm that allows much more flexible inference than a for-loop with an array (because it allows Julia to compile separate code for each "loop iteration"). However, it's generally only useful if the number of elements in the container is fairly small, otherwise you end up with insanely long compile times.

The type parameters F and Fs look unnecessary, but they are designed to force Julia to specialize the code for the particular functions you pass in.

2
Bogumił Kamiński On

There are several layers issues in your code (unfortunately to type stability):

  1. functions is a global variable so fundamentally your code will not be type stable
  2. Even if you moved functions inside the function definition and it would be a vector the code would still be type unstable, as the container would have abstract eltype (this would remain true even if you removed Function prefix before [ if you had more than one different function)
  3. If you changed a vector to tuple (then the collection functions would be type stable) the function would still be type unstable, because you use a loop which would not be able to internally infer the return type of func(u)

The solution would be to use @generated function that would unroll the loop into a sequence of consecutive applications of func(u) - then your code would be type stable.

However, in general I think, assuming that func(u) is an expensive operation, that the type instability in your code should not be very problematic as in the end you convert the return value of func(u) to Float64 anyway.

EDIT A @generated version for comparison with what is proposed by Tim Holy.

@generated function fgenerated(u, functions::Tuple{Vararg{Function}})
    expr = :(ret = zeros(eltype(u), size(u)))
    for fun in functions.parameters
        expr = :($expr; ret .+= $(fun.instance)(u))
    end
    return expr
end