Реализация TPU на модели, дающей ошибку глобальной переменной - PullRequest
0 голосов
/ 24 февраля 2020

Я пытаюсь использовать ТПУ для обучения моей модели. Вот код

def train(self, epochs=1, batchSize=128):
        # Loading the data
        print("loading data")
        trainingData = os.listdir("/content/gdrive/My Drive/Thesis - Deep Fakes Detection/8- Experiments/faceFrames")

        print("connecting with TPU")
        resolver = tf.contrib.cluster_resolver.TPUClusterResolver('grpc://' + os.environ['COLAB_TPU_ADDR'])
        tf.contrib.distribute.initialize_tpu_system(resolver)
        strategy = tf.contrib.distribute.TPUStrategy(resolver)

        with strategy.scope():
          # Return keras model
          self.model = self.getModel()
          print("compiling")
          sgd = SGD(lr=0.000061, decay=0.36, momentum=0.8, nesterov=True)
          self.model.compile(optimizer=sgd,
                      loss='binary_crossentropy',
                      metrics=['accuracy'])

        '''
        self.model.compile(optimizer="adam",
                      loss='binary_crossentropy',
                      metrics=['accuracy'])
        '''        
        for index in range(1, epochs + 1):
            print("Starting Epoch %d" % index)
            loss = 0
            accuracy = 0
            for mIndex in tqdm(range(batchSize)):
                realImagesBatch = random.sample(trainingData, batchSize)
                (images, labels) = self.loadData(realImagesBatch)  # Pre train discriminator on  fake and real data  before starting the gan.
                [mLoss, mAcc] = self.model.train_on_batch(images, labels)
                #print(labels)
                loss = mLoss
                accuracy = mAcc
                tqdm._instances.clear() 
                #print('during training: loss is {} accuracy is {} for Epoch {}'.format(loss, accuracy, index))

            print('during training: loss is {} accuracy is {} for Epoch {}'.format(loss, accuracy, index))

        realImagesBatch = random.sample(trainingData, batchSize)
        (images, labels) = self.loadData(realImagesBatch)  # Pre train discriminator on  fake and real data  before starting the gan.
        [loss2, accuracy2] = self.model.test_on_batch(images, labels)
        print('during testing: loss is {} accuracy is {}'.format(loss2, accuracy2))

        self.model.save_weights('/content/gdrive/My Drive/Thesis - Deep Fakes Detection/8- Experiments/dataMesoNet.h5', overwrite=True)

Однако я получаю ошибку

Use tf.where in 2.0, which has the same broadcast rule as np.where
  0%|          | 0/64 [00:00<?, ?it/s]Starting Epoch 1
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-3-57359621164a> in <module>()
      1 datanet = DataNet()
----> 2 datanet(300, 64)

5 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py in eager_learning_phase_scope(value)
    413   global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
    414   assert value in {0, 1}
--> 415   assert ops.executing_eagerly_outside_functions()
    416   global_learning_phase_was_set = global_learning_phase_is_set()
    417   if global_learning_phase_was_set:

AssertionError: 
...