1

I've made a very simple neural network, which is meant to do reinforcement learning. However, I cannot predict anything as I get an error when trying to predict.

Error in question:

Error when checking input: expected dense_203_input to have shape (1202,) but got array with shape (1,)

Model in questions:

 def _build_compile_model(self):
    model = Sequential()
    model.add(Dense(300, activation='relu', input_dim=1202))
    model.add(Dense(300, activation='relu'))
    model.add(Dense(200, activation='relu'))
    model.add(Dense(self._action_size, activation='softmax'))

    model.compile(loss='mse', optimizer=self._optimizer)
    return model

error occurs when calling model.predict(state) where state is an array of shape (1202, 1).

Full error message is

ValueError                                Traceback (most recent call last)
<ipython-input-148-06b7a01facef> in <module>
     18     new_state, reward = env.step(action, new_demand_a, new_demand_b) # Take action, get new state and reward
     19     new_state = np.reshape(new_state, [1202, -1])
---> 20     agent.update(old_state, new_state, action, reward) # Let the agent update internal
     21     average_reward.append(reward) # Keep score
     22     if i % 100 == 0 and i != 0: # Print out metadata every 100th iteration

<ipython-input-145-142ae54ce43f> in update(self, old_state, new_state, action, reward)
     49     def update(self, old_state, new_state, action, reward):
     50         print(old_state.shape)
---> 51         target = self.q_network.predict(old_state)
     52         t = self.target_network.predict(new_state)
     53         target[0][action] = reward + self.gamma * np.amax(t)

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
   1011         max_queue_size=max_queue_size,
   1012         workers=workers,
-> 1013         use_multiprocessing=use_multiprocessing)
   1014 
   1015   def reset_metrics(self):

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in predict(self, model, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing, **kwargs)
    496         model, ModeKeys.PREDICT, x=x, batch_size=batch_size, verbose=verbose,
    497         steps=steps, callbacks=callbacks, max_queue_size=max_queue_size,
--> 498         workers=workers, use_multiprocessing=use_multiprocessing, **kwargs)
    499 
    500 

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _model_iteration(self, model, mode, x, y, batch_size, verbose, sample_weight, steps, callbacks, max_queue_size, workers, use_multiprocessing, **kwargs)
    424           max_queue_size=max_queue_size,
    425           workers=workers,
--> 426           use_multiprocessing=use_multiprocessing)
    427       total_samples = _get_total_number_of_samples(adapter)
    428       use_sample = total_samples is not None

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_inputs(model, mode, x, y, batch_size, epochs, sample_weights, class_weights, shuffle, steps, distribution_strategy, max_queue_size, workers, use_multiprocessing)
    644     standardize_function = None
    645     x, y, sample_weights = standardize(
--> 646         x, y, sample_weight=sample_weights)
    647   elif adapter_cls is data_adapter.ListsOfScalarsDataAdapter:
    648     standardize_function = standardize

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
   2381         is_dataset=is_dataset,
   2382         class_weight=class_weight,
-> 2383         batch_size=batch_size)
   2384 
   2385   def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, is_dataset, class_weight, batch_size)
   2408           feed_input_shapes,
   2409           check_batch_axis=False,  # Don't enforce the batch size.
-> 2410           exception_prefix='input')
   2411 
   2412     # Get typespecs for the input data and sanitize it if necessary.

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    580                              ': expected ' + names[i] + ' to have shape ' +
    581                              str(shape) + ' but got array with shape ' +
--> 582                              str(data_shape))
    583   return data
    584 

ValueError: Error when checking input: expected dense_211_input to have shape (1202,) but got array with shape (1,)
0

1 Answer 1

1

There are two approaches when feeding inputs on your model:

1st Option: Using the input_shape

model.add(Dense(300, activation='relu', input_shape=(1202,1)))

Here the input shape is in 2D, but you should feed your network a 3D input (Rank 3) since you need to include the batch_size.

Example input:

state = np.array(np.ones((BATCH_SIZE,1202,1)))
print("Input Rank: {}".format(tf.rank(state))) # Check for the Rank of Input

2nd Option: Using the input_dim

model_dim.add(Dense(300, activation='relu', input_dim=1202))

Here the input shape is in 1D, but you should feed your network a 2D input (Rank 2) since you need to include the batch_size.

Example input :

state = np.array(np.ones((1,1202,)))
print("Input Rank: {}".format(tf.rank(state))) # Check for the Rank of Input
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.