1

Well, I am stuck in a Tensor flow problem.

I am trying to call a python function from within tensorflow. According to the tensor flow manual py_func() can be used to invoke python function.

The want to achieve something like the below:

a = np.array([1,2,3,4], dtype='float32')
b = np.array([[5,6,7,8],[9,8,1,2], [3,2,3,1],[4,5,1,3]], dtype='float32')


def pyfunction(inputIN):
    return np.array(inputIN + a)

def tfFunction():
    inp = tf.placeholder(tf.float32, [2,4])
    out = tf.py_func(pyfunction, [inp], tf.float32)

tfFunction()   
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())    
    for inp_ in (b[0:2], b[2:4]):
        feed_dict = {inp:inp_}
        output = sess.run([out], feed_dict=feed_dict)
        print (output)

The output I require is :

[[  6.   8.  10.  12.]
 [ 10.  10.   4.   6.]]
[[ 4.  4.  6.  5.]
 [ 5.  7.  4.  7.]]

Using the above code I get an error.

TypeError: Expected list for attr Tout

I guess i get what the error says but I cant figure out a solution.

Please Note: I want to achieve a code very similar to it, The above problem is just a dummy problem. I am working on a image processing task and I have few image processing task (using OpenCV) inside a python function. I need to call the python function for every image while running the graph.

I understand that I can preprocess the data before and store it as batches, but I have few other tasks lined up. Therefore I have to stick to the above format

Any help will be appreciated. Thanks

1 Answer 1

2

This code provides the result desired:

import tensorflow as tf
import numpy as np

a = np.array([1,2,3,4], dtype='float32')
b = np.array([[5,6,7,8],[9,8,1,2], [3,2,3,1],[4,5,1,3]], dtype='float32')


def pyfunction(inputIN):
    return np.array(inputIN + a)

inp = tf.placeholder(tf.float32, [2,4])
out = tf.py_func(pyfunction, [inp], tf.float32)

with tf.Session() as sess:

    sess.run(tf.initialize_all_variables())
    for inp_ in (b[0:2], b[2:4]):
        feed_dict = {inp:inp_}
        output = sess.run(out, feed_dict=feed_dict)
        print (output)

Note, there are no square brackets near 'out' in the sess.run(..).

[[  6.   8.  10.  12.]
 [ 10.  10.   4.   6.]]
[[ 4.  4.  6.  5.]
 [ 5.  7.  4.  7.]]
Sign up to request clarification or add additional context in comments.

2 Comments

Sorry, it gives the same error even after using the code provided my you. Is it a version specific error, the version of tensor flow I am running is '0.10.0'
I used 1.x. It looks like it is a version-specific, really. So, for future cases, you should define version of library in the question description.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.