I'm playing with the Dataset API in Tensorflow v1.3. It's great.
It is possible to map a dataset with a function as described here. I am interested to know how can I pass a function which has an additional argument, for example arg1
:
def _parse_function(example_proto, arg1):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
Of course,
dataset = dataset.map(_parse_function)
will not work since there is no way to pass in arg1
.
You can also use a
Partial
function instead to wrap your parameter :The parameters order of your function is changed in order to fit the partiality, then you can wrap your function with a parameter value like following :