Я пытаюсь использовать ТПУ для обучения моей модели. Вот код
def train(self, epochs=1, batchSize=128):
# Loading the data
print("loading data")
trainingData = os.listdir("/content/gdrive/My Drive/Thesis - Deep Fakes Detection/8- Experiments/faceFrames")
print("connecting with TPU")
resolver = tf.contrib.cluster_resolver.TPUClusterResolver('grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.contrib.distribute.TPUStrategy(resolver)
with strategy.scope():
# Return keras model
self.model = self.getModel()
print("compiling")
sgd = SGD(lr=0.000061, decay=0.36, momentum=0.8, nesterov=True)
self.model.compile(optimizer=sgd,
loss='binary_crossentropy',
metrics=['accuracy'])
'''
self.model.compile(optimizer="adam",
loss='binary_crossentropy',
metrics=['accuracy'])
'''
for index in range(1, epochs + 1):
print("Starting Epoch %d" % index)
loss = 0
accuracy = 0
for mIndex in tqdm(range(batchSize)):
realImagesBatch = random.sample(trainingData, batchSize)
(images, labels) = self.loadData(realImagesBatch) # Pre train discriminator on fake and real data before starting the gan.
[mLoss, mAcc] = self.model.train_on_batch(images, labels)
#print(labels)
loss = mLoss
accuracy = mAcc
tqdm._instances.clear()
#print('during training: loss is {} accuracy is {} for Epoch {}'.format(loss, accuracy, index))
print('during training: loss is {} accuracy is {} for Epoch {}'.format(loss, accuracy, index))
realImagesBatch = random.sample(trainingData, batchSize)
(images, labels) = self.loadData(realImagesBatch) # Pre train discriminator on fake and real data before starting the gan.
[loss2, accuracy2] = self.model.test_on_batch(images, labels)
print('during testing: loss is {} accuracy is {}'.format(loss2, accuracy2))
self.model.save_weights('/content/gdrive/My Drive/Thesis - Deep Fakes Detection/8- Experiments/dataMesoNet.h5', overwrite=True)
Однако я получаю ошибку
Use tf.where in 2.0, which has the same broadcast rule as np.where
0%| | 0/64 [00:00<?, ?it/s]Starting Epoch 1
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-3-57359621164a> in <module>()
1 datanet = DataNet()
----> 2 datanet(300, 64)
5 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py in eager_learning_phase_scope(value)
413 global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
414 assert value in {0, 1}
--> 415 assert ops.executing_eagerly_outside_functions()
416 global_learning_phase_was_set = global_learning_phase_is_set()
417 if global_learning_phase_was_set:
AssertionError: