Numpy/Pytorch dtype conversion / compatibility

621 views Asked by At

I'm trying to find some documentation understand how dtypes are combined. For example:

x : np.int32 = ...
y : np.float64 = ...
  • What is going to be the type of x + y ?
  • does it depend on the operator (here +) ?
  • does it depend where it is stored (z = x + y vs z[...] = x + y) ?

I'm looking for part of the documentation that describe these kind of scenario but so far I'm empty-handed.

1

There are 1 answers

4
kmario23 On BEST ANSWER

If data types don't match, then NumPy will upcast the data to the higher precision data types if possible. And it doesn't depend on the type of (arithmetic) operation that we do or to the variables that we assign to, unless that variable already has some other dtype. Here is a small illustration:

In [14]: x = np.arange(3, dtype=np.int32)
In [15]: y = np.arange(3, dtype=np.float64)

# `+` is equivalent to `numpy.add()`
In [16]: summed = x + y

In [17]: summed.dtype
Out[17]: dtype('float64')

In [18]: np.add(x, y).dtype
Out[18]: dtype('float64')

If you don't explicitly assign a datatype, then the result will be upcasted to the higher of data types of the given inputs. For example, numpy.add() accepts a dtype kwarg where you can specify the datatype of the resultant array.


And, one can check whether two different datatypes can be safely casted according to casting rules by using numpy.can_cast()

For the sake of completeness, I add the following numpy.can_cast() matrix:

>>> def print_casting_matrix(ntypes):
...     ntypes_ex = ["X"] + ntypes.split()
...     print("".join(ntypes_ex))
...     for row in ntypes:
...         print(row, sep='\t', end=''),
...         for col in ntypes:
...             print(int(np.can_cast(row, col)), sep='\t', end='')
...         print()

>>> print_casting_matrix(np.typecodes['All'])

And the output would be the following matrix which shows what dtypes can be safely casted (indicated by 1) and what dtypes cannot be casted (indicated by 0), following the order of from casting (along axis-0) to to casting (axis-1) :

# to casting -----> ----->
X?bhilqpBHILQPefdgFDGSUVOMm
?11111111111111111111111101
b01111110000001111111111101
h00111110000000111111111101
i00011110000000011011111101
l00001110000000011011111101
q00001110000000011011111101
p00001110000000011011111101
B00111111111111111111111101
H00011110111110111111111101
I00001110011110011011111101
L00000000001110011011111101
Q00000000001110011011111101
P00000000001110011011111101
e00000000000001111111111100
f00000000000000111111111100
d00000000000000011011111100
g00000000000000001001111100
F00000000000000000111111100
D00000000000000000011111100
G00000000000000000001111100
S00000000000000000000111100
U00000000000000000000011100
V00000000000000000000001100
O00000000000000000000001100
M00000000000000000000001110
m00000000000000000000001101

Since the characters are cryptic, we can use the following for better understanding of the above casting matrix:

In [74]: for char in np.typecodes['All']:
    ...:     print(char, " --> ", np.typeDict[char])

And the output would be:

?  -->  <class 'numpy.bool_'>
b  -->  <class 'numpy.int8'>
h  -->  <class 'numpy.int16'>
i  -->  <class 'numpy.int32'>
l  -->  <class 'numpy.int64'>
q  -->  <class 'numpy.int64'>
p  -->  <class 'numpy.int64'>
B  -->  <class 'numpy.uint8'>
H  -->  <class 'numpy.uint16'>
I  -->  <class 'numpy.uint32'>
L  -->  <class 'numpy.uint64'>
Q  -->  <class 'numpy.uint64'>
P  -->  <class 'numpy.uint64'>
e  -->  <class 'numpy.float16'>
f  -->  <class 'numpy.float32'>
d  -->  <class 'numpy.float64'>
g  -->  <class 'numpy.float128'>
F  -->  <class 'numpy.complex64'>
D  -->  <class 'numpy.complex128'>
G  -->  <class 'numpy.complex256'>
S  -->  <class 'numpy.bytes_'>
U  -->  <class 'numpy.str_'>
V  -->  <class 'numpy.void'>
O  -->  <class 'numpy.object_'>
M  -->  <class 'numpy.datetime64'>
m  -->  <class 'numpy.timedelta64'>