I need to create a variable epsilon_n
that changes definition (and value) based on the current step
. Since I have more than two cases, it seems that I can't use tf.cond
. I am trying to use tf.case
as follows:
import tensorflow as tf
####
EPSILON_DELTA_PHASE1 = 33e-4
EPSILON_DELTA_PHASE2 = 2.5
####
step = tf.placeholder(dtype=tf.float32, shape=None)
def fn1(step):
return tf.constant([1.])
def fn2(step):
return tf.constant([1.+step*EPSILON_DELTA_PHASE1])
def fn3(step):
return tf.constant([1.+step*EPSILON_DELTA_PHASE2])
epsilon_n = tf.case(
pred_fn_pairs=[
(tf.less(step, 3e4), lambda step: fn1(step)),
(tf.less(step, 6e4), lambda step: fn2(step)),
(tf.less(step, 1e5), lambda step: fn3(step))],
default=lambda: tf.constant([1e5]),
exclusive=False)
However, I keep getting this error message:
TypeError: <lambda>() missing 1 required positional argument: 'step'
I tried the following:
epsilon_n = tf.case(
pred_fn_pairs=[
(tf.less(step, 3e4), fn1),
(tf.less(step, 6e4), fn2),
(tf.less(step, 1e5), fn3)],
default=lambda: tf.constant([1e5]),
exclusive=False)
Still I would the same error. The examples in Tensorflow documentation weigh in on cases where no input argument is passed to the callable functions. I couldn't find enough info about tf.case on the internet! Please any help?
Here's few changes you need to make. For consistency you can set all return values as variable.