Tripleloss от Tensorflow add on дает ошибку изменения формы - PullRequest
0 голосов
/ 14 января 2020

Я получаю эту ошибку и не могу понять это.

 ValueError: Cannot reshape a tensor with 48032 elements to shape [32,1] (32 elements) for 'Reshape' (op: 'Reshape') with input shapes: [32,1501], [2] and with input tensors computed as partial shapes: input[1] = [32,1].

Я пытаюсь использовать функцию tripleloss из библиотеки tenorflow_addons, используя приведенный здесь пример

https://www.tensorflow.org/addons/tutorials/losses_triplet

Я в значительной степени скопировал это и изменил данные. Мой набор данных содержит 1501 различных классов, разделенных на папки для каждого класса. Я использую генератор данных из tf.data.Dataset, который, кажется, тоже работает нормально.

Это то, что у меня есть

BATCH_SIZE = 32
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    shear_range=0,
    rotation_range=20,
    zoom_range=0.15,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest')


ds = tf.data.Dataset.from_generator(generator=train_datagen.flow_from_directory,
                                    args=[train_dir, (224, 224), 'categorical'],
                                    output_types=(tf.float32, tf.float32),
                                    output_shapes=([32, 224,224,3], [32,1501]))


model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(224, 224, 3)),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation=None),  # No activation on final dense layer
    tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))  # L2 normalize embeddings
])
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss=tfa.losses.TripletSemiHardLoss())
history = model.fit(ds, epochs=45, verbose=1, callbacks=None)

Это в значительной степени дословная копия, кроме набора данных.

Должен ли я сделать функцию карты, например, ds.map (функция)?

1 Ответ

1 голос
/ 03 апреля 2020

точно проблема, с которой я столкнулся. Как указано в https://www.tensorflow.org/addons/api_docs/python/tfa/losses/TripletSemiHardLoss «Мы ожидаем, что метки y_true будут представлены как 1-D целочисленный тензор с формой [batch_size] многоцелевых целочисленных меток.»

ImageDataGenerator производит [ тензор batchsize, nclasses], который должен быть предварительно обработан для подачи в TripletSemiHardLoss.

Лично я создал собственную функцию trainig вместо model.fit:

for e in range(EPOCHS):
    print('Epoch', e)
    for b in range(int(STEPS_PER_EPOCH)):
        batch=train_data_gen.next()
        x_batch=batch[0]
        y_batch=np.argmax(batch[1],axis=1)  # <- class labels: y_true: 1-D integer 
        history=model.fit(x_batch, y_batch) # 1 step fit
        print(e,b)

, которая обучает модель, однако в данный момент я борюсь со значением потерь, которое случайным образом составляет от 0 до 1 на каждом шаге. Должно быть, градиенты теряются. Рассматривая это.

Edit1 :

на самом деле, эта вещь работает:

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
train_data_gen = image_generator.flow_from_directory(directory=str(data_dir),
                                                     batch_size=BATCH_SIZE,
                                                     shuffle=True,
                                                     target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                     classes = list(CLASS_NAMES),
                                                     color_mode='grayscale',
                                                     class_mode='sparse')
model.fit(train_data_gen, epochs=10)

Волхвы c должны были использовать class_mode = ' разреженный '

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...