I have a dataset with 5 labels
def get_label(file_path):
# convert the path to a list of path components
parts = tf.strings.split(file_path, os.path.sep)
class_names = ['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']
# The second to last is the class-directory
one_hot = parts[-2] == class_names
# Integer encode the label
return tf.argmax(one_hot)
def decode_img(img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img, channels=3)
# resize the image to the desired size
return tf.image.resize(img, [img_height, img_width])
def process_path(file_path):
label = get_label(file_path)
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = decode_img(img)
return img, label
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
If I change this code with other dataset having 2 labels, class_names = ['dog', 'cat']
I find this error
TypeError: Value passed to parameter 'input' has DataType bool not in list of allowed values: float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, bfloat16, uint16, complex128, float16, uint32, uint64
So how I can update def get_label(file_path)
My guess would be that tf.argmax requires one of these data-types (I can't test this right now)
so all you need to do is convert the output of
to int, the "==" evaluates to True/False which is probably not allowed.