Я моделирую нейронную сеть с использованием Keras и пытаюсь оценить ее с помощью графика acc
и val_acc
.У меня есть 3 ошибки в следующих строках кода:
- В
print(history.keys())
Ошибка function' object has not attribute 'keys'
- В
y_pred = classifier.predict(X_test)
Ошибка name 'classifier' is not defined
- In
plt.plot(history.history['acc'])
Ошибка: 'History' object is not subscriptable
Я также пытаюсь построить график кривой ROC, как я могу это сделать?
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import keras
from keras.models import Sequential
from keras.layers import Dense
from sklearn import cross_validation
from matplotlib import pyplot
from keras.utils import plot_model
dataset = pd.read_csv('Data_BP.csv')
X = dataset.iloc[:, 0:11].values
y = dataset.iloc[:, -1].values
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size = 0.2, random_state = 0)
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)
def Model():
classifier = Sequential()
classifier.add(Dense(units = 12, kernel_initializer = 'uniform', activation = 'relu', input_dim = 11))
classifier.add(Dense(units = 8, kernel_initializer = 'uniform', activation = 'relu'))
classifier.add(Dense(units = 1, kernel_initializer = 'uniform', activation = 'sigmoid'))
classifier.compile(optimizer = 'adam', loss = 'mean_squared_error', metrics = ['mse', 'acc'])
return classifier
classifier = Model()
history = classifier.fit(X_train, y_train, validation_split=0.25, batch_size = 10, epochs = 5)
print('\n', history.history.keys())
y_pred = classifier.predict(X_test)
y_pred = (y_pred > 0.5)
from sklearn.metrics import recall_score, classification_report, auc, roc_curve
cm = confusion_matrix(y_test, y_pred)
print(cm)
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
Какие функции должныбыть добавленным?