Как реализовать предварительное обучение в Tensorflow?Как частично использовать сохраненные веса из файла контрольных точек? - PullRequest
0 голосов
/ 12 сентября 2018

Для удобства обсуждения следующие модели были упрощены.

Допустим, в моем тренировочном наборе около 40 000 изображений 512x512. Я пытаюсь провести предварительное обучение, и мой план следующий:

1. Обучите нейронную сеть (назовем ее net_1), которая принимает изображения размером 256x256, и сохраните обученную модель в формате файла контрольной точки тензорного потока.

net_1: input -> 3 conv2d -> maxpool2d -> 2 conv2d -> rmspool -> flatten -> dense

давайте назовем эту структуру net_1_kernel:

net_1_kernel: 3 conv2d -> maxpool2d -> 3 conv2d

и вызов оставшейся части other_layers:

other_layers: rmspool -> flatten -> dense

Таким образом, мы можем представить net_1 в следующей форме:

net_1: input -> net_1_kernel -> other_layers

2. Вставьте несколько слоев в структуру net_1, а теперь назовите его net_2. Это должно выглядеть так:

net_2: input -> net_1_kernel -> maxpool2d -> 3 conv2d -> other_layers

net_2 будет принимать 512x512 изображений в качестве входных данных.

Когда я тренируюсь net_2, я хотел бы использовать сохраненные веса и смещения в файле контрольных точек net_1 для инициализации части net_1_kernel в net_2. Как я могу это сделать?

Я знаю, что могу загружать контрольные точки для прогнозирования тестовых данных. Но в этом случае он загрузит все (net_1_kernel и other_layers). Я хочу загрузить только net_1_kernel и использовать его для инициализации веса / смещения в net_2.

Я также знаю, что могу печатать содержимое файлов контрольных точек в txt, а также копировать и вставлять для ручной инициализации весов и смещений. Однако в этих весах и смещениях так много цифр, и это был бы мой последний выбор.

1 Ответ

0 голосов
/ 12 сентября 2018

Прежде всего, вы можете использовать следующий код для проверки списка всех контрольных точек в сохраненном вами файле ckpt.

from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file(file_name="file.ckpt", tensor_name="xxx", all_tensors=False, all_tensor_names=True)

Помните, что при восстановлении файла контрольных точек будут восстановлены все переменные в файле контрольных точек. Если вам нужно сохранить и восстановить определенные переменные, вы можете сделать это следующим образом:

  1. Составьте список всех переменных, которые вы хотите сохранить из tf.trainable_variables()

var = [v for v in tf.trainable_variables() if "net_1_kernel" in v.name]

saverAndRestore = tf.train.Saver(var)

  1. Теперь вы можете легко сохранить или восстановить все переменные в списке переменных следующим образом:

saverAndRestore.save(sess_1,"net_1.ckpt")

Это сохранит только переменные в списке var в net_1.ckpt.

saverAndRestore.restore(sess_1,"net_1.ckpt")

Это восстановит только переменные в списке var из net_1.ckpt.

Кроме вышеперечисленного, все, что вам нужно сделать, - это присвоить имена / область видимости вашим переменным, чтобы вы могли легко выполнить шаг 1 выше.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...