Как мне найти идентификатор статьи прогнозируемого ярлыка в модели пропуска грамма? - PullRequest
1 голос
/ 20 июня 2020

Мне нужно найти идентификатор статьи в таблицах модели skip-gram. Может ли кто-нибудь помочь мне получить идентификатор статьи из этого кода.

tokenizer = text.Tokenizer()
tokenizer.fit_on_texts(titles)
word2id = tokenizer.word_index

id2word = {v:k for k, v in word2id.items()}
wids = [[word2id[w] for w in text.text_to_word_sequence(doc)] for doc in titles]

vocab_size = len(word2id) + 1
embed_size = 100

print('Vocabulary Size:', vocab_size)
print('Vocabulary Sample:', list(word2id.items())[:30])

skip_grams = [skipgrams(wid, vocabulary_size=vocab_size, window_size=10) for wid in wids]

pairs, labels = skip_grams[0][0], skip_grams[0][1]
for i in range(10):
    print("({:s} ({:d}), {:s} ({:d})) -> {:d}".format(
          id2word[pairs[i][0]], pairs[i][0], 
          id2word[pairs[i][1]], pairs[i][1], 
          labels[i]))

word_model = Sequential()
word_model.add(Embedding(vocab_size, embed_size, embeddings_initializer="glorot_uniform", input_length=1))
word_model.add(Reshape((embed_size, )))

context_model = Sequential()
context_model.add(Embedding(vocab_size, embed_size, embeddings_initializer="glorot_uniform", input_length=1))
context_model.add(Reshape((embed_size,)))

model1 = Sequential()
model1.add(Merge([word_model, context_model], mode="dot"))
model1.add(Dense(1, kernel_initializer="glorot_uniform", activation="sigmoid"))
model1.compile(loss="mean_squared_error", optimizer="rmsprop")

# view model summary
print(model1.summary())

for epoch in range(1, 6):
    loss = 0
    for i, elem in enumerate(skip_grams):
        pair_first_elem = np.array(list(zip(*elem[0]))[0], dtype='int32')
        pair_second_elem = np.array(list(zip(*elem[0]))[1], dtype='int32')
        labels = np.array(elem[1], dtype='int32')

        test.X = [pair_first_elem, pair_second_elem]
        test.Y = labels
        #if i % 10000 == 0:
            #print('Processed {} (skip_first, skip_second, relevance) pairs'.format(i))
        loss += model1.train_on_batch(test.X, test.Y) 

    print('Epoch:', epoch, 'Loss:', loss)

Мне нужно обучить идентификатор статьи для этих предсказанных ярлыков. Как я могу получить из этого идентификаторы статей?

...