Я изо всех сил пытаюсь использовать библиотеку сокращения Tensorflow и не нашел много полезных примеров, поэтому я ищу помощь в сокращении простой модели, обученной на наборе данных MNIST. Если кто-то может помочь исправить мою попытку или привести пример использования библиотеки в MNIST, я был бы очень признателен.
Первая половина моего кода довольно стандартна, за исключением того, что моя модель имеет 2 скрытых слоя шириной 300 единиц, используя layers.masked_fully_connected
для сокращения.
import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data
# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])
# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 300)
logits = tf.contrib.layers.fully_connected(layer2, 10, tf.nn.relu)
# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))
# Training op
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)
# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
Затем я пытаюсь определить необходимые операции сокращения, но получаю ошибку.
############ Pruning Operations ##############
# Create global step variable
global_step = tf.contrib.framework.get_or_create_global_step()
# Create a pruning object using the pruning specification
pruning_hparams = pruning.get_pruning_hparams()
p = pruning.Pruning(pruning_hparams, global_step=global_step)
# Mask Update op
mask_update_op = p.conditional_mask_update_op()
# Set up the specification for model pruning
prune_train = tf.contrib.model_pruning.train(train_op=train_op, logdir=None, mask_update_op=mask_update_op)
Ошибка в этой строке:
prune_train = tf.contrib.model_pruning.train(train_op=train_op, logdir=None, mask_update_op=mask_update_op)
InvalidArgumentError (см. Выше для отслеживания): вы должны передать значение для тензора заполнителя 'Placeholder_1' с dtype float и shape [?, 10]
[[Узел: Placeholder_1 = Placeholderdtype = DT_FLOAT, shape = [?, 10], _device = "/ job: localhost / replica: 0 / task: 0 / device: GPU: 0"]]
[[Узел: global_step / _57 = _Recv_start_time = 0, client_terminated = false, recv_device = "/ job: localhost / replica: 0 / task: 0 / device: CPU: 0", send_device = "/ job: localhost / replica: 0 / task: 0 / device: GPU: 0 ", send_device_incarnation = 1, тензор_имя =" edge_71_global_step ", тензор_тип = DT_INT64, _device =" / job: localhost / реплика: 0 / задача: 0 / устройство: CPU: 0 "]]
Я предполагаю, что вместо train_op требуется другой тип операции, но я не нашел никаких корректировок, которые бы работали.
Опять же, если у вас есть другой рабочий пример, который сокращает модель, обученную на MNIST, я бы посчитал это ответом.