Getting a Tensors value in Java

3.6k views Asked by At

I habe trained a recurrent neural network with tensorflow using python. I saved the model and restored it in a Java-Application. This is working. Now i feed my input-Tensors to the pretrained modell and fetch the output. My problem now is, that the output is a Tensor and I don´t know hot to get the Tensors value (it is a simple integer-tensor of shape 1).

The python code looks like this:

sess = tf.InteractiveSession()

X = tf.placeholder(tf.float32, [None, n_steps, n_inputs], name="input_x")
y = tf.placeholder(tf.int32, [ None])

keep_prob = tf.placeholder(tf.float32, name="keep_prob")

basic_cell = tf.contrib.rnn.OutputProjectionWrapper(tf.contrib.rnn.BasicRNNCell(num_units=n_neurons),output_size=n_outputs)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)

logits = tf.layers.dense(states, n_outputs, name="logits")
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits)

loss = tf.reduce_mean(xentropy)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)
correct = tf.nn.in_top_k(logits, y,1, name="correct")
pred = tf.argmax(logits, 1, name="prediction")
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))


init = tf.global_variables_initializer()







def train_and_save_rnn():
    # create a Saver object as normal in Python to save your variables
    saver = tf.train.Saver()

    # Use a saver_def to get the "magic" strings to restore
    saver_def = saver.as_saver_def()
    print (saver_def.filename_tensor_name)
    print (saver_def.restore_op_name)

    # Loading the Train-DataSet
    data_train, labels_train = load_training_data("Train.csv")
    data_test, labels_test = load_training_data("Test.csv")

    #labels_train=reshape_labels_to_sequences(labels_train)
    #labels_test=reshape_labels_to_sequences(labels_test)

    dt_train = reshape_data(data_train)
    dt_test = reshape_data(data_test)
    X_test = dt_test
    X_test = X_test.reshape((-1, n_steps, n_inputs))
    y_test = labels_test-1


    sess.run(tf.global_variables_initializer())
    # START TRAINING ...
    for epoch in range(n_epochs):
        for iteration in range(dt_train.shape[0]-1):
            X_batch, y_batch = dt_train[iteration], labels_train[iteration]-1
            X_batch = X_batch.reshape((-1, n_steps, n_inputs))
            y_batch = y_batch.reshape((1))
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})

        acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
        acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
        print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)


    # SAVE THE TRAINED MODEL ...
    builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
    builder.save(True)  #true for human-readable

What I do in Java is:

            byte[] graphDef = readAllBytesOrExit(Paths.get(IMPORT_DIRECTORY, "/saved_model.pbtxt"));
    /*List<String> labels =
            readAllLinesOrExit(Paths.get(IMPORT_DIRECTORY, "trained_model.txt"));
            */

        try (SavedModelBundle b = SavedModelBundle.load(IMPORT_DIRECTORY, "serve")) {
            // create the session from the Bundle
            Session sess = b.session();
            s = sess;
            g = b.graph();
            // This is just a sample Tensor for debugging:
            Tensor t = Tensor.create(new float[][][] {{{(float)0.8231331,(float)-5.2657013,(float)-1.1111984,(float)0.0074825287,(float)0.075252056,(float)0.07835889,(float)-0.035752058,(float)-0.035610847,(float)0.045247793,(float)1.5594741,(float)57.78549,(float)-0.21489286,(float)0.011989355,(float)0.15965772,(float)13.370155,(float)3.4708557,(float)3.7776794,(float)-1.1115816,(float)0.72939104,(float)-0.44342846,(float)11.001129,(float)10.549805,(float)-50.719162,(float)-0.8261242,(float)0.71805984,(float)-0.1849739,(float)9.334606,(float)3.0003967,(float)-52.456577,(float)-0.1875816,(float)0.19306469,(float)0.004947722,(float)5.4054375,(float)-0.8630371,(float)-24.599575,(float)1.3387873,(float)-1.1488495,(float)-2.8362968,(float)22.174248,(float)-32.095154,(float)10.069847}}});
            runTensor(t);

        }


 public static void runTensor(Tensor inputTensor) throws IOException, FileNotFoundException {

    try (Graph graph = g;
         Session sess = s;) {
        Integer gesture = null;
        Tensor y_ph = Tensor.create(new int[]{0});
        Tensor result = sess.runner()
                .feed("input_x", inputTensor)
                .feed("Placeholder", y_ph)
                .fetch("pred")
                .run().get(0);
        System.out.println(result);

    } catch (Exception e) {
        e.printStackTrace();
    }
}

The output should (I´m not sure if it´s working) be an Integer between 0 and 10 for the predicted class. How can I extract the Integer in Java from the Tensor? Thank you in advance.

1

There are 1 answers

0
ash On

Use Tensor.intValue() if it is a scalar, or Tensor.copyTo() if it is not. (So System.out.println(result.intValue());)