Python ValueError: n_splits = 3 не может быть больше, чем количество членов в каждом классе - PullRequest
0 голосов
/ 08 мая 2020

Я работаю над проектом распознавания лиц, в котором у меня есть два человека с двумя лицами на каждом

1. personA
    image1.jpg
    image2.jpg


2. personB
    image1.jpg
    image2.jpg

Я пытаюсь обучить модель встраиванию лица из вышеуказанного набора данных, как показано ниже:

params = {"C": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0], "gamma": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]}
model = GridSearchCV(SVC(kernel="rbf", gamma="auto", probability=True), params, cv=3, n_jobs=-1)
model.fit(data["embeddings"], labels)

, где длина data["embeddings"] и labels равна 4. data["embeddings'] содержит ndarray встраивания лиц personA, personB

data['embeddings'] = [
                         [0.02331057, -0.01995077, ..], 
                         [-0.00034041,  0.02753334, ..], 
                         [0.02454563, -0.03797123, ...], 
                         [0.10561685, -0.08444008, ...]
                     ]

labels = [0 0 1 1]

Но я получаю сообщение об ошибке ниже model.fit(data["embeddings"], labels):

ValueError: n_splits=3 cannot be greater than the number of members in each class.

Я не могу понять эту ошибку . Кто-нибудь может объяснить мне эту проблему и как я могу ее решить?

1 Ответ

1 голос
/ 08 мая 2020

При внимательном чтении сообщение об ошибке ясно и самоочевидно; он просто сообщает вам, что, поскольку у вас всего два (2) образца для каждого из ваших классов, у вас не может быть перекрестной проверки с 3 сгибами. Для этого потребуется минимум 3 образца для каждого из ваших классов.

Думаю, он должен работать с cv=2 без каких-либо ошибок, но весь ваш подход (т.е. набор данных только с 4 образца) представляется весьма сомнительным.

...