Я реализую модель скип-граммы в системе федеративного обучения. Я получаю входные данные и метку следующим образом:
train_inputs_embed = tf.nn.embedding_lookup(variables.weights, batch['target_id'])
train_labels = tf.reshape(batch['context_id'], [-1, 1])
Когда я определяю потери следующим образом
loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(weights=variables.nce_weights,
biases=variables.bias,
inputs=train_inputs_embed,
labels=train_labels,
num_sampled=5,
num_true=1,
num_classes=vocab_size))
Я получаю следующую ошибку
ValueError: Shape must be rank 2 but is rank 3 for 'sampled_softmax_loss/concat_4' (op: 'ConcatV2') with input shapes: [?,1], [?,?,5], [].
Но следующий код (взят из раздела eval функции sampled_softmax_loss) работает для тех же входов и меток !!
logits = tf.matmul(train_inputs_embed, tf.transpose(variables.nce_weights))
logits = tf.nn.bias_add(logits, variables.bias)
labels_one_hot = tf.one_hot(train_labels, vocab_size)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels_one_hot, logits=logits))
Как решить эту проблему?