Загрузка numpy весов в TensorFlow 2.0 - PullRequest
0 голосов
/ 11 января 2020

У меня есть архитектура нейронной сети для набора данных MNIST следующим образом -

def create_nn():
    """
    Function to create NN model for MNIST
    classification using 300 100 architecture
    """
    model = Sequential()
    model.add(l.InputLayer(input_shape = (784, )))
    model.add(Flatten())
    model.add(Dense(units = 300, activation='relu', kernel_initializer = tf.initializers.GlorotUniform()))
    # model.add(l.Dropout(0.2))
    model.add(Dense(units = 100, activation='relu', kernel_initializer = tf.initializers.GlorotUniform()))
    # model.add(l.Dropout(0.1))
    model.add(Dense(units = num_classes, activation='softmax'))

    # Compile designed NN-
    model.compile(
        loss = tf.keras.losses.categorical_crossentropy,
        # optimizer = 'adam',
        optimizer = tf.keras.optimizers.Adam(lr = 0.001),
        metrics = ['accuracy'])

    return model

# Insantiate a new NN model instance-
orig_model = create_nn()


# Load original weights from when designed model was initialized-
orig_model.load_weights("300_100_MNIST.h5")

type(orig_model.trainable_weights), len(orig_model.trainable_weights)
# (list, 6)

# Insantiate a new NN model instance-
pruned_model = create_nn()

# Load pruned weights AFTER pruning algorithm was applied to prune NN-
pruned_model.load_weights("300_100_Pruned_Model.h5")

Теперь я создаю список, в котором я обрабатываю весовые коэффициенты согласно некоторому критерию следующим образом -

# List to store extracted weights-
weight_extracted = []

for orig_wts, pruned_wts in zip(orig_model.trainable_weights, pruned_model.trainable_weights):
    c = np.where(pruned_wts == 0, pruned_wts, orig_wts)
    weight_extracted.append(c)
    del c


len(weight_extracted)
# 6

Как я могу использовать веса / смещения в списке numpy массивов 'weight_extracted' для загрузки весов в NN, как определено выше?

Спасибо!

1 Ответ

1 голос
/ 11 января 2020

1-Вы можете преобразовать список weight_extracted = [] в массив

2- Сохранить этот массив в виде файла h5 с модулем h5py

3 - Загрузите извлеченные веса снова и обучите свою сеть с вашими новыми весами!

Это мои шаги, если есть заблуждение или недоразумение, пожалуйста, дайте мне знать.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...