Я хочу применить приложение netVlad к набору данных MNist, где я использую Keras 2.1. Однако, когда я использую код ниже, я получаю следующую ошибку:
AttributeError: у объекта 'NoneType' нет атрибута '_inbound_nodes'
Вот мой код, и я знаю, что я делаю неправильно:
import numpy as np
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Dense,Dropout, Flatten, Activation, Lambda, concatenate
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.utils import np_utils
from keras import backend as K
from keras.losses import categorical_crossentropy as logloss
from keras.metrics import categorical_accuracy
K.set_image_dim_ordering('th')
import tensorflow as tf
def matconvnetNormalize(__inputs, epsilon):
return __inputs / tf.sqrt(tf.reduce_sum(__inputs ** 2, axis=-1, keepdims=True) + epsilon)
def net_VLAD(inputs, num_clusters, assign_weight_initializer=None, cluster_initializer=None, skip_postnorm=False):
_inputs = inputs
K = num_clusters
D = _inputs.get_shape()[-1]
s = tf.layers.conv2d(inputs=_inputs,filters=K, kernel_size=1, use_bias=False, kernel_initializer=None,name='assignment')
a = tf.nn.softmax(s)
a = tf.expand_dims(a, -2)
# VLAD core.
C = tf.get_variable('cluster_centers', [1, 1, 1, D, K],
initializer=None,dtype=inputs.dtype)
v = tf.expand_dims(inputs, -1) + C
v = a * v
v = tf.reduce_sum(v, axis=[1, 2])
v = tf.transpose(v, perm=[0, 2, 1])
if not skip_postnorm:
v = matconvnetNormalize(v, 1e-12)
v = tf.transpose(v, perm=[0, 2, 1])
v = matconvnetNormalize(tf.layers.flatten(v), 1e-12)
return v
from keras.layers import Input
tf.reset_default_graph()
num_classes = 10
num_clusters = 3
skip_postnorm = False
input = Input(shape=(1, 28,28))
x1 = Conv2D(30,(5,5), input_shape = (1,28,28),activation='relu') (input)
print(x1.shape)
x1 = net_VLAD(x1,4)
print(x1.shape)
x1 = tf.nn.l2_normalize(tf.layers.flatten(x1), dim=-1)
print(x1.shape)
x2 = Dense(50, activation = 'relu')(x1)
print(x2.shape)
x2 = Dense(10,activation = 'softmax')(x2)
print(x2.shape)
model = Model(input, x2)
model.summary()
Ошибка возникает, когда есть «модель = Модель (вход, х2)», где х2 имеет форму (?, 10)
Что мне не хватает?
Большое спасибо заранее,
Andi