Зависание при запуске тренировки керас - PullRequest
1 голос
/ 21 февраля 2020

Как решить проблему зависания, когда я начинаю тренировать свою модель keras? Это причина CUDA, CUDNN или того, как я называю ImageDataGenerator?

, это мой код:

import tensorflow as tf
import cv2
import keras

batch_s = 5

**CREATE THE MODEL**
from tensorflow.keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator

adam = Adam(lr=0.0001)

def create_model():
    model = Sequential([
        Conv2D(kernel_size=(9, 9), kernel_initializer='glorot_uniform', filters=128, padding='valid',
               use_bias=True, activation='relu', input_shape=[None, None, 3]),
        Conv2D(kernel_size=(3, 3), kernel_initializer='glorot_uniform', filters=64, padding='same',
               use_bias=True, activation='relu'),
        Conv2D(kernel_size=(5, 5), kernel_initializer='glorot_uniform', filters=1, padding='valid',
               use_bias=True, activation='linear')])
    model.compile(optimizer=adam, loss='mse', metrics=['mse'])
    model.summary()
    return model

train_path = '/source'
val_path = '/images'

train_dat = ImageDataGenerator()
val_dat = ImageDataGenerator()

train_generator = train_dat.flow_from_directory(directory=train_path,
                                                target_size=(512, 512),
                                                interpolation=cv2.INTER_LINEAR,
                                                save_format='bmp',
                                                batch_size=batch_s)
val_generator = val_dat.flow_from_directory(directory=val_path,
                                            target_size=(512, 512),
                                            interpolation=cv2.INTER_LINEAR,
                                            save_format='bmp',
                                            batch_size=batch_s)

**SAVE THE WEIGHT**
model = create_model()
model_checkpoint = tf.keras.callbacks.ModelCheckpoint('Weight{epoch:02d}.h5', save_best_only=True)

history = model.fit_generator(train_generator,
                    #batch_size=batch_s,
                    steps_per_epoch=4,
                    epochs=3,
                    callbacks=[model_checkpoint],
                    validation_data=val_generator,
                    validation_steps=4,
                    verbose=1,
                    workers=1,
                    use_multiprocessing=False,
                    shuffle=True)

Если у кого-то есть предложение по поводу правильной формы ImageDataGenerator пожалуйста, дайте мне знать. Я надеюсь получить ответ немедленно, спасибо заранее.

это то, что я получил

1 Ответ

0 голосов
/ 21 февраля 2020

Из того, что я вижу, вы запускаете это на небольшом графическом процессоре. Попробуйте уменьшить размер партии. Или проблема может быть в генераторе. Попробуйте сначала протестировать его вручную и посмотрите, вернет ли он правильные партии.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...