Недавно я пытался понять, как работают нейронные сети.Я нашел библиотеку JavaScript для одного нейрона / персептрона, которую можно найти здесь .Он использует библиотеку JavaScript p5.js для реализации графики.Есть 3 файла, но важный - perceptron.js
, который содержит следующий код.
// Daniel Shiffman
// The Nature of Code
// http://natureofcode.com
// Simple Perceptron Example
// See: http://en.wikipedia.org/wiki/Perceptron
// Perceptron Class
// Perceptron is created with n weights and learning constant
class Perceptron {
constructor(n, c) {
// Array of weights for inputs
this.weights = new Array(n);
// Start with random weights
for (let i = 0; i < this.weights.length; i++) {
this.weights[i] = random(-1, 1);
}
this.c = c; // learning rate/constant
}
// Function to train the Perceptron
// Weights are adjusted based on "desired" answer
train(inputs, desired) {
// Guess the result
let guess = this.feedforward(inputs);
// Compute the factor for changing the weight based on the error
// Error = desired output - guessed output
// Note this can only be 0, -2, or 2
// Multiply by learning constant
let error = desired - guess;
// Adjust weights based on weightChange * input
for (let i = 0; i < this.weights.length; i++) {
this.weights[i] += this.c * error * inputs[i];
}
}
// Guess -1 or 1 based on input values
feedforward(inputs) {
// Sum all values
let sum = 0;
for (let i = 0; i < this.weights.length; i++) {
sum += inputs[i] * this.weights[i];
}
// Result is sign of the sum, -1 or 1
return this.activate(sum);
}
activate(sum) {
if (sum > 0) return 1;
else return -1;
}
// Return weights
getWeights() {
return this.weights;
}
}
Я написал класс C #, пытающийся работать в основном так же, вот код.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace NeuralNetwork
{
class Neuron
{
Random random = new Random();
//Helper functions
private double getRandomDouble(double minimum, double maximum) {
return random.NextDouble() * (maximum - minimum) + minimum;
}
private double getWeightedSum() {
double sum = 0;
//Sums all input values with their matching weights
for (int i = 0; i < inputs.Count; i++) {
double inputSum = inputs[i] * weights[i];
sum += inputSum;
}
return sum;
}
private double getActivationValue(double input) {
//Decides what output the neuron should give
//The current activation function gets the sign of the input
//Change this code to have the neuron run on a different activation function
if (input > 0) {
return 1;
}
return -1;
}
public double getError(double prediction, double expectedOutput) {
//Finds error of prediction
double error = expectedOutput - prediction;
return error;
}
public List<double> getWeights() {
return weights;
}
List<double> inputs = new List<double>();
List<double> weights = new List<double>();
double learningRate = 0.01;
public Neuron(int inputs, double learningRate = 0.01) {
Random rng = new Random();
//Set random weights for inputs between -1 and 1
for (int i = 0; i < inputs; i++) {
double weight = getRandomDouble(-1, 1);
weights.Add(weight);
}
}
public double predict(List<double> neuronInputs) {
//Sets the global input variable so the helper functions can access the inputs
inputs = neuronInputs;
//Gets the weighted sum of the inputs
double weightedSum = getWeightedSum();
//Finds the output of the neuron based on the activation function
//The output value is a double so the neuron supports returning non-boolean values
double output = getActivationValue(weightedSum);
return output;
}
public double train(List<double> trainingInputs, double expectedOutput) {
//Predicts the output given the inputs
double output = predict(inputs);
//Checks how accurate the prediction is
double error = getError(output, expectedOutput);
for (int i = 0; i < weights.Count; i++) {
//Train the neural network by tweaking its weights
double modification = trainingInputs[i] * error * learningRate;
weights[i] += modification;
}
return error;
}
}
}
В настоящее время я обучаю его реализации функции ИЛИ, но когда я запускаю программу, средняя ошибка (из 100 поездов) составляет 1,5 и остается такой же навсегда.Глядя на веса, я вижу, что они просто продолжают расти и, насколько я могу судить, стремятся к бесконечности.Тем не менее, я попытался сделать то же самое с версией JavaScript, и она отлично работает, при этом средняя ошибка снижается до 0,04 в течение 3 поколений.Я просматриваю код уже несколько дней и не могу найти ничего, что могло бы привести к другому результату.Кто-нибудь видит, почему эти классы дают различный вывод?