У меня некоторое время была проблема. Я реализовал нейронную сеть в python. Но когда есть тест, он всегда дает один и тот же прогноз с почти одинаковыми выходными активациями нейронов. Я добавил свой код здесь. Стоит отметить одну вещь; я не учел добавлять веса смещения к своим нейронам, может ли такое упущение вызвать такое поведение? Любая помощь в исправлении этой ошибки будет принята с благодарностью. Следует также отметить, что при печати результатов, полученных во время обучения, я вижу, что сеть учится предсказывать класс, но каким-то образом компенсирует каждый раз, когда изучается новый класс.
Используемая мной функция активации:
def sigmoid(x, d=False):
if(d):
try:
ans = exp(x)/(exp(x)+1)**2
except OverflowError:
ans = 0
return ans
else:
try:
ans = 1/(1+exp(-x))
except OverflowError:
ans = 0
return ans
Мой класс нейронов:
class Neuron:
def __init__ (self, layer, posInLayer,fromNeurons=[]):
self.layer = layer
self.posInLayer = posInLayer
self.weights = {}
self.activation = 0
self.input = 0
self.activationFunc = sigmoid
for i in range(len(fromNeurons)):
self.weights[fromNeurons[i]] = 1
def __str__(self, printWeights = False):
string = str(self.layer) + "," + str(self.posInLayer)
if printWeights:
string += "\n"
for neuronPair, weight in self.weights.items():
string += str(neuronPair) + ": " + str(weight) + "\n"
return string
Мой фактический класс нейронных сетей:
class NeuralNetwork:
def __init__(self, hiddenLayers, neurons, learningRate, tolerance, epochs):
self.hiddenLayers = hiddenLayers
self.neurons = neurons
self.learningRate = learningRate
self.tolerance = tolerance
self.epochs = epochs
def BackPropFit(self, X, Y):
layers = list()
amountOfLayers = self.hiddenLayers + 2
for l in range(amountOfLayers):
layers.append(list())
if l == 0:
for n in range(len(X[0])):
layers[l].append(Neuron(l,n))
#if layer is hidden layer
elif l >= 1 and l < amountOfLayers - 1:
for n in range(self.neurons):
layers[l].append(Neuron(l,n,layers[l-1]))
#If layer is output layer
elif l == amountOfLayers - 1:
for n in range(len(set(Y))):
layers[l].append(Neuron(l,n,layers[l-1]))
#return layers
deltas = dict()
for epoch in range(self.epochs):
sum_error = 0
for layer in layers:
for neuron in layer:
for fromNeuron, weight in neuron.weights.items():
neuron.weights[fromNeuron] = random.uniform(0,1)
for x, y in zip(X, Y):
y = [1 if y == list(set(Y))[j] else 0 for j in range(len(set(Y)))]
for i, neuron in enumerate(layers[0]):
neuron.activation = x[i]
for layer in layers[1:]:
for neuron in layer:
neuron.input = numpy.dot(list(neuron.weights.values()),[prevNeuron.activation for prevNeuron in neuron.weights.keys()])
neuron.activation = neuron.activationFunc(x=neuron.input)
# print([neuron.activation for neuron in layers[-1]])
errors = []
for i, neuron in enumerate(layers[-1]):
error = (y[i] - neuron.activation)
deltas[neuron] = neuron.activationFunc(neuron.input,True) * error
errors.append(error**2)
sum_error = sum(errors)
for l in reversed(range(len(layers))[:len(layers)-1]):
for i, neuron in enumerate(layers[l]):
propogatedDeltas = sum([list(nextNeuron.weights.values())[i]*deltas[nextNeuron] for nextNeuron in layers[l+1]])
deltas[neuron] = neuron.activationFunc(neuron.input, True) * propogatedDeltas
for layer in layers:
for neuron in layer:
for fromNeuron, weight in neuron.weights.items():
neuron.weights[fromNeuron] += self.learningRate*fromNeuron.activation * deltas[neuron]
return layers
def predict(self, X, network):
printNetwork(network)
for x in X:
for i, neuron in enumerate(network[0]):
neuron.activation = x[i]
for layer in network[1:]:
for neuron in layer:
neuron.input = numpy.dot(list(neuron.weights.values()),[prevNeuron.activation for prevNeuron in neuron.weights.keys()])
neuron.activation = sigmoid(x=neuron.input)
print("prediction:",[neuron.activation for neuron in network[-1]])
Код, который я использую для проверки:
ann = NeuralNetwork(1,2,5,1,100)
network = ann.BackPropFit(iris["data"], iris["target"])
print(iris["data"][2])
ann.predict(iris["data"], network)
Выход из теста:
prediction: [0.024255859717782637, 0.03532975366223915,
0.9639015160152697]
prediction: [0.024393568293905234, 0.035502898993352856, 0.963720974247098]
prediction: [0.02451004336498253, 0.0356457712280952, 0.9635678802735339]
prediction: [0.024465239588270328, 0.03560214892471153, 0.9636283635781644]
prediction: [0.024279667671360996, 0.03536047196784654, 0.9638704031327368]
prediction: [0.02400901157253274, 0.03502626346845873, 0.9642264950283018]
prediction: [0.024425974738286334, 0.035545109257721266, 0.9636787172791237]
prediction: [0.024262238879338063, 0.03534249519380545, 0.963893818183035]
prediction: [0.02466773635720441, 0.0358540310767315, 0.9633628802689409]
prediction: [0.024375300130734513, 0.03548751946024018, 0.9637459912803531]
prediction: [0.024115317053070864, 0.03515606989292284, 0.9640863566542127]
prediction: [0.024294344150729553, 0.03538910142450513, 0.9638525988901198]
prediction: [0.024485731010517063, 0.035622259176694336, 0.9636007216825342]
prediction: [0.025029059850895416, 0.03628822544159749, 0.9628877952469731]
prediction: [0.024081668913163073, 0.0351045619133855, 0.964129227546738]
prediction: [0.02397027762795376, 0.034973398110216214, 0.9642767902032843]
prediction: [0.02410858011107248, 0.03513967701721254, 0.964094078031946]
prediction: [0.024223021553935794, 0.03528709372577303, 0.9639443960433707]
prediction: [0.02397953427688219, 0.03498903824130283, 0.9642651953628664]
prediction: [0.024160109056276687, 0.03521258067028777, 0.9640275933703838]
prediction: [0.024083521145228434, 0.035121640954813856, 0.9641288851919902]
prediction: [0.02414152531046206, 0.03518780503191443, 0.9640517833240132]
prediction: [0.02468691546062378, 0.03584807849682689, 0.9633335262649361]
prediction: [0.02409092107927921, 0.03512871013250146, 0.9641188530815591]
prediction: [0.02418636614981571, 0.03526329808112867, 0.9639956450220543]
prediction: [0.024267459772099324, 0.03535264321381683, 0.9638874785573647]
prediction: [0.02416396533598955, 0.03521941006562164, 0.9640228132529653]
prediction: [0.024187589958478434, 0.03524744188979909, 0.9639915756651218]
prediction: [0.024233121870813345, 0.03530047524463949, 0.9639312430263501]
prediction: [0.024358045457862187, 0.035470619338355, 0.963769286344032]
prediction: [0.024330330709885894, 0.035434248249062814, 0.9638054013264938]
prediction: [0.024090347539108534, 0.0351221538198101, 0.9641187795033697]
prediction: [0.02416142488344722, 0.03521696853048011, 0.9640262537655548]
prediction: [0.024080888057424747, 0.035110459577033334, 0.9641312274738191]
prediction: [0.02433434663508933, 0.03543362073762344, 0.9637993323512524]
prediction: [0.024422364693548006, 0.03552783306382556, 0.9636816422214397]
prediction: [0.024167823485810685, 0.0352141055733865, 0.9640163114152102]
prediction: [0.02435597960589166, 0.0354591198996296, 0.9637707331737857]
prediction: [0.02471713999260132, 0.035909080167326635, 0.9632972506841793]
prediction: [0.02422862460732356, 0.03529947650703185, 0.9639378020816225]
prediction: [0.024296744817078754, 0.03537541752333326, 0.9638470846053074]
prediction: [0.024730672830243335, 0.035919008540631216, 0.9632785462889063]
prediction: [0.024677519877481423, 0.03585982177571434, 0.9633491577330772]
prediction: [0.02410312577355621, 0.03514062352476337, 0.9641023440537351]
prediction: [0.024029982472530884, 0.035058660710511995, 0.9641998044256463]
prediction: [0.02439361895553987, 0.03550167233172319, 0.9637207250087128]
prediction: [0.024156585374043745, 0.03521301454402581, 0.9640329067413494]
prediction: [0.024503985258918632, 0.0356453912451017, 0.9635768395827203]
prediction: [0.024141300488670046, 0.03518912019761061, 0.9640523047561518]
prediction: [0.024314691441718307, 0.03540376902138404, 0.9638243731160379]
prediction: [0.02372552400352548, 0.0346770160493557, 0.9646002105328491]
prediction: [0.02373031557071889, 0.03468344203505816, 0.9645939625660097]
prediction: [0.023724720057590926, 0.03467602012178222, 0.9646012705001369]
prediction: [0.02376226384556137, 0.03472644853026213, 0.9645523310648307]
prediction: [0.02372971117381163, 0.03468268261434423, 0.9645947578987908]
prediction: [0.023743722802670144, 0.034702227205844635, 0.9645765950953863]
prediction: [0.023728576231065727, 0.03468123332696621, 0.964596248101018]
prediction: [0.023842733395069367, 0.034832658453214274, 0.9644472094143852]
prediction: [0.02373042055268845, 0.03468372805926829, 0.9645938462365139]
prediction: [0.023767606248058563, 0.034733460268436926, 0.9645453447786919]
prediction: [0.02382973165430536, 0.034816631762861164, 0.964464351253853]
prediction: [0.023740148432312955, 0.034696513879326034, 0.9645811252965748]
prediction: [0.02375774841405245, 0.03472039790715519, 0.9645582185071547]
prediction: [0.0237335499058697, 0.034688221136712986, 0.9645898077611403]
prediction: [0.023764972739380295, 0.03472863260852244, 0.9645485944886801]
prediction: [0.02372966125521753, 0.034682434754927445, 0.964594797378473]
prediction: [0.023740777709890696, 0.0346980246575686, 0.9645803992136832]
prediction: [0.023756809123586447, 0.03471960530215498, 0.9645595091907951]
prediction: [0.023735568140569537, 0.03469061541453586, 0.964587131969361]
prediction: [0.023766729026892, 0.034732455121653336, 0.9645465126016353]
prediction: [0.023729658752626222, 0.034682802004743364, 0.9645948531063304]
prediction: [0.02374436650427925, 0.034701901057295176, 0.96457558744992]
prediction: [0.023730001861871606, 0.034683400161220006, 0.9645944252514723]
prediction: [0.023736526954411568, 0.03469248019688715, 0.9645859637381398]
prediction: [0.023735308768573625, 0.03469008610063925, 0.9645874444725212]
prediction: [0.023730799224378245, 0.03468399146054118, 0.9645933178720807]
prediction: [0.023726940311122345, 0.034679043074788314, 0.9645983817836591]
prediction: [0.023724383440550213, 0.03467558263383218, 0.9646017114161664]
prediction: [0.023735486337211342, 0.03469059751542944, 0.964587251628381]
prediction: [0.023779428491116632, 0.03474796479720819, 0.9645297423667241]
prediction: [0.023774644003516383, 0.03474287920981778, 0.9645361676007138]
prediction: [0.02378272128278718, 0.0347535874633472, 0.964525620956143]
prediction: [0.023755898307353942, 0.0347174730356531, 0.9645605677153195]
prediction: [0.023729776988529103, 0.034683290824218574, 0.9645947456858103]
prediction: [0.02374426502368594, 0.03470293459949354, 0.9645758853361956]
prediction: [0.023732590915310697, 0.03468659488641728, 0.9645910100457628]
prediction: [0.02372689136795545, 0.034678890564995644, 0.9645984333054987]
prediction: [0.023737795072963728, 0.0346936230621232, 0.9645842313043707]
prediction: [0.023751641066468036, 0.03471227627900076, 0.9645661906269275]
prediction: [0.023760484060681413, 0.03472404812676005, 0.9645546494204638]
prediction: [0.02375333425919188, 0.03471555320482372, 0.9645641255533826]
prediction: [0.023734172168252583, 0.03468896179055658, 0.9645889831039745]
prediction: [0.0237543590937047, 0.03471567377043775, 0.9645626119455891]
prediction: [0.02383695401105701, 0.03482464780948338, 0.9644547034408611]
prediction: [0.023751684859054063, 0.03471257333161061, 0.9645661672682535]
prediction: [0.023749784887358058, 0.034710056278325796, 0.964568648776404]
prediction: [0.0237481343236713, 0.03470766975525208, 0.9645707762674685]
prediction: [0.023737937030808825, 0.034693724469876064, 0.9645840336146474]
prediction: [0.02383878991268152, 0.03482361378058185, 0.9644518162534279]
prediction: [0.02375069769578546, 0.03471091494544744, 0.964567406943979]
prediction: [0.023720039883220946, 0.03466978245020838, 0.9646073789364076]
prediction: [0.023728767224089068, 0.03468179189591337, 0.9645960418694743]
prediction: [0.02371966243565111, 0.03466923396087005, 0.9646078651467044]
prediction: [0.023723630797320202, 0.03467484040572994, 0.964602730667874]
prediction: [0.023720851330585276, 0.03467090188948923, 0.9646063252176186]
prediction: [0.023718413388652403, 0.0346675339753494, 0.964609490409128]
prediction: [0.023753727650669514, 0.0347160742587419, 0.9645636117480275]
prediction: [0.02371941717623683, 0.03466896214549964, 0.9646081930533433]
prediction: [0.023721934346867198, 0.03467246999519504, 0.9646049293355077]
prediction: [0.023718573070236632, 0.034667710211575886, 0.9646092768127832]
prediction: [0.023723213645058503, 0.03467397869945961, 0.96460323183917]
prediction: [0.023724028333076297, 0.03467522557425642, 0.9646021913406588]
prediction: [0.02372097307207111, 0.03467098817754113, 0.9646061555680777]
prediction: [0.023729883468048434, 0.03468322959264031, 0.9645945779601794]
prediction: [0.023725071894882776, 0.034676504737134156, 0.9646008135202075]
prediction: [0.02372182081003962, 0.03467209628427972, 0.9646050460382083]
prediction: [0.023722984479787424, 0.03467386912008089, 0.9646035586635962]
prediction: [0.02371805951243077, 0.03466703124724607, 0.9646099478897259]
prediction: [0.02371811135131814, 0.03466711147728538, 0.9646098818062666]
prediction: [0.023733035488396666, 0.034687795581113096, 0.964590515932891]
prediction: [0.02371981915998572, 0.03466940854245421, 0.9646076557354973]
prediction: [0.02373074333081823, 0.03468435601126866, 0.9645934529719334]
prediction: [0.023718407261591398, 0.034667533836148796, 0.9646094995425389]
prediction: [0.02372708534539599, 0.034679254100438366, 0.964598195003677]
prediction: [0.023720605725436522, 0.03467053581420134, 0.9646066402943239]
prediction: [0.02371981470218485, 0.03466948787345506, 0.9646076736255174]
prediction: [0.023728197379544364, 0.03468072228605853, 0.9645967416839549]
prediction: [0.023727872953395065, 0.034680380772658885, 0.9645971779662906]
prediction: [0.02372209037755073, 0.03467260263698762, 0.9646047150262003]
prediction: [0.023720711438961664, 0.03467073649647976, 0.9646065107825088]
prediction: [0.02371937065704421, 0.0346688575901199, 0.9646082477456422]
prediction: [0.023718209658743057, 0.034667230503704585, 0.9646097517964425]
prediction: [0.023721742390424204, 0.034672099379889715, 0.9646051636239896]
prediction: [0.023727974793327326, 0.03468075064743645, 0.9645970781984098]
prediction: [0.02372788091576389, 0.03468103323236982, 0.9645972584360801]
prediction: [0.023718473311688415, 0.034667568022497264, 0.9646094057110476]
prediction: [0.023720928259411658, 0.034670947712577053, 0.964606216783534]
prediction: [0.02372327852709839, 0.03467429356197021, 0.964603179489309]
prediction: [0.023729310959277147, 0.03468229643759765, 0.9645953010664292]
prediction: [0.02372085790665672, 0.034670796561876285, 0.9646063004830066]
prediction: [0.023720212878373732, 0.03466993161748542, 0.9646071416213993]
prediction: [0.023720992252433487, 0.034670867582526715, 0.9646061098429933]
prediction: [0.023728767224089068, 0.03468179189591337, 0.9645960418694743]
prediction: [0.023719710964299968, 0.03466929809277085, 0.9646078017300216]
prediction: [0.023719746671022557, 0.03466929909541947, 0.9646077485306644]
prediction: [0.023721364052458862, 0.03467141417485969, 0.9646056318048679]
prediction: [0.02372623302959114, 0.03467811360743693, 0.9645993067511291]
prediction: [0.023723107328123223, 0.03467387400459404, 0.9646033758388185]
prediction: [0.023722005460505638, 0.03467240379137924, 0.9646048137299832]
prediction: [0.02372813999256747, 0.03468094436462871, 0.9645968588469407]