0

I have a Tensorflow program running in Python, and for some convenience reasons I want to run the same program on Java, so I have to save my model and load it in my Java application.

My problem is that a don't know how to save a Tensor object, here is my code :

class Main:
def __init__(self, checkpoint):
    ...
    self.g = tf.Graph()
    self.sess = tf.Session()

    self.img_placeholder = tf.placeholder(tf.float32, 
    shape=(1, 679, 1024, 3), name='img_placeholder')

    #self.preds is an instance of Tensor
    self.preds = transform(self.img_placeholder)

    self.saver = tf.train.Saver()
    self.saver.restore(self.sess, checkpoint)

def ffwd(...):

    ...
    _preds = self.sess.run(self.preds, feed_dict=
    {self.img_placeholder: self.X})

    ...

So since I can't create my Tensor (the transform function creates the NN behind the scenes...), I'am obliged to save it and reload it into Java. I have found ways of saving the session but not Tensor instances.

Could someone give me some insights on how to achieve this ?

2
  • you may need to check the tensorflow lite, it using different model, that should be the only way to load in java. tensorflow.org/images/tflite-architecture.jpg?hl=zh-cn Commented Feb 26, 2018 at 2:47
  • It seems to be way of saving models for android devices, no what i'am searching for. I already know how to save the session from Python to Java, what i can't save/load is the Tensor instance (used as fetches when running the session) Commented Feb 26, 2018 at 10:06

1 Answer 1

2

Python Tensor objects are symbolic references to a specific output of an operation in the graph.

An operation in a graph can be uniquely identified by its string name. A specific output of that operation is identified by an integer index into the list of outputs of that operation. That index is typically zero since a vast majority of operations produce a single output.

To obtain the name of an Operation and the output index referred to by a Tensor object in Python you could do something like:

print(preds.op.name)
print(preds.value_index)  # Most likely will be 0

And then in Java, you can feed/fetch nodes by name. Let's say preds.op.name returned the string foo, and preds.value_index returned the integer 1, then in Java, you'd do the following:

session.runner().feed("img_placeholder").fetch("foo", 1)

(See javadoc for org.tensorflow.Session.Runner for details).

You may find the slides linked to in https://github.com/tensorflow/models/tree/master/samples/languages/java along with the speaker notes in those slides useful.

Hope that helps.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.