Функция Keras's expand_dims заставляет тензоры терять метаданные - PullRequest
0 голосов
/ 25 января 2019

У меня проблемы с использованием функции Keras's expand_dims.Вот простой пример:

Этот код работает:

import tensorflow as tf
from tensorflow.python.keras.layers import Input, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.backend import expand_dims

def add_fun(x):
  return tf.add(x[0], x[1])

in_1 = Input(shape=(None, None, 8))
in_2 = Input(shape=(None, 1, 1))

out = Lambda(add_fun)([in_1, in_2])

m = Model([in_1, in_2], out)

И этот код не:

import tensorflow as tf
from tensorflow.python.keras.layers import Input, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.backend import expand_dims

def add_fun(x):
  return tf.add(x[0], x[1])

in_1 = Input(shape=(None, None, 8))
in_2 = Input(shape=(None, 1))

problem_part = expand_dims(in_2, axis=1)

out = Lambda(add_fun)([in_1, problem_part])

m = Model([in_1, in_2], out)

Как показано здесь Я считаю,что я правильно использую expand_dims, и я не могу понять, почему это вызывает проблему.

Ответы [ 2 ]

0 голосов
/ 25 января 2019

Эту проблему можно решить, обернув вызов функции expand_dims в лямбду:

import tensorflow as tf
from tensorflow.python.keras.layers import Input, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.backend import expand_dims

def add_fun(x):
  return tf.add(x[0], x[1])

in_1 = Input(shape=(None, None, 8))
in_2 = Input(shape=(None, 1))

problem_part = Lambda(lambda x: expand_dims(x, axis=1))(in_2)

out = Lambda(add_fun)([in_1, problem_part])

m = Model([in_1, in_2], out)
0 голосов
/ 25 января 2019

Проблема в том, что expand_dims не является слоем Keras.Если вместо этого вы поместите вызов на expand_dims внутри вашего лямбда-слоя, он должен работать.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...