Tensorflow: Can't use tf.case with input argument

2.5k views Asked by At

I need to create a variable epsilon_n that changes definition (and value) based on the current step. Since I have more than two cases, it seems that I can't use tf.cond . I am trying to use tf.case as follows:

import tensorflow as tf

####
EPSILON_DELTA_PHASE1 = 33e-4
EPSILON_DELTA_PHASE2 = 2.5
####
step = tf.placeholder(dtype=tf.float32, shape=None)


def fn1(step):
    return tf.constant([1.])

def fn2(step):
    return tf.constant([1.+step*EPSILON_DELTA_PHASE1])

def fn3(step):
    return tf.constant([1.+step*EPSILON_DELTA_PHASE2])

epsilon_n = tf.case(
        pred_fn_pairs=[
            (tf.less(step, 3e4), lambda step: fn1(step)),
            (tf.less(step, 6e4), lambda step: fn2(step)),
            (tf.less(step, 1e5), lambda step: fn3(step))],
            default=lambda: tf.constant([1e5]),
        exclusive=False)

However, I keep getting this error message:

TypeError: <lambda>() missing 1 required positional argument: 'step'

I tried the following:

epsilon_n = tf.case(
        pred_fn_pairs=[
            (tf.less(step, 3e4), fn1),
            (tf.less(step, 6e4), fn2),
            (tf.less(step, 1e5), fn3)],
            default=lambda: tf.constant([1e5]),
        exclusive=False)

Still I would the same error. The examples in Tensorflow documentation weigh in on cases where no input argument is passed to the callable functions. I couldn't find enough info about tf.case on the internet! Please any help?

1

There are 1 answers

1
Ishant Mrinal On BEST ANSWER

Here's few changes you need to make. For consistency you can set all return values as variable.

# Since step is a scalar, scalar shape [() or [], not None] much be provided 
step = tf.placeholder(dtype=tf.float32, shape=())


def fn1(step):
    return tf.constant([1.])

# Here you need to use Variable not constant, since you are modifying the value using placeholder
def fn2(step):
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE1])

def fn3(step):
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE2])

epsilon_n = tf.case(
    pred_fn_pairs=[
        (tf.less(step, 3e4), lambda : fn1(step)),
        (tf.less(step, 6e4), lambda : fn2(step)),
        (tf.less(step, 1e5), lambda : fn3(step))],
        default=lambda: tf.constant([1e5]),
    exclusive=False)