Как узнать, привязан ли словарь классов, загружаемых в fit_generator, к нужным классам? - PullRequest
1 голос
/ 13 июля 2020

У меня несбалансированный набор данных, поэтому я использовал словарь весов класса, я вычисляю веса класса с помощью этой функции:

from sklearn.utils import class_weight
def create_dict(dirs=['class1/','class2/','class3/','class4'], path  = '/home/trainset/'): 
    all = []
    for i in range(len(dirs)):
        c_path = path + dirs[i]
        l = [i for f in listdir(c_path) if isfile(join(c_path, f))]
        print(len(l))
        print('************************************************************')
        all+= l     
    y  = np.asarray(all)
    print(len(all))

    class_weights = class_weight.compute_class_weight('balanced', np.unique(y), y)
    print(class_weights)
    class_weights = dict(enumerate(class_weights))
    print(class_weights)
    f = open('weights_dict.txt','w')
    f.write(str(class_weights))
    return class_weights


class_weights = create_dict(dirs)

data_generator = ImageDataGenerator(rescale=1.0/255.0,                                  
                               horizontal_flip=True,
                               vertical_flip=True,
                               brightness_range=[0.2,1.0])

train_generator = data_generator.flow_from_directory(
    '/home/trainset/',  
    target_size=(image_size, image_size),
    batch_size=BATCH_SIZE_TRAINING,
    seed = 7)

Затем я передаю его функции fit_generator следующим образом:

fit_history = model.fit_generator(train_generator,
        epochs = NUM_EPOCHS,
        class_weight = class_weights,
        validation_data=validation_generator,
        verbose=2,      
        callbacks = [cb_checkpointer, cb_early_stopper, reduce_lr])

Модели работают хорошо по точности, но я не знаю, правильный ли вес прикреплен к правильному классу, как я могу быть уверен! а как соотнести каждый класс со своим весом?

1008 * PS. Я использую Keras для реализации, и я использую функцию compute_class_weight из библиотеки Sklearn.
...