I have a Tensorflow program running in Python, and for some convenience reasons I want to run the same program on Java, so I have to save my model and load it in my Java application.
My problem is that a don't know how to save a Tensor object, here is my code :
class Main:
def __init__(self, checkpoint):
...
self.g = tf.Graph()
self.sess = tf.Session()
self.img_placeholder = tf.placeholder(tf.float32,
shape=(1, 679, 1024, 3), name='img_placeholder')
#self.preds is an instance of Tensor
self.preds = transform(self.img_placeholder)
self.saver = tf.train.Saver()
self.saver.restore(self.sess, checkpoint)
def ffwd(...):
...
_preds = self.sess.run(self.preds, feed_dict=
{self.img_placeholder: self.X})
...
So since I can't create my Tensor (the transform function creates the NN behind the scenes...), I'am obliged to save it and reload it into Java. I have found ways of saving the session but not Tensor instances.
Could someone give me some insights on how to achieve this ?