Как поймать любое исключение во время обучения модели в Tensorflow 2 - PullRequest
0 голосов
/ 04 ноября 2019

Я тренирую модель Unet, используя Tensorflow. Если возникает проблема с каким-либо из изображений, которые я передаю модели для обучения, возникает исключение. Иногда это может происходить через час или два после тренировки. Можно ли поймать какие-либо такие исключения в будущем, чтобы моя модель могла перейти к следующему изображению и возобновить обучение? Я попытался добавить блок try/catch к функции process_path, показанной ниже, но это не имеет никакого эффекта ...

def process_path(filePath):
    # catching exceptions here has no effect
    parts = tf.strings.split(filePath, '/')
    fileName = parts[-1]
    parts = tf.strings.split(fileName, '.')
    prefix = tf.convert_to_tensor(maskDir, dtype=tf.string)
    suffix = tf.convert_to_tensor("-mask.png", dtype=tf.string)
    maskFileName = tf.strings.join((parts[-2], suffix))
    maskPath = tf.strings.join((prefix, maskFileName), separator='/')

    # load the raw data from the file as a string
    img = tf.io.read_file(filePath)
    img = decode_img(img)
    mask = tf.io.read_file(maskPath)
    oneHot = decodeMask(mask)
    img.set_shape([256, 256, 3])
    oneHot.set_shape([256, 256, 10])
    return img, oneHot

trainSize = int(0.7 * DATASET_SIZE)
validSize = int(0.3 * DATASET_SIZE)
batchSize = 32

allDataSet = tf.data.Dataset.list_files(str(imageDir + "/*"))

trainDataSet = allDataSet.take(trainSize)
trainDataSet = trainDataSet.shuffle(1000).repeat()
trainDataSet = trainDataSet.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE)
trainDataSet = trainDataSet.batch(batchSize)
trainDataSet = trainDataSet.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

validDataSet = allDataSet.skip(trainSize)
validDataSet = validDataSet.shuffle(1000).repeat()
validDataSet = validDataSet.map(process_path)
validDataSet = validDataSet.batch(batchSize)

imageHeight = 256
imageWidth = 256
channels = 3

inputImage = Input((imageHeight, imageWidth, channels), name='img') 
model = baseUnet.get_unet(inputImage, n_filters=16, dropout=0.05, batchnorm=True)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

callbacks = [
    EarlyStopping(patience=5, verbose=1),
    ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1),
    ModelCheckpoint(outputModel, verbose=1, save_best_only=True, save_weights_only=False)
]

BATCH_SIZE = 32
BUFFER_SIZE = 1000
EPOCHS = 20

stepsPerEpoch = int(trainSize / BATCH_SIZE)
validationSteps = int(validSize / BATCH_SIZE)

model_history = model.fit(trainDataSet, epochs=EPOCHS,
                          steps_per_epoch=stepsPerEpoch,
                          validation_steps=validationSteps,
                          validation_data=validDataSet,
                          callbacks=callbacks)

Следующая ссылка показывает аналогичный случай иобъясняет, что « функция Python выполняется только один раз для построения графа функции и попытки, и операторы кроме этого не будут иметь никакого эффекта. » Хотя ссылка показывает, как перебирать набор данных и отлавливать ошибки ...

dataset = ...
iterator = iter(dataset)

while True:
  try:
    elem = next(iterator)
    ...
  except InvalidArgumentError:
    ...
  except StopIteration:
    break

... Однако я ищу способ отловить ошибку во время тренировки. Это возможно?

...