Numba type changes when indexing

54 views Asked by At

I'm facing a strange problem. The following code (part of a function)

@njit
def treedist(treedists, An, Bn, w, M, Theta):
    print(An)
    print(Bn)
    print(An[1])
    print(Bn[1])

prints the following:

[(0.0, 1), (1.0, 18.071077087009371), (0.0, 0)]
[(0.0, 1), (1.0, 25.897262991223062), (0.0, 0)]
(1.0, 18)
(1.0, 25)

For some reason the float64 of the second element in the tuples get converted to an int64. Can anyone tell me why this is happening?

Thanks!

1

There are 1 answers

2
JoshAdel On BEST ANSWER

The problem is the following I believe -- Numba can only handle lists of constant types, so it looks at your list and inspects the first element and sees its of type (float64, int64). You can see this if you look at:

treedist.inspect_types()

or

treedist.inspect_llvm()

after running the function. It then makes that assumption about the types going forward. If you change all of the tuples to have consistent type:

An = [(0.0, 1.0), (1.0, 18.071077087009371), (0.0, 0.0)]

You won't get the cast to int when you print An[1].

If you have a list of inconsistently typed items, numba is going to fail (unfortunately here, it's doing it silently). See the docs that say that lists must be strictly homogeneous:

http://numba.pydata.org/numba-doc/0.29.0/reference/pysupported.html#list

The fact that it isn't "rejecting" your tuples may be because it isn't properly handling complex objects that don't follow the correct convention.