Я только начал заниматься машинным обучением с небольшим опытом программирования и хотел создать простую нейронную сеть персептрона, которая научится классифицировать набор данных MNIST di git без использования библиотек машинного обучения. В output_data я поместил первое и последнее 80 операторов print, которые выводит моя программа. Кажется, моя программа учится минимизировать функцию стоимости, но она застревает, угадывая одно и то же di git снова и снова. Я испортил скорость обучения и размеры партий, но они, похоже, ничего не улучшают. Если бы кто-то мог привести меня в правильном направлении, это было бы очень ценно. Спасибо
mnist_recognition
Это 4-уровневая нейронная сеть, которая классифицирует рукописные цифры в наборе данных mnist. Используя только чистые python и numpy, эта программа вычисляет градиентное снижение функции стоимости (∑ (фактическая - цель) ^ 2) по отношению к весам и соответственно изменяет веса. После каждой итерации программа печатает ди git обучающих данных, предположение программы и стоимость, связанную с этой итерацией. Импорт данных обучения MNIST
with np.load('mnist.npz') as data:
training_images = data['training_images']
training_labels = data['training_labels']
Настройка нейронной сети и определение сигмоидальной функции self.mtrx содержит нейроны на каждом уровне собственный вес, смещение, вес удержания града, значения смещения и градиента между уровнями L и L - 1
class NeuralNetwork:
def __init__(self, rows, columns=0):
self.mtrx = np.zeros((rows, 1))
self.weight = np.random.random((rows, columns)) / columns ** .5
self.bias = np.random.random((rows, 1)) * -1.0
self.grad = np.zeros((rows, columns))
def sigmoid(self):
return 1 / (1 + np.exp(-self.mtrx))
def sigmoid_derivative(self):
return self.sigmoid() * (1.0 - self.sigmoid())
Инициализация уровней нейронной сети
lvl_input = NeuralNetwork(784)
lvl_one = NeuralNetwork(200, 784)
lvl_two = NeuralNetwork(200, 200)
lvl_output = NeuralNetwork(10, 200)
Функции прямого и обратного распространения
def forward_prop():
lvl_one.mtrx = lvl_one.weight.dot(lvl_input.mtrx) + lvl_one.bias
lvl_two.mtrx = lvl_two.weight.dot(lvl_one.sigmoid()) + lvl_two.bias
lvl_output.mtrx = lvl_output.weight.dot(lvl_two.sigmoid()) + lvl_output.bias
def back_prop(actual):
val = np.zeros((10, 1))
val[actual] = 1
delta_3 = (lvl_output.sigmoid() - val) * lvl_output.sigmoid_derivative()
delta_2 = np.dot(lvl_output.weight.transpose(), delta_3) * lvl_two.sigmoid_derivative()
delta_1 = np.dot(lvl_two.weight.transpose(), delta_2) * lvl_one.sigmoid_derivative()
lvl_output.grad = lvl_two.sigmoid().transpose() * delta_3
lvl_two.grad = lvl_one.sigmoid().transpose() * delta_2
lvl_one.grad = lvl_input.sigmoid().transpose() * delta_1
Хранение данных mnist в массиве np.array
def make_image(c):
lvl_input.mtrx = training_images[c]
Оценка функции стоимости
def cost(actual):
val = np.zeros((10, 1))
val[actual] = 1
cost_val = (lvl_output.sigmoid() - val) ** 2
return np.sum(cost_val)
Вычитание градиентов из весов и инициализация скорости обучения
learning_rate = .01
def update():
lvl_output.weight -= learning_rate * lvl_output.grad
lvl_two.weight -= learning_rate * lvl_two.grad
lvl_one.weight -= learning_rate * lvl_one.grad
Обучение нейронной сети iter_1 равно количеству партий iter_2 равно числу итераций в одном пакете
iter_1 = 50
iter_2 = 100
for batch_num in range(iter_1):
update()
for batches in range(iter_2):
make_image(counter)
num = np.argmax(training_labels[counter])
counter += 1
forward_prop()
back_prop(num)
print("actual: ", num, " guess: ", np.argmax(lvl_output.mtrx), " cost", cost(num))
output_data
FIRST 80 ITERATIONS:
actual: 5 guess: 3 cost 8.967940654671088
actual: 0 guess: 3 cost 8.96727511953835
actual: 4 guess: 3 cost 8.966336311471029
actual: 1 guess: 3 cost 8.964614419297058
actual: 9 guess: 3 cost 8.969134701891605
actual: 2 guess: 3 cost 8.967053265932318
actual: 1 guess: 3 cost 8.964824848818395
actual: 3 guess: 3 cost 8.966473334609903
actual: 1 guess: 3 cost 8.960864501044062
actual: 4 guess: 3 cost 8.966927097539942
actual: 3 guess: 3 cost 8.96602960141387
actual: 5 guess: 3 cost 8.96457467709148
actual: 3 guess: 3 cost 8.966463452568336
actual: 6 guess: 3 cost 8.967170896271007
actual: 1 guess: 3 cost 8.961504554251428
actual: 7 guess: 3 cost 8.970226265002914
actual: 2 guess: 3 cost 8.966534186296752
actual: 8 guess: 3 cost 8.96806492904598
actual: 6 guess: 3 cost 8.963241663267867
actual: 9 guess: 3 cost 8.967891094208154
actual: 4 guess: 3 cost 8.968165257872185
actual: 0 guess: 3 cost 8.967495671691166
actual: 9 guess: 3 cost 8.967110016262358
actual: 1 guess: 3 cost 8.964392716554022
actual: 1 guess: 3 cost 8.965993742374005
actual: 2 guess: 3 cost 8.967551426336762
actual: 4 guess: 3 cost 8.963912501397779
actual: 3 guess: 3 cost 8.966729854711515
actual: 2 guess: 3 cost 8.967571805901548
actual: 7 guess: 3 cost 8.968031754047926
actual: 3 guess: 3 cost 8.965675580057647
actual: 8 guess: 3 cost 8.968461428875388
actual: 6 guess: 3 cost 8.965166939019545
actual: 9 guess: 3 cost 8.968763204750987
actual: 0 guess: 3 cost 8.967540507250032
actual: 5 guess: 3 cost 8.965545688959857
actual: 6 guess: 3 cost 8.967425028943891
actual: 0 guess: 3 cost 8.967566971035732
actual: 7 guess: 3 cost 8.969754175784066
actual: 6 guess: 3 cost 8.96702539598315
actual: 1 guess: 3 cost 8.96299163011006
actual: 8 guess: 3 cost 8.968175816089042
actual: 7 guess: 3 cost 8.966425056294776
actual: 9 guess: 3 cost 8.96796817338183
actual: 3 guess: 3 cost 8.963755408168
actual: 9 guess: 3 cost 8.96926567423336
actual: 8 guess: 3 cost 8.967543729824387
actual: 5 guess: 3 cost 8.967286499095575
actual: 9 guess: 3 cost 8.9677253773608
actual: 3 guess: 3 cost 8.966335253428326
actual: 3 guess: 3 cost 8.962829459684784
actual: 0 guess: 3 cost 8.966443407799728
actual: 7 guess: 3 cost 8.969485491531145
actual: 4 guess: 3 cost 8.964159055804105
actual: 9 guess: 3 cost 8.968054200103934
actual: 8 guess: 3 cost 8.96719386034473
actual: 0 guess: 3 cost 8.966374739396157
actual: 9 guess: 3 cost 8.9673694447568
actual: 4 guess: 3 cost 8.966879451409914
actual: 1 guess: 3 cost 8.963085409100401
actual: 4 guess: 3 cost 8.96659585831308
actual: 4 guess: 3 cost 8.964656458614465
actual: 6 guess: 3 cost 8.965997487130116
actual: 0 guess: 3 cost 8.966455019673488
actual: 4 guess: 3 cost 8.966295463866858
actual: 5 guess: 3 cost 8.964316168401316
actual: 6 guess: 3 cost 8.965707649845031
actual: 1 guess: 3 cost 8.962325088384468
actual: 0 guess: 3 cost 8.965286965834165
actual: 0 guess: 3 cost 8.966383201903987
actual: 1 guess: 3 cost 8.964628836235496
actual: 7 guess: 3 cost 8.968386161233427
actual: 1 guess: 3 cost 8.959224945536565
actual: 6 guess: 3 cost 8.965609436078736
actual: 3 guess: 3 cost 8.964347604784498
actual: 0 guess: 3 cost 8.96611237840382
actual: 2 guess: 3 cost 8.965381094981499
actual: 1 guess: 3 cost 8.963120677127996
actual: 1 guess: 3 cost 8.96405639510175
actual: 7 guess: 3 cost 8.968933638290096
LAST 80 ITERATIONS:
actual: 5 guess: 7 cost 1.1211637067627063
actual: 6 guess: 7 cost 1.0552531331006683
actual: 6 guess: 7 cost 1.0554754137400155
actual: 1 guess: 7 cost 1.1111335796511572
actual: 7 guess: 7 cost 0.39288606314850105
actual: 0 guess: 7 cost 0.9725346172858359
actual: 8 guess: 7 cost 1.1328622780173858
actual: 7 guess: 7 cost 0.3948858827858138
actual: 9 guess: 7 cost 1.1062621497616252
actual: 6 guess: 7 cost 1.0553258043668636
actual: 0 guess: 7 cost 0.9722536736320497
actual: 9 guess: 7 cost 1.104398056849663
actual: 1 guess: 7 cost 1.1107715879340798
actual: 6 guess: 7 cost 1.0548862811967128
actual: 2 guess: 7 cost 1.0591572140435124
actual: 9 guess: 7 cost 1.1061966300353054
actual: 3 guess: 7 cost 1.0898127633513548
actual: 2 guess: 7 cost 1.060007348683317
actual: 4 guess: 7 cost 1.0866881038452423
actual: 5 guess: 7 cost 1.121417891094558
actual: 5 guess: 7 cost 1.120467850157213
actual: 1 guess: 7 cost 1.108298490272241
actual: 6 guess: 7 cost 1.0553087806403634
actual: 1 guess: 7 cost 1.110088687169048
actual: 7 guess: 7 cost 0.39320440442417653
actual: 7 guess: 7 cost 0.393599379050548
actual: 8 guess: 7 cost 1.132814048747761
actual: 5 guess: 7 cost 1.121939678599679
actual: 9 guess: 7 cost 1.106025451446512
actual: 1 guess: 0 cost 0.9622564124884784
actual: 0 guess: 0 cost 0.8593530649130317
actual: 4 guess: 0 cost 0.9393396290138495
actual: 1 guess: 0 cost 0.9622940421810162
actual: 4 guess: 0 cost 0.9404822496691384
actual: 2 guess: 0 cost 0.9178125638868906
actual: 3 guess: 0 cost 0.9426691285338776
actual: 3 guess: 0 cost 0.942690753734257
actual: 1 guess: 0 cost 0.9623297861049644
actual: 4 guess: 0 cost 0.940371478405245
actual: 1 guess: 0 cost 0.9623031268176
actual: 5 guess: 0 cost 0.9712347696568674
actual: 0 guess: 0 cost 0.8590451041238242
actual: 6 guess: 0 cost 0.9144392962122468
actual: 7 guess: 0 cost 0.8587038959856634
actual: 7 guess: 0 cost 0.8595722910216409
actual: 3 guess: 0 cost 0.9426103963458455
actual: 8 guess: 0 cost 0.9818405405702578
actual: 9 guess: 0 cost 0.8595962707758777
actual: 5 guess: 0 cost 0.9714964202087537
actual: 1 guess: 0 cost 0.9624568499960074
actual: 1 guess: 0 cost 0.9618997440804723
actual: 1 guess: 0 cost 0.9616207680615205
actual: 9 guess: 0 cost 0.8592443630955677
actual: 5 guess: 0 cost 0.9713378210566505
actual: 9 guess: 0 cost 0.8594114362413322
actual: 1 guess: 0 cost 0.9618546105579062
actual: 7 guess: 0 cost 0.8588493542347546
actual: 1 guess: 0 cost 0.9618718521886025
actual: 1 guess: 0 cost 0.9617515805652713
actual: 6 guess: 0 cost 0.9136159964857046
actual: 0 guess: 0 cost 0.8592350961109494
actual: 8 guess: 0 cost 0.9818162514673088
actual: 9 guess: 0 cost 0.8595943721384651
actual: 7 guess: 0 cost 0.8594020275011919
actual: 0 guess: 0 cost 0.8594831625021346
actual: 2 guess: 0 cost 0.9166105269490908
actual: 5 guess: 0 cost 0.9715520240455258
actual: 3 guess: 0 cost 0.9426616801200285
actual: 9 guess: 0 cost 0.859581449204677
actual: 6 guess: 0 cost 0.9142265327401935
actual: 7 guess: 0 cost 0.8588630292850988
actual: 8 guess: 0 cost 0.981801914298837
actual: 1 guess: 0 cost 0.9618275889195599
actual: 0 guess: 0 cost 0.8587978160230371
actual: 7 guess: 0 cost 0.8591755187382748
actual: 3 guess: 0 cost 0.9427927022073027
actual: 2 guess: 0 cost 0.9175924937090146
actual: 1 guess: 0 cost 0.9608604565831466
actual: 2 guess: 0 cost 0.9171708777221477
--- 3.403271198272705 seconds ---