0

What is the the counterpart in Tensorflow to pyTorch forward function?

I try to translate some pytorch code to tensorflow.

1 Answer 1

2

The forward function in nn.module in Pytorch can be replaced by the function "__call__()" in tf.module in Tensorflow or the function call() in tf.keras.layers.layer in keras. This is an example of a simple dense layer in tensorflow and keras:

Tensorflow:

class Dense(tf.Module):
  def __init__(self, input_dim, output_size, name=None):
     super().__init__(name=name)
     self.w = tf.Variable(tf.random.normal([input_dim, output_size]), name='w')
     self.b = tf.Variable(tf.zeros([output_size]), name='b')
  def __call__(self, x):
     y = tf.matmul(x, self.w) + self.b
     return tf.nn.relu(y)

Keras:

class Dense(tf.keras.Layers.Layer):
  def __init__(self, units=32):
     super(SimpleDense, self).__init__()
     self.units = units
  def build(self, input_shape):
     self.w = self.add_weight(shape=(input_shape[-1], self.units),
                           initializer='random_normal',
                           trainable=True)
     self.b = self.add_weight(shape=(self.units,),
                           initializer='random_normal',
                           trainable=True)
  def call(self, inputs):
     return tf.matmul(inputs, self.w) + self.b

You can check the following links for more details:

  1. https://www.tensorflow.org/api_docs/python/tf/Module
  2. https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.