Как использовать предварительно обученную модель кераса для вывода в tf.data.Dataset.map? - PullRequest
1 голос
/ 18 марта 2019

У меня есть предварительно обученная модель, и я пытаюсь построить другую модель, которая принимает в качестве входных данных выходные данные предыдущей модели. Я не хочу обучать модели из конца в конец, и хочу использовать первую модель только для вывода. Первая модель была обучена с использованием конвейера tf.data.Dataset, и я сначала хотел объединить модель как еще одну dataset.map() операцию в хвостовой части конвейера, но у меня возникли проблемы с этим. Я столкнулся с 20 различными ошибками в процессе, каждая из которых не связана с предыдущей. Слой нормализации партии, в частности, кажется, больно для этого.

Ниже приведен минимальный начальный пример, иллюстрирующий проблему. Он написан на R, но ответы на Python также приветствуются.

Я использую tenorflow-GPU версии 1.13.1 и Keras из tf.keras

library(reticulate)
library(tensorflow)
library(keras)
library(tfdatasets)
use_implementation("tensorflow")

model_weights_path <- 'model-weights.h5'

arr <- function(...) 
  np_array(array(seq_len(prod(unlist(c(...)))), unlist(c(...))), dtype = 'float32')

new_model <- function(load_weights = TRUE) {
  model <- keras_model_sequential() %>% 
    layer_conv_1d(5, 5, activation = 'relu', input_shape = shape(150, 10)) %>%
    layer_batch_normalization() %>%
    layer_flatten() %>%
    layer_dense(10, activation = 'softmax')
  if (load_weights)
    load_model_weights_hdf5(model, model_weights_path)
  freeze_weights(model)
  model
}

if(!file.exists(model_weights_path)) {
  model <- new_model(FALSE) 
  save_model_weights_hdf5(model, model_weights_path)
}

model <- new_model()

data <- arr(20, 150, 10)
ds <- tfdatasets::tensors_dataset(data) %>% 
  dataset_repeat()

ds2 <- ds %>% 
  dataset_map(function(x) {
    model(x)
  })

try(nb <- next_batch(ds2))

sess <- k_get_session()
it <- make_iterator_initializable(ds2)
sess$run(iterator_initializer(it))
nb <- it$get_next()

try(sess$run(nb))

sess$run(tf$initialize_all_variables())

try(sess$run(nb))
...