Я создаю пользовательский слой для tf.keras, и он правильно работает при выводе / обучении, но время начальной загрузки очень медленное.Очевидно, это происходит из вложенного цикла for, который я использую для разделения моих данных в call (), но я не уверен, как векторизовать / или просто ускорить этот процесс.Любые предложения хорошие!Спасибо!
Я пытался использовать tf.dynamic_partition, но я не уверен, что полностью понимаю подход только после прочтения документации на сайте tenorflow.
class AttentionLayer(tf.keras.layers.Layer):
def __init__(self, output_units):
super(AttentionLayer, self).__init__()
self.output_units = output_units
def build(self, input_shape):
self.len = input_shape[1]
self.cells = 100
if len(input_shape) == 3:
self.c = input_shape[2]
else:
self.c = 1
self.WQs = []
self.WKs = []
self.WA = self.add_variable("WA", [int(self.len - 1), 1],initializer=tf.glorot_uniform_initializer)
for idx in range(self.cells):
WQ = self.add_variable("WQ" + str(idx), [self.c , self.output_units]
,initializer=tf.glorot_uniform_initializer)
WK = self.add_variable("WK" + str(idx), [self.c , self.output_units]
,initializer=tf.glorot_uniform_initializer)
self.WQs.append(WQ)
self.WKs.append(WK)
def call(self, input):
attention = []
array = tf.reshape(input, [-1, self.len, self.c])
batch_size = tf.shape(input)[0]
for idx in range(self.cells):
print(str(idx), end="\r")
Q = tf.reshape(array[:, idx], [-1, 1, self.c])
#print("Q: ", Q)
context_list = []
for cdx in range(self.len):
if idx != cdx:
context_list.append(tf.reshape(array[:, cdx], [-1, 1, self.c]))
K = tf.concat(context_list, 1)
#print("K: ", K)
WQ_expand = tf.expand_dims(self.WQs[idx], axis=0)
WK_expand = tf.expand_dims(self.WKs[idx], axis=0)
WQ_tile = tf.tile(WQ_expand, [batch_size, 1, 1])
WK_tile = tf.tile(WK_expand, [batch_size, 1, 1])
Q = tf.matmul(Q, WQ_tile)
K = tf.matmul(K, WK_tile)
a = tf.nn.sigmoid(tf.matmul(Q,tf.reshape(K, [-1, self.output_units, int(self.len - 1)]))/27.5)
#print("a: ", a)
attention.append(a)
A = tf.concat(attention,1)
#print("A: ", A)
WA_expand = tf.expand_dims(self.WA, axis=0)
WA_tile = tf.tile(WA_expand, [batch_size, 1, 1])
Z = tf.reshape(tf.nn.sigmoid(tf.matmul(A,WA_tile)), [-1, self.cells, 1])
#print("Z: ", Z)
return Z
Слой принимает трехмерный ввод (размер пакета, ширина * высота, каналы). Оператор print должен показывать, какие слои инициализируются.