I habe trained a recurrent neural network with tensorflow using python. I saved the model and restored it in a Java-Application. This is working. Now i feed my input-Tensors to the pretrained modell and fetch the output. My problem now is, that the output is a Tensor and I don´t know hot to get the Tensors value (it is a simple integer-tensor of shape 1).
The python code looks like this:
sess = tf.InteractiveSession()
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs], name="input_x")
y = tf.placeholder(tf.int32, [ None])
keep_prob = tf.placeholder(tf.float32, name="keep_prob")
basic_cell = tf.contrib.rnn.OutputProjectionWrapper(tf.contrib.rnn.BasicRNNCell(num_units=n_neurons),output_size=n_outputs)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
logits = tf.layers.dense(states, n_outputs, name="logits")
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits)
loss = tf.reduce_mean(xentropy)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)
correct = tf.nn.in_top_k(logits, y,1, name="correct")
pred = tf.argmax(logits, 1, name="prediction")
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
init = tf.global_variables_initializer()
def train_and_save_rnn():
# create a Saver object as normal in Python to save your variables
saver = tf.train.Saver()
# Use a saver_def to get the "magic" strings to restore
saver_def = saver.as_saver_def()
print (saver_def.filename_tensor_name)
print (saver_def.restore_op_name)
# Loading the Train-DataSet
data_train, labels_train = load_training_data("Train.csv")
data_test, labels_test = load_training_data("Test.csv")
#labels_train=reshape_labels_to_sequences(labels_train)
#labels_test=reshape_labels_to_sequences(labels_test)
dt_train = reshape_data(data_train)
dt_test = reshape_data(data_test)
X_test = dt_test
X_test = X_test.reshape((-1, n_steps, n_inputs))
y_test = labels_test-1
sess.run(tf.global_variables_initializer())
# START TRAINING ...
for epoch in range(n_epochs):
for iteration in range(dt_train.shape[0]-1):
X_batch, y_batch = dt_train[iteration], labels_train[iteration]-1
X_batch = X_batch.reshape((-1, n_steps, n_inputs))
y_batch = y_batch.reshape((1))
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)
# SAVE THE TRAINED MODEL ...
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
builder.save(True) #true for human-readable
What I do in Java is:
byte[] graphDef = readAllBytesOrExit(Paths.get(IMPORT_DIRECTORY, "/saved_model.pbtxt"));
/*List<String> labels =
readAllLinesOrExit(Paths.get(IMPORT_DIRECTORY, "trained_model.txt"));
*/
try (SavedModelBundle b = SavedModelBundle.load(IMPORT_DIRECTORY, "serve")) {
// create the session from the Bundle
Session sess = b.session();
s = sess;
g = b.graph();
// This is just a sample Tensor for debugging:
Tensor t = Tensor.create(new float[][][] {{{(float)0.8231331,(float)-5.2657013,(float)-1.1111984,(float)0.0074825287,(float)0.075252056,(float)0.07835889,(float)-0.035752058,(float)-0.035610847,(float)0.045247793,(float)1.5594741,(float)57.78549,(float)-0.21489286,(float)0.011989355,(float)0.15965772,(float)13.370155,(float)3.4708557,(float)3.7776794,(float)-1.1115816,(float)0.72939104,(float)-0.44342846,(float)11.001129,(float)10.549805,(float)-50.719162,(float)-0.8261242,(float)0.71805984,(float)-0.1849739,(float)9.334606,(float)3.0003967,(float)-52.456577,(float)-0.1875816,(float)0.19306469,(float)0.004947722,(float)5.4054375,(float)-0.8630371,(float)-24.599575,(float)1.3387873,(float)-1.1488495,(float)-2.8362968,(float)22.174248,(float)-32.095154,(float)10.069847}}});
runTensor(t);
}
public static void runTensor(Tensor inputTensor) throws IOException, FileNotFoundException {
try (Graph graph = g;
Session sess = s;) {
Integer gesture = null;
Tensor y_ph = Tensor.create(new int[]{0});
Tensor result = sess.runner()
.feed("input_x", inputTensor)
.feed("Placeholder", y_ph)
.fetch("pred")
.run().get(0);
System.out.println(result);
} catch (Exception e) {
e.printStackTrace();
}
}
The output should (I´m not sure if it´s working) be an Integer between 0 and 10 for the predicted class. How can I extract the Integer in Java from the Tensor? Thank you in advance.
Use
Tensor.intValue()
if it is a scalar, orTensor.copyTo()
if it is not. (SoSystem.out.println(result.intValue());
)