обучение CNN на ТПУ с использованием Google Colab - PullRequest
0 голосов
/ 20 февраля 2019

Мне нужно настроить CNN (vgg16) для большого набора данных изображений.Я использую Google Colab, и мне нужно использовать TPU для ускорения обучения.Как указано в примере документации Я использую tf.keras вместо keras. После создания модели, замены полностью подключенных слоев и указания обучаемых я скомпилировал модель, используя categorical_crossentropy lossи создал генератор поездов:

model.compile(loss='categorical_crossentropy', 
          optimizer= tensorflow.train.RMSPropOptimizer(learning_rate=1e-4),
          metrics = ['accuracy'])
train_data_gen = ImageDataGenerator(rescale=1./255,rotation_range = 20, 
                                width_shift_range = 0.2, 
                                height_shift_range = 0.2, 
                                horizontal_flip = True)
train_gen= train_data_gen.flow_from_directory('/content/drive/My Drive/data/train', 
                                          target_size=(224, 224), 
                                          batch_size = 1024, 
                                          class_mode='categorical' )

Затем я преобразовал модель в совместимую модель TPU с тензорным потоком:

tf.logging.set_verbosity(tf.logging.INFO)
tpu_model = tf.contrib.tpu.keras_to_tpu_model(model,
strategy=tf.contrib.tpu.TPUDistributionStrategy(
    tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))

Когда я пытаюсь обучить модель с помощью следующей команды:

history = tpu_model.fit_generator(train_gen, 
                    steps_per_epoch=train_gen.samples/train_gen.batch_size, 
                    epochs=1000, 
                    verbose=2, 
                    shuffle= False)

Я получаю это сообщение об ошибке:

RuntimeError: Compilation failed: Compilation failure: 
Detected unsupported operations when trying to compile graph 
cluster_4125121914893370080[] on XLA_TPU_JIT: Placeholder 
(No registered 'Placeholder' OpKernel for XLA_TPU_JIT devices compatible          
with node {{node tpu_140154018445800/input_1}} Registered:  
device='TPU'
device='CPU'
device='GPU'
device='XLA_CPU'
){{node tpu_140154018445800/input_1}}

Может кто-нибудь помочь мне решить эту проблему?

...