5

I have a fined-tuned vgg model and I created the model using tensorflow.keras functional API and saved the model using tf.contrib.saved_model.save_keras_model. So the model is saved with this structure: assets folder which contains saved_model.json file, saved_model.pb file, and the variables folder, which contain checkpoint, variables.data-00000-of-00001 and variables.index.

I can easily load my model in python and get predictions using tf.contrib.saved_model.load_keras_model(saved_model_path), but I have no idea how to load the model in JAVA. I googled a lot and found this How to export Keras .h5 to tensorflow .pb? to export as pb file and then load it up following this link Loading in Java. I was not able to freeze the graph and also I tried to use simple_save but the tensorflow.keras does not support simple_save (AttributeError: module 'tensorflow.contrib.saved_model' has no attribute 'simple_save'). So can someone help me to figure out what steps are needed to load my model (tensorflow.keras functional API model) in JAVA.

Is the saved_model.pb file that I have, good enough to be loaded on the JAVA side? Do I need to create my input/output place holders? Then how can I export it?
I appreciate your help.

1
  • You can use TensorFlow Lite instead of tensorflow.org/lite Commented Dec 12, 2018 at 15:49

1 Answer 1

3

If you have a model saved in the SavedModel format (which it appears you do, and things like tf.contrib.saved_model.save_keras_model can help create), then in Java you can use SavedModelBundle.load to load and serve it. You do not need to "freeze" the model.

You may find the following useful:

But the basic idea is that your code will look something like:

try (SavedModelBundle model = SavedModelBundle.load("<directory>", "serve")) {
  try (Tensor<?> input = makeInputTensor();
       Tensor<?> output = model.session().runner().feed("INPUT_TENSOR", input).fetch("OUTPUT_TENSOR", output).run().get(0)) {
  // Use output
  }
}

Where "INPUT_TENSOR" and "OUTPUT_TENSOR" are the names of the input and output nodes in the TensorFlow graph. The saved_model_cli command-line tool installed when you install TensorFlow for Python can show you the names of these tensors in your model.

Note that using the TensorFlow Java API may be more suited to server/desktop applications than using TensorFlow Lite as suggested by another commenter. This is because the TensorFLow Lite runtime, while optimized (in terms of memory footprint etc.) for small devices, cannot export all models yet. While the TensorFlow Java API is using the exact same runtime and thus has the exact same abilities as TensorFlow for Python.

Hope that helps.

Sign up to request clarification or add additional context in comments.

5 Comments

I did exactly the same and it worked fine finally when I was using Inception model as a pre-trained model. But when I use VGG model as a base, my model cannot be loaded in JAVA. Have you seen any tutorial which loads the VGG model and finetune it and then loaded in JAVA? I will send you the errors that I get when I try the VGG model. I do appreciate your help.
This is the error that I get: Matrix size-incompatible: In[0]: [1,8192], In[1]: [25088,256] [[{{node dense/MatMul}} = MatMul[T=DT_FLOAT, _output_shapes=[[?,256]], transpose_a=false, transpose_b=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](flatten/Reshape, dense/MatMul/ReadVariableOp)]]
The Example Link in the response is not valid anymore.
what is makeInputTensor()? Something the user would code that returns a Tensor of the appropriate dimension?
And on the next line, output is used before it is defined, seemingly.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.