У меня есть следующая модель lstm:
class LSTM_model():
def __init__(self):
w2v_model = gensim.models.Word2Vec(sentences, size=150, window=10, min_count=2, workers=10)
pretrained_weights = w2v_model.wv.syn0
vocab_size, emdedding_size = pretrained_weights.shape
self.w2v_model = w2v_model
self.keras_lstm_model = Sequential()
self.keras_lstm_model.add(Embedding(input_dim = vocab_size, output_dim = emdedding_size, weights = [pretrained_weights]))
self.keras_lstm_model.add(LSTM(units = emdedding_size))
self.keras_lstm_model.add(Dense(units = vocab_size))
self.keras_lstm_model.add(Activation('sigmoid'))
self.keras_lstm_model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['mae','acc'])
Моя цель - предсказать вероятность данного слова в контексте.У меня есть список предложений, и я хочу обучить этой модели:
def train_lstm_model(self, sentences):
sentences_as_indexes = []
filtered_sentences = list(filter(lambda x : all([w in self.w2v_model.wv.vocab for w in x]) , sentences)) #filter to take only sentences with no OOV words
for sentence in filtered_sentences:
if all([w in self.w2v_model.wv.vocab for w in sentence]): #todo use filter
indexes_row = []
for word in sentence:
idx = self.w2v_model.wv.vocab.get(word).index
indexes_row.append(idx)
sentences_as_indexes.append(indexes_row)
X = pd.DataFrame([sentence[:-1] for sentence in sentences_as_indexes])
y = pd.DataFrame([sentence[-1] for sentence in sentences_as_indexes])
print(datetime.datetime.now(), ": Fitting LSTM model , size of X is ", X.shape)
self.keras_lstm_model.fit(X, y) #HERE the error
Где предложения - это список из примерно 1 млн. Предложений (мой тренировочный набор).Я получаю следующую ошибку в Fit ():
tenorflow.python.framework.errors_impl.InvalidArgumentError: indices [7,38] = -2147483648 не находится в [0, 694415) [[Node: embedding_1 / embedding_lookup = GatherV2 [Taxis = DT_INT32, Tindices = DT_INT32, Tparams = DT_FLOAT, _class = ["loc: @ training / Adam / Assign_2"], _device = "/ job: localhost / replica: 0 / task: 0 / task: 0 /устройство: ЦП: 0 "] (embedding_1 / embeddings / read, embedding_1 / Cast, обучение / Adam / градиенты / embedding_1 / embedding_lookup_grad / concat / axis)]]
Я думаю, это может быть потому, что яу меня есть данные, которые являются OOV, но, как вы можете видеть - я отфильтровал предложения, чтобы взять только те, у которых есть все слова в словаре.
Что может вызвать эту ошибку?
Спасибо!