Если вы хотите изменить поведение прямого прохода в зависимости от условия, вы можете просто добавить параметр в метод call
.
Из вашего примера похоже, что вы хотите изменить функцию активации последнего слоя. Таким образом, вы можете просто определить последний слой с линейной функцией активации и применить желаемую функцию активации в зависимости от условия.
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
self.flatten = tf.keras.layers.Flatten()
self.d1 = tf.keras.layers.Dense(128, activation='relu')
# note: no activation = linear activation
self.d2 = tf.keras.layers.Dense(10)
# Create two activation layers
self.relu = tf.keras.layers.ReLU()
self.softmax = tf.keras.layers.Softmax()
def call(self, x, condition):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
x = self.d2(x)
# Change the activation depending on the condition
if condition:
tf.print("callign with activation=relu")
x = self.relu(x)
return self.softmax(x)
model = MyModel()
fake_input = tf.zeros((1, 28, 28, 1))
tf.print("condition false")
tf.print(model(fake_input, condition=False))
tf.print("condition true")
tf.print(model(fake_input, condition=True))