Changing function arguments to keywords in Julia seems to introduce type instability

388 views Asked by At

I have a program in which the main() function takes four arguments. When I run @code_warntype on the function there seems to be nothing untoward. All the variables have specified types, and there are no instances of UNION or other obvious warning signs.

Apologies, the program is rather long but I'm not sure how to shorten it while retaining the problem:

function main(n::Int, dice::Int=6, start::Int=1, modal::Int=3) ::Tuple{String, Vector{String}, Vector{Float64}}
    board = String["GO", "A1", "CC1", "A2", "T1", "R1", "B1", "CH1", "B2", "B3",
        "JAIL", "C1", "U1", "C2", "C3", "R2", "D1", "CC2", "D2", "D3",
        "FP", "E1", "CH2", "E2", "E3", "R3", "F1", "F2", "U2", "F3",
        "G2J", "G1", "G2", "CC3", "G3", "R4", "CH3", "H1", "T2", "H2"]
    cc_cards = shuffle(collect(1:16))
    ch_cards = shuffle(collect(1:16))
    function take_cc_card(square::Int, cards::Vector{Int})::Tuple{Int, Vector{Int}}
        if cards[1] == 1
            square = findfirst(board, "GO")
        elseif cards[1] == 2
            square = findfirst(board, "JAIL")
        end
        p = pop!(cards)
        unshift!(cards, p)
        return square, cards
    end
    function take_ch_card(square::Int, cards::Vector{Int})::Tuple{Int, Vector{Int}}
        if cards[1] == 1
            square = findfirst(board, "GO")
        elseif cards[1] == 2
            square = findfirst(board, "JAIL")
        elseif cards[1] == 3
            square = findfirst(board, "C1")
        elseif cards[1] == 4
            square = findfirst(board, "E3")
        elseif cards[1] == 5
            square = findfirst(board, "H2")
        elseif cards[1] == 6
            square = findfirst(board, "R1")
        elseif cards[1] == 7 || cards[1] == 8
            if board[square] == "CH1"
                square = findfirst(board, "R2")
            elseif board[square] == "CH2"
                square = findfirst(board, "R3")
            elseif board[square] == "CH3"
                square = findfirst(board, "R1")
            end
        elseif cards[1] == 9
            if board[square] == "CH1"
                square = findfirst(board, "U1")
            elseif board[square] == "CH2"
                square = findfirst(board, "U2")
            elseif board[square] == "CH3"
                square = findfirst(board, "U1")
            end
        elseif cards[1] == 10
            square = (square - 3) % 40 + ((square - 3 % 40 == 0 ? 40 : 0))
        end
        p = pop!(cards)
        unshift!(cards, p)
        return square, cards
    end
    result = zeros(Int, 40)
    consec_doubles = 0
    square = 1
    for i = 1:n
        throw_1 = rand(collect(1:dice))
        throw_2 = rand(collect(1:dice))
        if throw_1 == throw_2
            consec_doubles += 1
        else
            consec_doubles = 0
        end
        if consec_doubles != 3
            move = throw_1 + throw_2
            square = (square + move) % 40 +((square + move) % 40 == 0 ? 40 : 0)
            if board[square] == "G2J"
                square = findfirst(board, "JAIL")
            elseif board[square][1:2] == "CC"
                square, cc_cards = take_cc_card(square, cc_cards)
            elseif board[square][1:2] == "CH"
                square, ch_cards = take_ch_card(square, ch_cards)
                if board[square][1:2] == "CC"
                    square, cc_cards = take_cc_card(square, cc_cards)
                end
            end
        else
            square = findfirst(board, "JAIL")
            consec_doubles = 0
        end
        if i >= start
            result[square] += 1
        end
    end
    result_tuple = Vector{Tuple{Float64, Int}}()
    for i = 1:40
        percent = result[i] * 100 / sum(result)
        push!(result_tuple, (percent, i))
    end
    sort!(result_tuple, lt = (x, y) -> isless(x[1], y[1]), rev=true)
    modal_squares = Vector{String}()
    modal_string = ""
    modal_percents = Vector{Float64}()
    for i = 1:modal
        push!(modal_squares, board[result_tuple[i][2]])
        push!(modal_percents, result_tuple[i][1])
        k = result_tuple[i][2] - 1
        modal_string *= (k < 10 ? ("0" * string(k)) : string(k))
    end
    return modal_string, modal_squares, modal_percents
end

@code_warntype main(1_000_000, 4, 101, 5)

However, when I change the last three arguments to keywords by inserting a semi-colon rather than a comma after the first argument...

function main(n::Int; dice::Int=6, start::Int=1, modal::Int=3) ::Tuple{String, Vector{String}, Vector{Float64}}

...I seem to run into type stability problems.

@code_warntype main(1_000_000, dice=4, start=101, modal=5)

I'm now getting a temporary variable with an ANY type and an instance of UNION in the main text when I run @code_warntype.

Curiously this doesn't seem to come with a performance hit, as on an average of three benchmark tests the 'argument' version runs in 431.594 ms and the 'keyword' version runs in 413.149 ms. However, I'm curious to know:

(a) why this is happening;

(b) whether, as a general rule, the appearance of temporary variables with an ANY type is a cause for concern; and

(c) whether, as a general rule, there is any advantage from a performance perspective from using keywords rather than normal function arguments.

1

There are 1 answers

1
Bogumił Kamiński On BEST ANSWER

Here is my take at the three questions. In the answer I assume Julia 0.6.3 unless I explicitly state that I refer to Julia 0.7 at the end of the post.

(a) The code with Any variable is a part of the code that is responsible for handling keyword arguments (e.g. making sure that passed keyword argument is allowed by function signature). The reason is that keyword arguments are received as Vector{Any} inside a function. The vector holds tuples ([argument name], [argument value]). The actual "work" the function does happens after this part with Any variable.

You can see this by comparing calls:

@code_warntype main(1_000_000, dice=4, start=101, modal=5)

and

@code_warntype main(1_000_000)

for the function with keyword arguments. The second call has only the last line of report produced by the first call above, and all other are responsible for handling passed keyword arguments.

(b) as a general rule this can be a concern of course, but in this case this cannot be helped. The variable with Any holds information about the name of keyword argument.

(c) in general you can assume that positional arguments are not slower than keyword arguments, but can be faster. Here is a MWE (actually if you run @code_warntype f(a=10) you will see this Any variable also):

julia> using BenchmarkTools

julia> f(;a::Int=1) = a+1
f (generic function with 1 method)

julia> g(a::Int=1) = a+1
g (generic function with 2 methods)

julia> @benchmark f()
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.865 ns (0.00% GC)
  median time:      1.866 ns (0.00% GC)
  mean time:        1.974 ns (0.00% GC)
  maximum time:     14.463 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark f(a=10)
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     52.994 ns (0.00% GC)
  median time:      54.413 ns (0.00% GC)
  mean time:        65.207 ns (10.65% GC)
  maximum time:     3.466 μs (94.78% GC)
  --------------
  samples:          10000
  evals/sample:     986

julia> @benchmark g()
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.865 ns (0.00% GC)
  median time:      1.866 ns (0.00% GC)
  mean time:        1.954 ns (0.00% GC)
  maximum time:     13.062 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark g(10)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.865 ns (0.00% GC)
  median time:      1.866 ns (0.00% GC)
  mean time:        1.949 ns (0.00% GC)
  maximum time:     13.063 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

Now you can see that actually the penalty of keyword argument is when it is passed (and this is exactly the case when you have Any variable in @code_warntype as Julia has to do more work then). Note, that the penalty is small and it will be visible in functions doing very little work. For functions that do a lot of computations it can be ignored most of the time.

Additionally note that in case you would not specify type of keyword argument the penalty would be much bigger when explicitly passing keyword argument value as Julia does not dispatch on keyword argument type (you can also run @code_warntype to witness this):

julia> h(;a=1) = a+1
h (generic function with 1 method)

julia> @benchmark h()
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.865 ns (0.00% GC)
  median time:      1.866 ns (0.00% GC)
  mean time:        1.960 ns (0.00% GC)
  maximum time:     13.996 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark h(a=10)
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     75.433 ns (0.00% GC)
  median time:      77.355 ns (0.00% GC)
  mean time:        89.037 ns (7.87% GC)
  maximum time:     2.128 μs (89.73% GC)
  --------------
  samples:          10000
  evals/sample:     971

In Julia 0.7 keyword arguments are received as Base.Iterator.Pairs holding a NamedTuple so Julia knows types of passed arguments at compile time. This means that using keyword arguments is faster than in Julia 0.6.3 (but again - you should not expect them to be faster than positional arguments). You can see this buy running similar benchmarks (I have only changed what function does a bit to give a bit more work to Julia compiler) as above but under Julia 0.7 (you can also have a look at @code_warntype on those functions to see that type inference works better in Julia 0.7):

julia> using BenchmarkTools

julia> f(;a::Int=1) = [a]
f (generic function with 1 method)

julia> g(a::Int=1) = [a]
g (generic function with 2 methods)

julia> h(;a=1) = [a]
h (generic function with 1 method)

julia> @benchmark f()
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.724 ns (0.00% GC)
  median time:      34.523 ns (0.00% GC)
  mean time:        50.576 ns (22.80% GC)
  maximum time:     53.465 μs (99.89% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark f(a=10)
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.724 ns (0.00% GC)
  median time:      34.057 ns (0.00% GC)
  mean time:        50.739 ns (22.83% GC)
  maximum time:     55.303 μs (99.89% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark g()
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.724 ns (0.00% GC)
  median time:      34.523 ns (0.00% GC)
  mean time:        50.529 ns (22.77% GC)
  maximum time:     54.501 μs (99.89% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark g(10)
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.724 ns (0.00% GC)
  median time:      34.523 ns (0.00% GC)
  mean time:        50.899 ns (23.27% GC)
  maximum time:     56.246 μs (99.90% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark h()
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.257 ns (0.00% GC)
  median time:      34.057 ns (0.00% GC)
  mean time:        50.924 ns (22.87% GC)
  maximum time:     55.724 μs (99.88% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark h(a=10)
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     31.724 ns (0.00% GC)
  median time:      34.057 ns (0.00% GC)
  mean time:        50.864 ns (22.60% GC)
  maximum time:     53.389 μs (99.83% GC)
  --------------
  samples:          10000
  evals/sample:     1000