Я использую tf.contrib.tpu.keras_to_tpu_model
, чтобы мой код мог работать на TPU, но завершение эпохи заняло 170 часов, в то время как процессор занимал столько же времени, а GPU - только 40 часов на одну эпоху. Я пытался настроить размер пакета, ноничего не изменилось. И я проверял, что функция ввода может занимать 20% времени работы при работе на GPU, поэтому я думаю, что это, возможно, не главная причина.
Вот мой код: https://github.com/WangHexie/DHNE/blob/master/src/hypergraph_embedding.py
Запуск на colab:
- ТПУ: https://colab.research.google.com/gist/WangHexie/30c385509f9cd93be747f04c39f039a4/tpu-error.ipynb
- GPU: https://colab.research.google.com/gist/WangHexie/5bfac53bf92ef0ad527f15ddbf8705e1/-gpu-ipynb.ipynb
Модель:
def build_model(self):
self.inputs = [Input(shape=(self.options.dim_feature[i], ), name='input_{}'.format(i), dtype='float') for i in range(3)]
self.encodeds = [Dense(self.options.embedding_size[i], activation='tanh', name='encode_{}'.format(i))(self.inputs[i]) for i in range(3)]
self.decodeds = [Dense(self.options.dim_feature[i], activation='sigmoid', name='decode_{}'.format(i),
activity_regularizer = regularizers.l2(0.0))(self.encodeds[i]) for i in range(3)]
self.merged = concatenate(self.encodeds, axis=1)
self.hidden_layer = Dense(self.options.hidden_size, activation='tanh', name='full_connected_layer')(self.merged)
self.ouput_layer = Dense(1, activation='sigmoid', name='classify_layer')(self.hidden_layer)
self.model = Model(inputs=self.inputs, outputs=self.decodeds+[self.ouput_layer])
self.model.compile(optimizer=tf.train.AdamOptimizer(learning_rate=self.options.learning_rate),
loss=[self.sparse_autoencoder_error]*3+['binary_crossentropy'],
loss_weights=[self.options.alpha]*3+[1.0],
metrics=dict([('decode_{}'.format(i), 'mse') for i in range(3)]+[('classify_layer', 'accuracy')]))
self.model = tf.contrib.tpu.keras_to_tpu_model(
self.model,
strategy=tf.contrib.tpu.TPUDistributionStrategy(
tf.contrib.cluster_resolver.TPUClusterResolver(
tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
)
)
self.model.summary()