Я работаю с нейронными сетями в Python, используя библиотеку машинного обучения H2O, и получаю странную ошибку Java, когда пытаюсь обучить свою сеть, поскольку H2O - это пакет на основе Java, который может быть включен в Python.
Я запускаю эту программу, используя Python 3.6, и до сих пор у меня не было проблем с ней. Однако теперь, когда я пытаюсь обучить сеть, я получаю следующие сообщения об ошибках
Из того, что я прочитал, это проблема, возникающая при попытке использовать массив в качестве обучающего набора данных, однако это не объясняет, почему ошибка появляется только сейчас. Моя Java была недавно обновлена, и я подозреваю, что это возможно, но я не могу подтвердить, что это причина.
Ниже приведен код Python из моей программы:
import h2o
import matplotlib
from h2o.estimators.deeplearning import H2ODeepLearningEstimator
import os
from tkinter import *
from tkinter import filedialog, messagebox
# Used for selecting a CSV file to import into H2O
def file_dialog():
Tk().withdraw()
return filedialog.askopenfilename(initialdir=os.environ["HOMEPATH"] + "\\Desktop",
title="Select a file",
filetypes=[("CSV files", ("*.csv", "*.xlsx"))])
def main():
h2o.init(nthreads = 4, max_mem_size = 8)
csv_file = file_dialog()
data = h2o.import_file(csv_file)
# This is used for classifying a variable in the CSV called "Discrepancy" as output
data['Discrepancy'] = data['Discrepancy'].asfactor()
data['Discrepancy'].levels()
# Used for splitting the dataset into training, testing, and validation sets
splits = data.split_frame(ratios=[0.7, 0.15], seed=1)
train = splits[0]
valid = splits[1]
test = splits[2]
y = 'Discrepancy'
x = list(data.columns)
# Removes "Discrepancy" from the list of inputs
x.remove(y)
dl_fit3 = H2ODeepLearningEstimator(model_id='dl_fit3',
epochs=1,
hidden=[60,40,20],
stopping_rounds=0, # Used for early stopping
stopping_tolerance=0.0001, # Used for early stopping
seed=1)
# This is the line where the error occurs
dl_fit3.train(x=x, y=y, training_frame=train, validation_frame=valid)
dl_perf3 = dl_fit3.model_performance(test)
# This line may cause the program to pause, but is not related to the error at hand
dl_perf3.plot()
if __name__ == "__main__":
os.chdir(os.environ["HOMEPATH"] + "\\Desktop")
main()
Размер или характер данных, по которым я пытаюсь обучить сеть, не имеет отношения к возникновению ошибки. Единственное различие, которое я заметил, заключается в том, что при использовании большего набора данных для возникновения ошибки в процессе обучения требуется немного больше времени.
Буду признателен за любую помощь в выявлении причин возникновения этой ошибки и способов ее устранения.