0

I am using a function in TensorFlow which maps a set of tensors to another arrangement of tensors. For example, you might write:

data = data.map(_function)

def _function(a, b, c):
    return (a + 1, b, c)

So here, you pass _function as a function variable to map, and map passes it three tensors, which are mutated in some way (here, just adding one) and returned.


My question is: Is there a way to pass in additional variables to _function?

If I want to perform a + x, and not a + 1, then how could I pass in the additional variable?

You can't do something like: data.map(_function(x)) because then you're passing the result of a function, not the function itself.

I've experimented with *arg, but I can't find a way. Any help is greatly appreciated.

1 Answer 1

3

You can do sth like

def extra_func(x):
    def _function(a, b, c):
        return (a + x, b, c)
    return _function

So you can do data.map(extra_func(x))

or you can use functools.partial to fix some of a function params

Sign up to request clarification or add additional context in comments.

4 Comments

Except now in this case, _extra_func() is being passed a, b, c by the .map() function.
I can't get your first method to fix my specific problem, but I can fix it with functools.partial to add further parameters to the function. Thanks so much for that helping hand.
From your question I understand that the map parameter needs to take a 3 parameter function. extra_func(10) returns a 3 parameter function that has the x fixed to 10.
Oh snap it does. Yeah, you're right, that would also fix it.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.