4-слойная нейронная сеть персептрона, классифицирующая набор данных MNIST di git, работает неправильно - PullRequest
0 голосов
/ 17 марта 2020

Я только начал заниматься машинным обучением с небольшим опытом программирования и хотел создать простую нейронную сеть персептрона, которая научится классифицировать набор данных MNIST di git без использования библиотек машинного обучения. В output_data я поместил первое и последнее 80 операторов print, которые выводит моя программа. Кажется, моя программа учится минимизировать функцию стоимости, но она застревает, угадывая одно и то же di git снова и снова. Я испортил скорость обучения и размеры партий, но они, похоже, ничего не улучшают. Если бы кто-то мог привести меня в правильном направлении, это было бы очень ценно. Спасибо

mnist_recognition

Это 4-уровневая нейронная сеть, которая классифицирует рукописные цифры в наборе данных mnist. Используя только чистые python и numpy, эта программа вычисляет градиентное снижение функции стоимости (∑ (фактическая - цель) ^ 2) по отношению к весам и соответственно изменяет веса. После каждой итерации программа печатает ди git обучающих данных, предположение программы и стоимость, связанную с этой итерацией. Импорт данных обучения MNIST

with np.load('mnist.npz') as data:
    training_images = data['training_images']
    training_labels = data['training_labels']

Настройка нейронной сети и определение сигмоидальной функции self.mtrx содержит нейроны на каждом уровне собственный вес, смещение, вес удержания града, значения смещения и градиента между уровнями L и L - 1

class NeuralNetwork:

    def __init__(self, rows, columns=0):
        self.mtrx = np.zeros((rows, 1))
        self.weight = np.random.random((rows, columns)) / columns ** .5
        self.bias = np.random.random((rows, 1)) * -1.0
        self.grad = np.zeros((rows, columns))

    def sigmoid(self):
        return 1 / (1 + np.exp(-self.mtrx))

    def sigmoid_derivative(self):
        return self.sigmoid() * (1.0 - self.sigmoid())

Инициализация уровней нейронной сети

lvl_input = NeuralNetwork(784)
lvl_one = NeuralNetwork(200, 784)
lvl_two = NeuralNetwork(200, 200)
lvl_output = NeuralNetwork(10, 200)

Функции прямого и обратного распространения

def forward_prop():
    lvl_one.mtrx = lvl_one.weight.dot(lvl_input.mtrx) + lvl_one.bias
    lvl_two.mtrx = lvl_two.weight.dot(lvl_one.sigmoid()) + lvl_two.bias
    lvl_output.mtrx = lvl_output.weight.dot(lvl_two.sigmoid()) + lvl_output.bias


def back_prop(actual):
    val = np.zeros((10, 1))
    val[actual] = 1

    delta_3 = (lvl_output.sigmoid() - val) * lvl_output.sigmoid_derivative()
    delta_2 = np.dot(lvl_output.weight.transpose(), delta_3) * lvl_two.sigmoid_derivative()
    delta_1 = np.dot(lvl_two.weight.transpose(), delta_2) * lvl_one.sigmoid_derivative()

    lvl_output.grad = lvl_two.sigmoid().transpose() * delta_3
    lvl_two.grad = lvl_one.sigmoid().transpose() * delta_2
    lvl_one.grad = lvl_input.sigmoid().transpose() * delta_1

Хранение данных mnist в массиве np.array

def make_image(c): 
    lvl_input.mtrx = training_images[c]

Оценка функции стоимости

def cost(actual):
    val = np.zeros((10, 1))
    val[actual] = 1
    cost_val = (lvl_output.sigmoid() - val) ** 2
    return np.sum(cost_val)

Вычитание градиентов из весов и инициализация скорости обучения

learning_rate = .01

def update():
    lvl_output.weight -= learning_rate * lvl_output.grad
    lvl_two.weight -= learning_rate * lvl_two.grad
    lvl_one.weight -= learning_rate * lvl_one.grad

Обучение нейронной сети iter_1 равно количеству партий iter_2 равно числу итераций в одном пакете

iter_1 = 50
iter_2 = 100

for batch_num in range(iter_1):
    update()
    for batches in range(iter_2):
        make_image(counter)
        num = np.argmax(training_labels[counter])
        counter += 1
        forward_prop()
        back_prop(num)
        print("actual: ", num, "     guess: ", np.argmax(lvl_output.mtrx), "     cost", cost(num))

output_data

FIRST 80 ITERATIONS:

actual:  5      guess:  3      cost 8.967940654671088
actual:  0      guess:  3      cost 8.96727511953835
actual:  4      guess:  3      cost 8.966336311471029
actual:  1      guess:  3      cost 8.964614419297058
actual:  9      guess:  3      cost 8.969134701891605
actual:  2      guess:  3      cost 8.967053265932318
actual:  1      guess:  3      cost 8.964824848818395
actual:  3      guess:  3      cost 8.966473334609903
actual:  1      guess:  3      cost 8.960864501044062
actual:  4      guess:  3      cost 8.966927097539942
actual:  3      guess:  3      cost 8.96602960141387
actual:  5      guess:  3      cost 8.96457467709148
actual:  3      guess:  3      cost 8.966463452568336
actual:  6      guess:  3      cost 8.967170896271007
actual:  1      guess:  3      cost 8.961504554251428
actual:  7      guess:  3      cost 8.970226265002914
actual:  2      guess:  3      cost 8.966534186296752
actual:  8      guess:  3      cost 8.96806492904598
actual:  6      guess:  3      cost 8.963241663267867
actual:  9      guess:  3      cost 8.967891094208154
actual:  4      guess:  3      cost 8.968165257872185
actual:  0      guess:  3      cost 8.967495671691166
actual:  9      guess:  3      cost 8.967110016262358
actual:  1      guess:  3      cost 8.964392716554022
actual:  1      guess:  3      cost 8.965993742374005
actual:  2      guess:  3      cost 8.967551426336762
actual:  4      guess:  3      cost 8.963912501397779
actual:  3      guess:  3      cost 8.966729854711515
actual:  2      guess:  3      cost 8.967571805901548
actual:  7      guess:  3      cost 8.968031754047926
actual:  3      guess:  3      cost 8.965675580057647
actual:  8      guess:  3      cost 8.968461428875388
actual:  6      guess:  3      cost 8.965166939019545
actual:  9      guess:  3      cost 8.968763204750987
actual:  0      guess:  3      cost 8.967540507250032
actual:  5      guess:  3      cost 8.965545688959857
actual:  6      guess:  3      cost 8.967425028943891
actual:  0      guess:  3      cost 8.967566971035732
actual:  7      guess:  3      cost 8.969754175784066
actual:  6      guess:  3      cost 8.96702539598315
actual:  1      guess:  3      cost 8.96299163011006
actual:  8      guess:  3      cost 8.968175816089042
actual:  7      guess:  3      cost 8.966425056294776
actual:  9      guess:  3      cost 8.96796817338183
actual:  3      guess:  3      cost 8.963755408168
actual:  9      guess:  3      cost 8.96926567423336
actual:  8      guess:  3      cost 8.967543729824387
actual:  5      guess:  3      cost 8.967286499095575
actual:  9      guess:  3      cost 8.9677253773608
actual:  3      guess:  3      cost 8.966335253428326
actual:  3      guess:  3      cost 8.962829459684784
actual:  0      guess:  3      cost 8.966443407799728
actual:  7      guess:  3      cost 8.969485491531145
actual:  4      guess:  3      cost 8.964159055804105
actual:  9      guess:  3      cost 8.968054200103934
actual:  8      guess:  3      cost 8.96719386034473
actual:  0      guess:  3      cost 8.966374739396157
actual:  9      guess:  3      cost 8.9673694447568
actual:  4      guess:  3      cost 8.966879451409914
actual:  1      guess:  3      cost 8.963085409100401
actual:  4      guess:  3      cost 8.96659585831308
actual:  4      guess:  3      cost 8.964656458614465
actual:  6      guess:  3      cost 8.965997487130116
actual:  0      guess:  3      cost 8.966455019673488
actual:  4      guess:  3      cost 8.966295463866858
actual:  5      guess:  3      cost 8.964316168401316
actual:  6      guess:  3      cost 8.965707649845031
actual:  1      guess:  3      cost 8.962325088384468
actual:  0      guess:  3      cost 8.965286965834165
actual:  0      guess:  3      cost 8.966383201903987
actual:  1      guess:  3      cost 8.964628836235496
actual:  7      guess:  3      cost 8.968386161233427
actual:  1      guess:  3      cost 8.959224945536565
actual:  6      guess:  3      cost 8.965609436078736
actual:  3      guess:  3      cost 8.964347604784498
actual:  0      guess:  3      cost 8.96611237840382
actual:  2      guess:  3      cost 8.965381094981499
actual:  1      guess:  3      cost 8.963120677127996
actual:  1      guess:  3      cost 8.96405639510175
actual:  7      guess:  3      cost 8.968933638290096

LAST 80 ITERATIONS:

actual:  5      guess:  7      cost 1.1211637067627063
actual:  6      guess:  7      cost 1.0552531331006683
actual:  6      guess:  7      cost 1.0554754137400155
actual:  1      guess:  7      cost 1.1111335796511572
actual:  7      guess:  7      cost 0.39288606314850105
actual:  0      guess:  7      cost 0.9725346172858359
actual:  8      guess:  7      cost 1.1328622780173858
actual:  7      guess:  7      cost 0.3948858827858138
actual:  9      guess:  7      cost 1.1062621497616252
actual:  6      guess:  7      cost 1.0553258043668636
actual:  0      guess:  7      cost 0.9722536736320497
actual:  9      guess:  7      cost 1.104398056849663
actual:  1      guess:  7      cost 1.1107715879340798
actual:  6      guess:  7      cost 1.0548862811967128
actual:  2      guess:  7      cost 1.0591572140435124
actual:  9      guess:  7      cost 1.1061966300353054
actual:  3      guess:  7      cost 1.0898127633513548
actual:  2      guess:  7      cost 1.060007348683317
actual:  4      guess:  7      cost 1.0866881038452423
actual:  5      guess:  7      cost 1.121417891094558
actual:  5      guess:  7      cost 1.120467850157213
actual:  1      guess:  7      cost 1.108298490272241
actual:  6      guess:  7      cost 1.0553087806403634
actual:  1      guess:  7      cost 1.110088687169048
actual:  7      guess:  7      cost 0.39320440442417653
actual:  7      guess:  7      cost 0.393599379050548
actual:  8      guess:  7      cost 1.132814048747761
actual:  5      guess:  7      cost 1.121939678599679
actual:  9      guess:  7      cost 1.106025451446512
actual:  1      guess:  0      cost 0.9622564124884784
actual:  0      guess:  0      cost 0.8593530649130317
actual:  4      guess:  0      cost 0.9393396290138495
actual:  1      guess:  0      cost 0.9622940421810162
actual:  4      guess:  0      cost 0.9404822496691384
actual:  2      guess:  0      cost 0.9178125638868906
actual:  3      guess:  0      cost 0.9426691285338776
actual:  3      guess:  0      cost 0.942690753734257
actual:  1      guess:  0      cost 0.9623297861049644
actual:  4      guess:  0      cost 0.940371478405245
actual:  1      guess:  0      cost 0.9623031268176
actual:  5      guess:  0      cost 0.9712347696568674
actual:  0      guess:  0      cost 0.8590451041238242
actual:  6      guess:  0      cost 0.9144392962122468
actual:  7      guess:  0      cost 0.8587038959856634
actual:  7      guess:  0      cost 0.8595722910216409
actual:  3      guess:  0      cost 0.9426103963458455
actual:  8      guess:  0      cost 0.9818405405702578
actual:  9      guess:  0      cost 0.8595962707758777
actual:  5      guess:  0      cost 0.9714964202087537
actual:  1      guess:  0      cost 0.9624568499960074
actual:  1      guess:  0      cost 0.9618997440804723
actual:  1      guess:  0      cost 0.9616207680615205
actual:  9      guess:  0      cost 0.8592443630955677
actual:  5      guess:  0      cost 0.9713378210566505
actual:  9      guess:  0      cost 0.8594114362413322
actual:  1      guess:  0      cost 0.9618546105579062
actual:  7      guess:  0      cost 0.8588493542347546
actual:  1      guess:  0      cost 0.9618718521886025
actual:  1      guess:  0      cost 0.9617515805652713
actual:  6      guess:  0      cost 0.9136159964857046
actual:  0      guess:  0      cost 0.8592350961109494
actual:  8      guess:  0      cost 0.9818162514673088
actual:  9      guess:  0      cost 0.8595943721384651
actual:  7      guess:  0      cost 0.8594020275011919
actual:  0      guess:  0      cost 0.8594831625021346
actual:  2      guess:  0      cost 0.9166105269490908
actual:  5      guess:  0      cost 0.9715520240455258
actual:  3      guess:  0      cost 0.9426616801200285
actual:  9      guess:  0      cost 0.859581449204677
actual:  6      guess:  0      cost 0.9142265327401935
actual:  7      guess:  0      cost 0.8588630292850988
actual:  8      guess:  0      cost 0.981801914298837
actual:  1      guess:  0      cost 0.9618275889195599
actual:  0      guess:  0      cost 0.8587978160230371
actual:  7      guess:  0      cost 0.8591755187382748
actual:  3      guess:  0      cost 0.9427927022073027
actual:  2      guess:  0      cost 0.9175924937090146
actual:  1      guess:  0      cost 0.9608604565831466
actual:  2      guess:  0      cost 0.9171708777221477
--- 3.403271198272705 seconds ---
...