Я пытаюсь создать базовый NN, используя MLP Classifier.
Когда я использую метод mlp.fit
a, получаю следующую ошибку:
ValueError: Неизвестный тип метки: (массив ([
Ниже моего простого кода
df_X_train = df_train[["Pe/Pe_nom","Gas_cons","PthLoad"]]
df_Y_train = df_train["Eff_Th"]
df_X_test = df_test[["Pe/Pe_nom","Gas_cons","PthLoad"]]
df_Y_test = df_test["Eff_Th"]
X_train = np.asarray(df_X_train, dtype="float64")
Y_train = np.asarray(df_Y_train, dtype="float64")
X_test = np.asarray(df_X_test, dtype="float64")
Y_test = np.asarray(df_Y_test, dtype="float64")
from sklearn.neural_network import MLPClassifier
mlp = MLPClassifier(hidden_layer_sizes=(100,), verbose=True)
mlp.fit(X_train, Y_train)
На самом деле я не понимаю, почему метод fit
не любит типы с плавающей точкой X_train
и Y_train
.
Просто чтобы все было понятно под размерами матрицы:
X_train.shape --> (720, 3)
Y_train.shape --> (720,)
Жаль, что я спросил это правильно, спасибо.
ниже полной ошибки:
> --------------------------------------------------------------------------- ValueError Traceback (most recent call
> last) <ipython-input-6-2efb224ab852> in <module>()
> 2
> 3 mlp = MLPClassifier(hidden_layer_sizes=(100,), verbose=True)
> ----> 4 mlp.fit(X_train, Y_train)
> 5
> 6 #y_pred_train = mlp.predict(X_train)
>
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\neural_network\multilayer_perceptron.py
> in fit(self, X, y)
> 971 """
> 972 return self._fit(X, y, incremental=(self.warm_start and
> --> 973 hasattr(self, "classes_")))
> 974
> 975 @property
>
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\neural_network\multilayer_perceptron.py
> in _fit(self, X, y, incremental)
> 329 hidden_layer_sizes)
> 330
> --> 331 X, y = self._validate_input(X, y, incremental)
> 332 n_samples, n_features = X.shape
> 333
>
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\neural_network\multilayer_perceptron.py
> in _validate_input(self, X, y, incremental)
> 914 if not incremental:
> 915 self._label_binarizer = LabelBinarizer()
> --> 916 self._label_binarizer.fit(y)
> 917 self.classes_ = self._label_binarizer.classes_
> 918 elif self.warm_start:
>
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\preprocessing\label.py
> in fit(self, y)
> 282
> 283 self.sparse_input_ = sp.issparse(y)
> --> 284 self.classes_ = unique_labels(y)
> 285 return self
> 286
>
> C:\ProgramData\Anaconda3\lib\site-packages\sklearn\utils\multiclass.py
> in unique_labels(*ys)
> 94 _unique_labels = _FN_UNIQUE_LABELS.get(label_type, None)
> 95 if not _unique_labels:
> ---> 96 raise ValueError("Unknown label type: %s" % repr(ys))
> 97
> 98 ys_labels = set(chain.from_iterable(_unique_labels(y) for y in ys))
>
> ValueError: Unknown label type: (array([1. , 0.89534884, 0.58139535, 0.37209302, 0.24418605,
0.15116279, 0.09302326, 0.23255814, 0.34883721, 0.37209302,
0.30232558, 0.23255814, 0.18604651, 0.12790698, 0.08139535,
0.08139535, 0.19767442, 0.27906977, 0.26744186, 0.22093023,
0.1744186 , 0.11627907, 0.06976744, 0.05813953, 0.1744186 ,
0.26744186, 0.34883721, 0.40697674, 0.46511628, 0.45348837,
0.38372093, 0.31395349, 0.26744186, 0.36046512, 0.44186047,
0.48837209, 0.53488372, 0.48837209, 0.40697674, 0.31395349,
0.24418605, 0.1744186 , 0.19767442, 0.29069767, 0.36046512,
0.3255814 , 0.26744186, 0.20930233, 0.13953488, 0.09302326,
0.04651163, 0.09302326, 0.19767442, 0.29069767, 0.26744186,
0.20930233, 0.1627907 , 0.11627907, 0.06976744, 0.03488372,
0.12790698, 0.24418605, 0.31395349, 0.26744186, 0.20930233,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.13953488,
0.25581395, 0.30232558, 0.24418605, 0.19767442, 0.15116279,
0.09302326, 0.05813953, 0.04651163, 0.1627907 , 0.26744186,
0.30232558, 0.24418605, 0.19767442, 0.13953488, 0.09302326,
0.05813953, 0.06976744, 0.18604651, 0.27906977, 0.27906977,
0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.03488372,
0.10465116, 0.22093023, 0.29069767, 0.26744186, 0.22093023,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.12790698,
0.24418605, 0.30232558, 0.25581395, 0.20930233, 0.15116279,
0.10465116, 0.05813953, 0.03488372, 0.15116279, 0.26744186,
0.30232558, 0.25581395, 0.19767442, 0.15116279, 0.09302326,
0.05813953, 0.09302326, 0.20930233, 0.29069767, 0.26744186,
0.22093023, 0.1627907 , 0.11627907, 0.06976744, 0.02325581,
0.12790698, 0.23255814, 0.31395349, 0.26744186, 0.20930233,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.13953488,
0.25581395, 0.31395349, 0.25581395, 0.20930233, 0.15116279,
0.10465116, 0.05813953, 0.02325581, 0.11627907, 0.22093023,
0.29069767, 0.24418605, 0.19767442, 0.13953488, 0.09302326,
0.04651163, 0.02325581, 0.10465116, 0.20930233, 0.30232558,
0.25581395, 0.20930233, 0.15116279, 0.10465116, 0.05813953,
0.03488372, 0.13953488, 0.24418605, 0.31395349, 0.25581395,
0.20930233, 0.15116279, 0.10465116, 0.15116279, 0.26744186,
0.3372093 , 0.36046512, 0.30232558, 0.24418605, 0.19767442,
0.1744186 , 0.25581395, 0.3255814 , 0.38372093, 0.41860465,
0.34883721, 0.29069767, 0.23255814, 0.1627907 , 0.1744186 ,
0.27906977, 0.34883721, 0.3255814 , 0.26744186, 0.20930233,
0.15116279, 0.09302326, 0.04651163, 0.10465116, 0.22093023,
0.30232558, 0.25581395, 0.20930233, 0.15116279, 0.10465116,
0.05813953, 0.02325581, 0.12790698, 0.24418605, 0.30232558,
0.25581395, 0.20930233, 0.15116279, 0.10465116, 0.1627907 ,
0.26744186, 0.37209302, 0.45348837, 0.51162791, 0.55813953,
0.59302326, 0.62790698, 0.56976744, 0.48837209, 0.40697674,
0.36046512, 0.43023256, 0.47674419, 0.48837209, 0.39534884,
0.30232558, 0.23255814, 0.1627907 , 0.10465116, 0.19767442,
0.29069767, 0.31395349, 0.25581395, 0.20930233, 0.15116279,
0.10465116, 0.05813953, 0.02325581, 0.03488372, 0.15116279,
0.25581395, 0.25581395, 0.20930233, 0.15116279, 0.10465116,
0.06976744, 0.03488372, 0.04651163, 0.1627907 , 0.26744186,
0.25581395, 0.20930233, 0.1627907 , 0.11627907, 0.06976744,
0.03488372, 0. , 0.10465116, 0.20930233, 0.27906977,
0.22093023, 0.1744186 , 0.12790698, 0.08139535, 0.08139535,
0.19767442, 0.29069767, 0.36046512, 0.43023256, 0.48837209,
0.53488372, 0.56976744, 0.60465116, 0.52325581, 0.45348837,
0.38372093, 0.45348837, 0.51162791, 0.54651163, 0.54651163,
0.44186047, 0.36046512, 0.27906977, 0.20930233, 0.1744186 ,
0.25581395, 0.3372093 , 0.3372093 , 0.27906977, 0.22093023,
0.1627907 , 0.10465116, 0.05813953, 0.06976744, 0.18604651,
0.27906977, 0.27906977, 0.22093023, 0.1744186 , 0.12790698,
0.08139535, 0.03488372, 0.10465116, 0.22093023, 0.30232558,
0.27906977, 0.22093023, 0.1744186 , 0.11627907, 0.19767442,
0.29069767, 0.36046512, 0.40697674, 0.34883721, 0.29069767,
0.23255814, 0.1744186 , 0.20930233, 0.30232558, 0.36046512,
0.34883721, 0.29069767, 0.23255814, 0.1744186 , 0.11627907,
0.06976744, 0.11627907, 0.22093023, 0.30232558, 0.27906977,
0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.12790698,
0.24418605, 0.3255814 , 0.27906977, 0.23255814, 0.1744186 ,
0.12790698, 0.08139535, 0.03488372, 0. , 0.11627907,
0.22093023, 0.27906977, 0.22093023, 0.1744186 , 0.12790698,
0.08139535, 0.04651163, 0.02325581, 0.11627907, 0.23255814,
0.30232558, 0.25581395, 0.19767442, 0.15116279, 0.10465116,
0.05813953, 0.08139535, 0.19767442, 0.29069767, 0.29069767,
0.23255814, 0.18604651, 0.13953488, 0.08139535, 0.04651163,
0.06976744, 0.18604651, 0.27906977, 0.27906977, 0.23255814,
0.1744186 , 0.12790698, 0.08139535, 0.04651163, 0.12790698,
0.24418605, 0.3255814 , 0.27906977, 0.22093023, 0.1744186 ,
0.11627907, 0.06976744, 0.03488372, 0.13953488, 0.24418605,
0.30232558, 0.25581395, 0.19767442, 0.15116279, 0.10465116,
0.05813953, 0.02325581, 0.13953488, 0.24418605, 0.26744186,
0.22093023, 0.1744186 , 0.12790698, 0.06976744, 0.03488372,
0.08139535, 0.19767442, 0.27906977, 0.29069767, 0.24418605,
0.19767442, 0.13953488, 0.09302326, 0.11627907, 0.23255814,
0.3255814 , 0.30232558, 0.25581395, 0.19767442, 0.15116279,
0.09302326, 0.04651163, 0.08139535, 0.19767442, 0.27906977,
0.31395349, 0.25581395, 0.19767442, 0.15116279, 0.10465116,
0.05813953, 0.09302326, 0.20930233, 0.30232558, 0.27906977,
0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.03488372,
0.03488372, 0.15116279, 0.25581395, 0.26744186, 0.20930233,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.01162791,
0.12790698, 0.23255814, 0.31395349, 0.29069767, 0.24418605,
0.18604651, 0.13953488, 0.09302326, 0.05813953, 0.1744186 ,
0.27906977, 0.34883721, 0.29069767, 0.23255814, 0.1744186 ,
0.11627907, 0.06976744, 0.09302326, 0.19767442, 0.30232558,
0.31395349, 0.26744186, 0.20930233, 0.15116279, 0.10465116,
0.05813953, 0.09302326, 0.20930233, 0.30232558, 0.27906977,
0.23255814, 0.1744186 , 0.12790698, 0.08139535, 0.03488372,
0.08139535, 0.20930233, 0.29069767, 0.26744186, 0.20930233,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.09302326,
0.20930233, 0.27906977, 0.23255814, 0.18604651, 0.13953488,
0.09302326, 0.04651163, 0.05813953, 0.18604651, 0.26744186,
0.3372093 , 0.30232558, 0.24418605, 0.19767442, 0.13953488,
0.09302326, 0.1744186 , 0.27906977, 0.34883721, 0.30232558,
0.24418605, 0.18604651, 0.13953488, 0.08139535, 0.03488372,
0.04651163, 0.1627907 , 0.26744186, 0.26744186, 0.22093023,
0.1627907 , 0.11627907, 0.06976744, 0.03488372, 0.03488372,
0.15116279, 0.25581395, 0.27906977, 0.22093023, 0.1744186 ,
0.12790698, 0.08139535, 0.03488372, 0.01162791, 0.12790698,
0.23255814, 0.29069767, 0.24418605, 0.19767442, 0.13953488,
0.09302326, 0.05813953, 0.05813953, 0.1744186 , 0.27906977,
0.29069767, 0.24418605, 0.18604651, 0.13953488, 0.09302326,
0.11627907, 0.23255814, 0.30232558, 0.34883721, 0.29069767,
0.24418605, 0.18604651, 0.12790698, 0.15116279, 0.25581395,
0.3255814 , 0.30232558, 0.24418605, 0.19767442, 0.13953488,
0.09302326, 0.12790698, 0.22093023, 0.30232558, 0.25581395,
0.20930233, 0.1627907 , 0.11627907, 0.05813953, 0.02325581,
0.05813953, 0.1744186 , 0.26744186, 0.22093023, 0.1744186 ,
0.12790698, 0.08139535, 0.04651163, 0.01162791, 0.11627907,
0.22093023, 0.25581395, 0.22093023, 0.1744186 , 0.12790698,
0.08139535, 0.03488372, 0.08139535, 0.19767442, 0.27906977,
0.34883721, 0.29069767, 0.24418605, 0.18604651, 0.13953488,
0.10465116, 0.22093023, 0.30232558, 0.3255814 , 0.27906977,
0.22093023, 0.1627907 , 0.10465116, 0.05813953, 0.02325581,
0.12790698, 0.24418605, 0.29069767, 0.24418605, 0.19767442,
0.13953488, 0.09302326, 0.05813953, 0.02325581, 0.10465116,
0.22093023, 0.30232558, 0.24418605, 0.19767442, 0.15116279,
0.09302326, 0.05813953, 0.02325581, 0.06976744, 0.18604651,
0.27906977, 0.25581395, 0.20930233, 0.1627907 , 0.10465116,
0.06976744, 0.03488372, 0.04651163, 0.1627907 , 0.25581395,
0.3255814 , 0.38372093, 0.44186047, 0.41860465, 0.34883721,
0.29069767, 0.24418605, 0.25581395, 0.34883721, 0.41860465,
0.46511628, 0.5 , 0.51162791, 0.41860465, 0.3372093 ,
0.26744186, 0.20930233, 0.20930233, 0.30232558, 0.37209302,
0.36046512, 0.29069767, 0.22093023, 0.15116279, 0.10465116,
0.09302326, 0.19767442, 0.27906977, 0.25581395, 0.20930233,
0.1627907 , 0.11627907, 0.06976744, 0.02325581, 0.08139535,
0.19767442, 0.26744186, 0.22093023, 0.1744186 , 0.13953488,
0.09302326, 0.04651163, 0.02325581, 0.13953488, 0.24418605,
0.26744186, 0.22093023, 0.1744186 , 0.12790698, 0.08139535,
0.1744186 , 0.26744186, 0.34883721, 0.40697674, 0.46511628,
0.41860465, 0.34883721, 0.27906977, 0.22093023, 0.18604651,
0.27906977, 0.34883721, 0.37209302, 0.30232558, 0.24418605,
0.1744186 , 0.11627907, 0.06976744, 0.03488372, 0.15116279]),)