NumPy - Возможно ли повысить производительность вычислений взаимно рекурсивных массивов? - PullRequest
0 голосов
/ 09 апреля 2020

Рассмотрим следующий модуль -

""" simple nn """
import numpy as np

TRAINING_SET = np.array([[0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
OUTPUT = np.array([[0, 0, 1, 1]]).T

def train(epochs=100000):
    """ train a single layer neural network """
    rng = np.random.default_rng()
    weights = 2 * rng.random((3, 1)) - 1

    for _ in range(epochs):
        layer = 1 /(1 + np.exp(-np.dot(TRAINING_SET, weights)))
        weights += np.dot(TRAINING_SET.T, (OUTPUT - layer) * layer * (1 - layer))
    print("Number of epochs:", epochs)
    print("layer\n", layer)
    print("weights\n", weights)

, поскольку layer и weights являются взаимно рекурсивными, я вынужден вычислять их, используя for l oop. Есть ли более эффективный способ?

1 Ответ

1 голос
/ 09 апреля 2020

Вот подход, который полностью избавляет от for l oop. Обратите внимание, что оптимальное значение weights необязательно должно быть однозначно определено для этой задачи, но оптимальное значение layer может быть восстановлено.

Я предполагаю, что при epochs уходит в бесконечность, weights и layer сходятся к некоторым предельным значениям (как в случае с этим конкретным c примером). Если это так, мы ожидаем, что приращение weights будет равно нулю. Итак, у нас есть следующая система уравнений:

0 == np.dot(TRAINING_SET.T, (OUTPUT - layer) * layer * (1 - layer))
layer == 1 /(1 + np.exp(-np.dot(TRAINING_SET, weights)))

Используя немного алгебры, мы получаем

OUTPUT == 1/(1 + np.exp(-TRAINING_SET @ weights))

Обратите внимание, что это уравнение не может выполняться точно, так как OUTPUT имеет некоторое ноль элементов. Мы можем получить численное приближение следующим образом:

err = 1e-8
w = np.linalg.pinv(TRAINING_SET) @ -np.log(err + 1/(err + OUTPUT)-1)

Сравнение результатов для layer:

ws, lr = train(100000)

err = 1e-8
ws2 = np.linalg.pinv(TRAINING_SET) @ -np.log(err + 1/(err + OUTPUT)-1)
lr2 = 1 /(1 + np.exp(-np.dot(TRAINING_SET, ws2)))

print(lr - lr2)
# [[ 0.00301781]
#  [ 0.00246096]
#  [-0.00200824]
#  [-0.0024629 ]]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...