Я не уверен, что вы имеете в виду, но у вас может быть переменная в вашем слое, которая просто обновляется предыдущим значением другой переменной на каждом шаге обучения, что-то вроде этих строк:
import tensorflow as tf
class MyLayer(tf.keras.layers.Layer):
def __init__(self, units, **kwargs):
super(MyLayer, self).__init__(**kwargs)
self.units = units
def build(self, input_shape):
self.w = self.add_weight(shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=self.trainable,
name='W')
self.w_prev = self.add_weight(shape=self.w.shape,
initializer='zeros',
trainable=False,
name='W_prev')
def call(self, inputs, training=False):
# Only update value of w_prev on training steps
deps = []
if training:
deps.append(self.w_prev.assign(self.w))
with tf.control_dependencies(deps):
return tf.matmul(inputs, self.w)
Вот пример использования:
import tensorflow as tf
import numpy as np
tf.random.set_seed(0)
np.random.seed(0)
# Make a random linear problem
x = np.random.rand(50, 3)
y = x @ np.random.rand(3, 2)
# Make model
model = tf.keras.Sequential()
my_layer = MyLayer(2, input_shape=(3,))
model.add(my_layer)
model.compile(optimizer='SGD', loss='mse')
# Train
cbk = tf.keras.callbacks.LambdaCallback(
on_batch_begin=lambda batch, logs: (tf.print('batch:', batch),
tf.print('w_prev:', my_layer.w_prev, sep='\n'),
tf.print('w:', my_layer.w, sep='\n')))
model.fit(x, y, batch_size=10, epochs=1, verbose=0, callbacks=[cbk])
Вывод:
batch: 0
w_prev:
[[0 0]
[0 0]
[0 0]]
w:
[[0.0755531341 0.0211461019]
[-0.0209847465 -0.0518018603]
[-0.0618413948 0.0235136505]]
batch: 1
w_prev:
[[0.0755531341 0.0211461019]
[-0.0209847465 -0.0518018603]
[-0.0618413948 0.0235136505]]
w:
[[0.0770048052 0.0292659812]
[-0.0199236758 -0.04635958]
[-0.060054455 0.0332755931]]
batch: 2
w_prev:
[[0.0770048052 0.0292659812]
[-0.0199236758 -0.04635958]
[-0.060054455 0.0332755931]]
w:
[[0.0780589 0.0353098139]
[-0.0189863108 -0.0414136574]
[-0.0590113513 0.0387929156]]
batch: 3
w_prev:
[[0.0780589 0.0353098139]
[-0.0189863108 -0.0414136574]
[-0.0590113513 0.0387929156]]
w:
[[0.0793346688 0.042034667]
[-0.0173048507 -0.0330933407]
[-0.0573575757 0.0470812619]]
batch: 4
w_prev:
[[0.0793346688 0.042034667]
[-0.0173048507 -0.0330933407]
[-0.0573575757 0.0470812619]]
w:
[[0.0805450454 0.0485667922]
[-0.0159637 -0.0261840075]
[-0.0563304275 0.052557759]]
РЕДАКТИРОВАТЬ: Я все еще не уверен на 100%, как именно вам это нужно, чтобы это работало , но вот кое-что, что может сработать для вас:
import tensorflow as tf
class KCompetitive(Layer):
'''Applies K-Competitive layer.
# Arguments
'''
def __init__(self, topk, ctype, **kwargs):
self.topk = topk
self.ctype = ctype
self.uses_learning_phase = True
self.supports_masking = True
self.x_prev = None
super(KCompetitive, self).__init__(**kwargs)
def call(self, x):
if self.ctype == 'ksparse':
return K.in_train_phase(self.kSparse(x, self.topk), x)
elif self.ctype == 'kcomp':
return K.in_train_phase(self.k_comp_tanh(x, self.topk), x)
else:
warnings.warn("Unknown ctype, using no competition.")
return x
def get_config(self):
config = {'topk': self.topk, 'ctype': self.ctype}
base_config = super(KCompetitive, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def k_comp_tanh(self, x, topk, factor=6.26):
if self.x_prev is None:
self.x_prev = self.add_weight(shape=x.shape,
initializer='zeros',
trainable=False,
name='X_prev')
###Some modification on x so now the x becomes
x_modified = self.x_prev.assign(x + 1)
return x_modified
Вот пример использования:
import tensorflow as tf
tf.random.set_seed(0)
np.random.seed(0)
# Make model
model = tf.keras.Sequential()
model.add(tf.keras.Input(batch_shape=(3, 4)))
my_layer = KCompetitive(2, 'kcomp')
print(my_layer.x_prev)
# None
model.add(my_layer)
# The variable gets created after it is added to a model
print(my_layer.x_prev)
# <tf.Variable 'k_competitive/X_prev:0' shape=(3, 4) dtype=float32, numpy=
# array([[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]], dtype=float32)>
model.compile(optimizer='SGD', loss='mse')
# "Train"
x = tf.zeros((3, 4))
cbk = tf.keras.callbacks.LambdaCallback(
on_epoch_begin=lambda batch, logs:
tf.print('initial x_prev:', my_layer.x_prev, sep='\n'),
on_epoch_end=lambda batch, logs:
tf.print('final x_prev:', my_layer.x_prev, sep='\n'),)
model.fit(x, x, epochs=1, verbose=0, callbacks=[cbk])
# initial x_prev:
# [[0 0 0 0]
# [0 0 0 0]
# [0 0 0 0]]
# final x_prev:
# [[1 1 1 1]
# [1 1 1 1]
# [1 1 1 1]]