Я пытаюсь сгенерировать заголовок изображения, используя преобразователи из набора данных Coco. Застрял здесь. Поэтому мне действительно нужна помощь с этим кодом:
model = CaptionGenerator(word_to_idx, dim_feature=[196, 512], dim_embed=512,
dim_hidden=1024, n_time_step=16, prev2out=True,
ctx2out=True, alpha_c=1.0, selector=True, dropout=True)
print("Done with model")
solver = CaptioningSolver(model, data, val_data, n_epochs=20, batch_size=128, update_rule='adam',
learning_rate=0.001, print_every=1000, save_every=1, image_path='./image/',
pretrained_model=None, model_path='model/lstm/', test_model='model/lstm/model-10',
print_bleu=True, log_path='log/')
print("Done with solver")
solver.train()
Так что после вызова solver.train () из solver.py
class CaptioningSolver(object):
def __init__(self, model, data, val_data, **kwargs):
def train(self):
# train/val dataset
# Changed this because I keep less features than captions, see prepro
# n_examples = self.data['captions'].shape[0]
n_examples = self.data['features'].shape[0]
n_iters_per_epoch = int(np.ceil(float(n_examples)/self.batch_size))
features = self.data['features']
captions = self.data['captions']
image_idxs = self.data['image_idxs']
val_features = self.val_data['features']
n_iters_val = int(np.ceil(float(val_features.shape[0])/self.batch_size))
# build graphs for training model and sampling captions
# This scope fixed things!!
with tf.variable_scope(tf.get_variable_scope()):
print("HELLOOOOOOOOOOOOOO")
loss = self.model.build_model() **#ERROR COMES AFTER THIS**
tf.get_variable_scope().reuse_variables()
_, _, generated_captions = self.model.build_sampler(max_len=20)
........
........
Итак, после
print("HELLOOOOOOOOOOOOOO")
loss = self.model.build_model() **#ERROR COMES AFTER THIS**
я получаю следующую ошибку:
c_op = c_api.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes must be equal rank, but are 2 and 0
From merging shape 0 with other shapes. for 'encoder/enc_pe/Tile/multiples_1' (op: 'Pack') with input shapes: [196,512], [].```