Загрузка весов Keras в нейронной сети MATLAB дает другие прогнозы, чем функция Keras model.predict () - PullRequest
0 голосов
/ 22 января 2020

Я обучил простую нейронную сеть с использованием Keras и хочу использовать весовые коэффициенты для разработки нейронной сети с прямой связью в MATLAB (у меня нет набора инструментов Deep Leaning Toolbox, и мне нужно использовать некоторые функции обработки изображений из MATLAB). Предсказания, сделанные функцией model.predict () из Keras, верны, однако те, которые сделаны моим кодом MATLAB, не верны.

Мне было интересно, используют ли модели нейросетей Keras дополнительные шаги для прогнозирования, потому что я уверен, что X_test в Python и X_train в MATLAB одинаковы. Кроме того, я уверен, что веса одинаковы в обеих программах.

Python:

# Libraries
%tensorflow_version 2.x
import tensorflow as tf
import numpy as np

# Train model
model = tf.keras.Sequential([
    tf.keras.layers.Dense(units=200, activation="relu", input_shape=(100,)),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(units=100, activation="relu"),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(units=1, activation='sigmoid')
])

model.summary()
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=['accuracy'])
history = model.fit(X_train, y_train, validation_split=0.20, epochs=100)

# Save weights    
model = tf.keras.models.load_model(model_path)
weights = model.get_weights()
w0 = weights[0]; b0 = weights[1]
w1 = weights[0]; b1 = weights[1]
w2 = weights[0]; b2 = weights[1]

np.savetxt('w0.csv', w0, delimiter=',')
np.savetxt('b0.csv', b0, delimiter=',')
np.savetxt('w1.csv', w1, delimiter=',')
np.savetxt('b1.csv', b1, delimiter=',')
np.savetxt('w2.csv', w2, delimiter=',')
np.savetxt('b2.csv', b2, delimiter=',')

# Predict
Y_pred = model.predict(X_test)

MATLAB:

% Load weights
w0 = csvread('w0.csv');
w1 = csvread('w1.csv');
w2 = csvread('w2.csv');
model.weight = {w0, w1, w2};

b0 = csvread('b0.csv');
b1 = csvread('b1.csv');
b2 = csvread('b2.csv');
model.bias = {b0, b1, b2};

% Predict
Z1 = model.weight{1}' * X_test + model.bias{1};
A1 = relu(Z1);

Z2 = model.weight{2}' * A1 + model.bias{2};
A2 = relu(Z2);

Z3 = model.weight{3}' * A2 + model.bias{3};
Y_pred = sigmoid(Z3);

Функции MATLAB

function y = relu(x)
y = x .* (x >= 0);
end

function y = sigmoid(x)
y = 1 ./ (1 + exp(-x));
end
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...