Я обучил простую нейронную сеть с использованием Keras и хочу использовать весовые коэффициенты для разработки нейронной сети с прямой связью в MATLAB (у меня нет набора инструментов Deep Leaning Toolbox, и мне нужно использовать некоторые функции обработки изображений из MATLAB). Предсказания, сделанные функцией model.predict () из Keras, верны, однако те, которые сделаны моим кодом MATLAB, не верны.
Мне было интересно, используют ли модели нейросетей Keras дополнительные шаги для прогнозирования, потому что я уверен, что X_test в Python и X_train в MATLAB одинаковы. Кроме того, я уверен, что веса одинаковы в обеих программах.
Python:
# Libraries
%tensorflow_version 2.x
import tensorflow as tf
import numpy as np
# Train model
model = tf.keras.Sequential([
tf.keras.layers.Dense(units=200, activation="relu", input_shape=(100,)),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(units=100, activation="relu"),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(units=1, activation='sigmoid')
])
model.summary()
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=['accuracy'])
history = model.fit(X_train, y_train, validation_split=0.20, epochs=100)
# Save weights
model = tf.keras.models.load_model(model_path)
weights = model.get_weights()
w0 = weights[0]; b0 = weights[1]
w1 = weights[0]; b1 = weights[1]
w2 = weights[0]; b2 = weights[1]
np.savetxt('w0.csv', w0, delimiter=',')
np.savetxt('b0.csv', b0, delimiter=',')
np.savetxt('w1.csv', w1, delimiter=',')
np.savetxt('b1.csv', b1, delimiter=',')
np.savetxt('w2.csv', w2, delimiter=',')
np.savetxt('b2.csv', b2, delimiter=',')
# Predict
Y_pred = model.predict(X_test)
MATLAB:
% Load weights
w0 = csvread('w0.csv');
w1 = csvread('w1.csv');
w2 = csvread('w2.csv');
model.weight = {w0, w1, w2};
b0 = csvread('b0.csv');
b1 = csvread('b1.csv');
b2 = csvread('b2.csv');
model.bias = {b0, b1, b2};
% Predict
Z1 = model.weight{1}' * X_test + model.bias{1};
A1 = relu(Z1);
Z2 = model.weight{2}' * A1 + model.bias{2};
A2 = relu(Z2);
Z3 = model.weight{3}' * A2 + model.bias{3};
Y_pred = sigmoid(Z3);
Функции MATLAB
function y = relu(x)
y = x .* (x >= 0);
end
function y = sigmoid(x)
y = 1 ./ (1 + exp(-x));
end