1

I'm trying to train a model with GradientTape in Keras. Here is the code:

@tf.function
def train_step(x,y):
    
    with tf.GradientTape() as tape:
                
        predictions = model.predict(x)
        
        loss = compute_loss(y, predections)
    
    grads = tape.gradient(loss, model.trainable_variables)
        
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
    return loss

history = []

for iter in tqdm(range(num_iters)):
    
    x_batch, y_batch = get_batch(x_train, y_train, batch_dim)
    loss = train_step(x_batch, y_batch)
    history.append(loss.numpy().mean())
    

This code leads to the following error:

ValueError: When using data tensors as input to a model, you should specify the `steps` argument.

However if I try to call the prediction outside the function as follows:

history = []

for iter in tqdm(range(num_iters)):
    
    x_batch, y_batch = get_batch(x_train, y_train, batch_dim)       
    x_hat = model.predict(x_batch)

I get no error...

Can someone explain me why do I get this behavior from Keras?

6
  • could you print out the result of the get_batch function? Commented Sep 7, 2020 at 18:10
  • This previous discussion may help Question-54547681 Commented Sep 7, 2020 at 20:19
  • @tornikeo get_batch outputs two numpy arrays of shape (batch_size, 512, 512, 1) Commented Sep 8, 2020 at 6:30
  • @stephen_mugisha I already checked that answear before posting this one but I couldn't get rid of my problem... Commented Sep 8, 2020 at 6:31
  • @francesco in order to use tf.function you must specify input and output shapes and data types. That is likely causing problems. Does you error change if you remove the tf.function decorator? Commented Sep 8, 2020 at 8:00

1 Answer 1

1

Answering here even though the user has answered in the comment for the benefit of the community.

By changing the data type of x_batch and y_batch to float32 and then called model(x_batch) for predicting the output.
In this way, the issue will be resolved and by maintaining the train_step function as @tf.function.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.