Я реализовал собственный генератор, который вызывает метод ImageDataGenerator
flow_from_dataframe
для каждого пакета, чтобы создать пакет со случайной выборкой, который имеет по одному экземпляру каждого класса:
def __getitem__(self, idx):
batch_df = pd.DataFrame()
# create a dataframe with one random sample for each class
for class_name in batch_class_names:
rand_row = self.df_dict[class_name].sample(n = 1)
batch_df = batch_df.append(rand_row, ignore_index=True)
# create generator
batch_gen = self.generator.flow_from_dataframe(
dataframe=batch_df, directory=self.directory, x_col="filename",
y_col="brand", classes=self.class_names,
class_mode="categorical",
target_size=self.image_dims, color_mode="rgb", batch_size=batch_df.shape[0], shuffle=self.shuffle)
# return batch
return next(batch_gen)
Это вызывает flow_from_dataframe to
print новую строку для каждого пакета , что портит выходные данные эпох. Вместо того, чтобы печатать новую строку после каждой эпохи (которая имеет 100 пакетов), новая строка печатается после каждой партии:
1/100 [..............................] - ETA: 20:05 - loss: 3.4795 - tpr_metric: 0.0000e+00Found 9 validated image filenames belonging to 28 classes.
2/100 [..............................] - ETA: 11:11 - loss: 3.4328 - tpr_metric: 0.0000e+00Found 9 validated image filenames belonging to 28 classes.
3/100 [..............................] - ETA: 8:10 - loss: 3.4140 - tpr_metric: 0.0000e+00Found 9 validated image filenames belonging to 28 classes.
4/100 [>.............................] - ETA: 6:39 - loss: 3.4309 - tpr_metric: 0.0000e+00Found 9 validated image filenames belonging to 28 classes.
5/100 [>.............................] - ETA: 5:45 - loss: 3.4323 - tpr_metric: 0.0000e+00Found 9 validated image filenames belonging to 28 classes.
6/100 [>.............................] - ETA: 5:08 - loss: 3.4221 - tpr_metric: 0.0000e+00Found 9 validated image filenames belonging to 28 classes.
7/100 [=>............................] - ETA: 4:41 - loss: 3.4188 - tpr_metric: 0.0000e+00Found 9 validated image filenames belonging to 28 classes.
Я пробовал разместить sys.stdout = open(os.devnull, 'w')
перед вызовом flow_from_dataframe
и sys.stdout = sys.__stdout__
сразу после этого, но это остановило всю печать (также печать эпох).
Есть предложения?