Как исправить ошибку «UnsupportedOperationException: слишком много классов» в Python H2O - PullRequest
0 голосов
/ 11 июля 2019

Я работаю с нейронными сетями в Python, используя библиотеку машинного обучения H2O, и получаю странную ошибку Java, когда пытаюсь обучить свою сеть, поскольку H2O - это пакет на основе Java, который может быть включен в Python.

Я запускаю эту программу, используя Python 3.6, и до сих пор у меня не было проблем с ней. Однако теперь, когда я пытаюсь обучить сеть, я получаю следующие сообщения об ошибках

Из того, что я прочитал, это проблема, возникающая при попытке использовать массив в качестве обучающего набора данных, однако это не объясняет, почему ошибка появляется только сейчас. Моя Java была недавно обновлена, и я подозреваю, что это возможно, но я не могу подтвердить, что это причина.

Ниже приведен код Python из моей программы:

import h2o 
import matplotlib
from h2o.estimators.deeplearning import H2ODeepLearningEstimator 
import os
from tkinter import *
from tkinter import filedialog, messagebox

# Used for selecting a CSV file to import into H2O
def file_dialog():
    Tk().withdraw()
    return filedialog.askopenfilename(initialdir=os.environ["HOMEPATH"] + "\\Desktop", 
                                      title="Select a file", 
                                      filetypes=[("CSV files", ("*.csv", "*.xlsx"))])

def main():

    h2o.init(nthreads = 4, max_mem_size = 8)

    csv_file = file_dialog()
    data = h2o.import_file(csv_file)

    # This is used for classifying a variable in the CSV called "Discrepancy" as output
    data['Discrepancy'] = data['Discrepancy'].asfactor()
    data['Discrepancy'].levels()

    # Used for splitting the dataset into training, testing, and validation sets
    splits = data.split_frame(ratios=[0.7, 0.15], seed=1)  

    train = splits[0]
    valid = splits[1]
    test = splits[2]

    y = 'Discrepancy'
    x = list(data.columns)

    # Removes "Discrepancy" from the list of inputs
    x.remove(y)  

    dl_fit3 = H2ODeepLearningEstimator(model_id='dl_fit3', 
                                       epochs=1,
                                       hidden=[60,40,20],
                                       stopping_rounds=0,         # Used for early stopping
                                       stopping_tolerance=0.0001, # Used for early stopping
                                       seed=1)

    # This is the line where the error occurs
    dl_fit3.train(x=x, y=y, training_frame=train, validation_frame=valid)

    dl_perf3 = dl_fit3.model_performance(test)

    # This line may cause the program to pause, but is not related to the error at hand
    dl_perf3.plot()

if __name__ == "__main__":
    os.chdir(os.environ["HOMEPATH"] + "\\Desktop")
    main()

Размер или характер данных, по которым я пытаюсь обучить сеть, не имеет отношения к возникновению ошибки. Единственное различие, которое я заметил, заключается в том, что при использовании большего набора данных для возникновения ошибки в процессе обучения требуется немного больше времени.

Буду признателен за любую помощь в выявлении причин возникновения этой ошибки и способов ее устранения.

...