Я пытаюсь создать нейронную сеть с одним скрытым слоем с нуля, и моя неправильная классификация и частота ошибок, очевидно, неверны.Я не могу найти свою ошибку, и у меня заканчивается время.Кто-нибудь может мне помочь?Я выложу остальную часть кода, если это необходимо.
def trainNN(Xtrain, ttrain, M, K, W1, W2, iterations,eta):
N = len(Xtrain)
Etotal = np.zeros(iterations)
misclassification_rate = np.zeros(iterations)
for i in range(iterations):
dE1_total = np.zeros(W1.shape)
dE2_total = np.zeros(W2.shape)
guesses = []
count = 0
true_counter = 0
cross_entropy = 0
while count < N:
x = Xtrain[count]
true_class = ttrain[count]
y_true = np.zeros(K)
y_true[true_class] = 1.0
count += 1
y,dE1,dE2 = backprop(x,target_y,M,K,W1,W2)
dE1_total = dE1_total+dE1
dE2_total = dE2_total+dE2
cross_entropy = cross_entropy + log_loss(y_true, y)
guesses.append(y_true.argmax())
if y.argmax() == y_true.argmax():
true_counter += 1
W1 = W1 - eta * dE1_total / N
W2 = W2 - eta * dE2_total / N
misclassification_rate[i] = 1 - true_counter / N
Etotal[i] = cross_entropy / N
return W1, W2, Etotal, misclassification_rate, guesses
Это мои результаты, так как вы можете видеть, что процент неправильной классификации только увеличивается, а также ошибка:
(array([[-0.26141859, 0.95582288, 0.57300376, 0.14235966, -0.68768852,
-0.64548708],
[-0.93809538, 1.02883521, 0.79944328, 0.12326179, -0.95742711,
1.18626673],
[ 0.62943067, -0.40277124, -0.28537253, -0.80926262, -0.39064817,
0.17429355],
[-0.15704319, -0.2594925 , 0.53201345, -0.85378359, -0.41510531,
-0.09040425],
[-0.09284687, 0.61627985, -0.50984544, -0.00669883, 0.18498267,
-0.84636893]]), array([[ 2.1677012 , -1.77665969, -2.88314017],
[ 0.94825079, 0.90245525, 0.56396907],
[ 1.46204349, -1.88697652, -1.54105275],
[ 1.5410913 , -1.79412746, -1.71918361],
[-0.83947178, 0.78059373, -0.57768823],
[ 0.32557882, -0.37690341, 0.0395823 ],
[ 1.93819337, -1.71160783, -0.96436068]]), array([2.99467737, 2.99560225, 2.9965253 , 2.99744653, 2.99836593,
2.99928353, 3.00019932, 3.00111332, 3.00202552, 3.00293595,
3.00384459, 3.00475147, 3.00565659, 3.00655995, 3.00746157,
3.00836144, 3.00925958, 3.01015599, 3.01105067, 3.01194365,
3.01283491, 3.01372448, 3.01461234, 3.01549852, 3.01638302,
3.01726584, 3.01814699, 3.01902648, 3.01990431, 3.02078049,
3.02165503, 3.02252792, 3.02339919, 3.02426882, 3.02513684,
3.02600325, 3.02686804, 3.02773123, 3.02859283, 3.02945284,
3.03031126, 3.0311681 , 3.03202337, 3.03287707, 3.03372922,
3.0345798 , 3.03542884, 3.03627633, 3.03712228, 3.0379667 ,
3.03880959, 3.03965096, 3.04049082, 3.04132916, 3.04216599,
3.04300133, 3.04383517, 3.04466752, 3.04549839, 3.04632778,
3.04715569, 3.04798213, 3.04880712, 3.04963064, 3.05045271,
3.05127333, 3.05209251, 3.05291026, 3.05372657, 3.05454145,
3.05535491, 3.05616695, 3.05697758, 3.0577868 , 3.05859461,
3.05940103, 3.06020606, 3.0610097 , 3.06181195, 3.06261283,
3.06341233, 3.06421046, 3.06500722, 3.06580263, 3.06659668,
3.06738938, 3.06818073, 3.06897074, 3.06975942, 3.07054676,
3.07133277, 3.07211746, 3.07290083, 3.07368289, 3.07446363,
3.07524307, 3.07602121, 3.07679805, 3.07757359, 3.07834785,
3.07912083, 3.07989252, 3.08066294, 3.08143208, 3.08219996,
3.08296657, 3.08373193, 3.08449603, 3.08525888, 3.08602048,
3.08678084, 3.08753996, 3.08829785, 3.0890545 , 3.08980993,
3.09056413, 3.09131712, 3.09206889, 3.09281945, 3.09356881,
3.09431696, 3.09506391, 3.09580966, 3.09655423, 3.0972976 ,
3.0980398 , 3.09878081, 3.09952064, 3.1002593 , 3.1009968 ,
3.10173313, 3.10246829, 3.1032023 , 3.10393515, 3.10466685,
3.10539741, 3.10612682, 3.10685509, 3.10758223, 3.10830823,
3.1090331 , 3.10975685, 3.11047947, 3.11120098, 3.11192137,
3.11264065, 3.11335882, 3.11407588, 3.11479184, 3.11550671,
3.11622047, 3.11693315, 3.11764474, 3.11835524, 3.11906467,
3.11977301, 3.12048028, 3.12118647, 3.1218916 , 3.12259566,
3.12329866, 3.12400061, 3.12470149, 3.12540132, 3.12610011,
3.12679784, 3.12749454, 3.12819019, 3.12888481, 3.12957839,
3.13027094, 3.13096246, 3.13165296, 3.13234244, 3.1330309 ,
3.13371834, 3.13440477, 3.13509019, 3.13577461, 3.13645802,
3.13714042, 3.13782184, 3.13850225, 3.13918168, 3.13986011,
3.14053756, 3.14121402, 3.14188951, 3.14256402, 3.14323755,
3.14391011, 3.1445817 , 3.14525232, 3.14592198, 3.14659068,
3.14725842, 3.14792521, 3.14859104, 3.14925592, 3.14991986,
3.15058285, 3.15124489, 3.151906 , 3.15256618, 3.15322541,
3.15388372, 3.1545411 , 3.15519755, 3.15585308, 3.15650769,
3.15716138, 3.15781415, 3.15846601, 3.15911696, 3.15976701,
3.16041614, 3.16106438, 3.16171171, 3.16235814, 3.16300368,
3.16364833, 3.16429209, 3.16493495, 3.16557694, 3.16621803,
3.16685825, 3.16749759, 3.16813606, 3.16877365, 3.16941036,
3.17004621, 3.1706812 , 3.17131532, 3.17194858, 3.17258097,
3.17321252, 3.1738432 , 3.17447304, 3.17510202, 3.17573016,
3.17635745, 3.17698389, 3.1776095 , 3.17823427, 3.1788582 ,
3.1794813 , 3.18010356, 3.18072499, 3.1813456 , 3.18196538,
3.18258434, 3.18320248, 3.1838198 , 3.1844363 , 3.18505199,
3.18566686, 3.18628093, 3.18689418, 3.18750663, 3.18811828,
3.18872913, 3.18933917, 3.18994842, 3.19055687, 3.19116453,
3.1917714 , 3.19237747, 3.19298277, 3.19358727, 3.19419099,
3.19479393, 3.1953961 , 3.19599748, 3.19659809, 3.19719793,
3.19779699, 3.19839529, 3.19899282, 3.19958959, 3.20018559,
3.20078083, 3.20137531, 3.20196903, 3.202562 , 3.20315421,
3.20374567, 3.20433639, 3.20492635, 3.20551557, 3.20610404,
3.20669178, 3.20727877, 3.20786502, 3.20845054, 3.20903532,
3.20961937, 3.21020269, 3.21078528, 3.21136714, 3.21194828,
3.21252869, 3.21310838, 3.21368735, 3.2142656 , 3.21484313,
3.21541995, 3.21599606, 3.21657146, 3.21714614, 3.21772012,
3.21829339, 3.21886596, 3.21943783, 3.22000899, 3.22057946,
3.22114922, 3.22171829, 3.22228667, 3.22285436, 3.22342135,
3.22398766, 3.22455328, 3.22511821, 3.22568246, 3.22624602,
3.22680891, 3.22737112, 3.22793264, 3.2284935 , 3.22905368,
3.22961318, 3.23017202, 3.23073018, 3.23128768, 3.23184451,
3.23240068, 3.23295619, 3.23351103, 3.23406521, 3.23461874,
3.23517161, 3.23572382, 3.23627538, 3.23682629, 3.23737655,
3.23792615, 3.23847512, 3.23902343, 3.2395711 , 3.24011813,
3.24066452, 3.24121026, 3.24175537, 3.24229984, 3.24284368,
3.24338688, 3.24392945, 3.24447139, 3.2450127 , 3.24555338,
3.24609343, 3.24663286, 3.24717167, 3.24770985, 3.24824741,
3.24878436, 3.24932068, 3.24985639, 3.25039148, 3.25092596,
3.25145983, 3.25199309, 3.25252574, 3.25305778, 3.25358921,
3.25412004, 3.25465026, 3.25517988, 3.2557089 , 3.25623732,
3.25676514, 3.25729236, 3.25781899, 3.25834503, 3.25887047,
3.25939532, 3.25991957, 3.26044324, 3.26096633, 3.26148882,
3.26201073, 3.26253206, 3.2630528 , 3.26357296, 3.26409255,
3.26461155, 3.26512998, 3.26564783, 3.2661651 , 3.2666818 ,
3.26719793, 3.26771349, 3.26822848, 3.2687429 , 3.26925676,
3.26977004, 3.27028277, 3.27079493, 3.27130652, 3.27181756,
3.27232803, 3.27283795, 3.27334731, 3.27385612, 3.27436437,
3.27487206, 3.27537921, 3.2758858 , 3.27639184, 3.27689733,
3.27740228, 3.27790668, 3.27841053, 3.27891384, 3.27941661,
3.27991883, 3.28042051, 3.28092166, 3.28142227, 3.28192233,
3.28242187, 3.28292087, 3.28341933, 3.28391726, 3.28441466,
3.28491153, 3.28540787, 3.28590369, 3.28639897, 3.28689373,
3.28738797, 3.28788168, 3.28837487, 3.28886754, 3.28935968,
3.28985131, 3.29034242, 3.29083302, 3.29132309, 3.29181266,
3.29230171, 3.29279024, 3.29327827, 3.29376578, 3.29425279,
3.29473928, 3.29522527, 3.29571075, 3.29619573, 3.2966802 ,
3.29716418, 3.29764764, 3.29813061, 3.29861308, 3.29909505,
3.29957652, 3.30005749, 3.30053797, 3.30101795, 3.30149744,
3.30197644, 3.30245494, 3.30293295, 3.30341048, 3.30388751,
3.30436406, 3.30484012, 3.3053157 , 3.30579078, 3.30626539,
3.30673951, 3.30721316, 3.30768632, 3.308159 , 3.3086312 ,
3.30910292, 3.30957417, 3.31004494, 3.31051523, 3.31098506,
3.3114544 , 3.31192328, 3.31239169, 3.31285962, 3.31332709,
3.31379408, 3.31426061, 3.31472668, 3.31519227, 3.3156574 ]), array([0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
0.7, 0.7, 0.7, 0.7, 0.7, 0.7]), [2, 0, 1, 2, 1, 0, 2, 1, 1, 2, 1, ` 1, 2, 1, 0, 2, 0, 1, 0, 0, 0, 2, 2, 2, 0, 2, 2, 2, 2, 0, 0, 2, 1,
1, 2, 2, 1, 0, 1, 0, 2, 1, 1, 0, 1, 1, 1, 2, 0, 1, 0, 1, 2, 0, 1,
0, 0, 0, 2, 2, 0, 0, 2, 2, 1, 2, 1, 1, 2, 0, 2, 2, 2, 0, 2, 0, 0,
1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 0, 1, 1, 1, 1, 2, 1, 0,
0, 2, 1, 2, 0, 2, 0, 2, 2, 0, 1, 0, 2, 1, 0, 2, 1, 0, 0, 1, 0])`