Есть ли способ изменить переменную во время вызова? - PullRequest
0 голосов
/ 22 июня 2019

tenorflow2.0 имеет формат класса init и call
например

class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.conv1 = Conv2D(32, 3, activation='relu')
    self.flatten = Flatten()
    self.d1 = Dense(128, activation='relu')
    self.d2 = Dense(10, activation='softmax')

  def call(self, x):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    return self.d2(x)

model = MyModel()

мой вопрос в том, что если я хочу изменить

> def call(self, x):
>     x = self.conv1(x)
>     x = self.flatten(x)
>     x = self.d1(x)
>     return self.d2(x,activation='relu')

это вызывает ошибку.
если я хочу изменить атрибут во время какого-либо процесса Как я должен это делать?

1 Ответ

0 голосов
/ 22 июня 2019

Если вы хотите изменить поведение прямого прохода в зависимости от условия, вы можете просто добавить параметр в метод call.

Из вашего примера похоже, что вы хотите изменить функцию активации последнего слоя. Таким образом, вы можете просто определить последний слой с линейной функцией активации и применить желаемую функцию активации в зависимости от условия.

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
    self.flatten = tf.keras.layers.Flatten()
    self.d1 = tf.keras.layers.Dense(128, activation='relu')
    # note: no activation = linear activation
    self.d2 = tf.keras.layers.Dense(10)
    # Create two activation layers
    self.relu =  tf.keras.layers.ReLU()
    self.softmax = tf.keras.layers.Softmax()

  def call(self, x, condition):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    x = self.d2(x)
    # Change the activation depending on the condition
    if condition:
      tf.print("callign with activation=relu")
      x = self.relu(x)
    return self.softmax(x)

model = MyModel()

fake_input = tf.zeros((1, 28, 28, 1))

tf.print("condition false")
tf.print(model(fake_input, condition=False))
tf.print("condition true")
tf.print(model(fake_input, condition=True))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...