tenorflow.python.framework.errors_impl.InvalidArgumentError: в модели KERAS LSTM - PullRequest
0 голосов
/ 02 октября 2018

У меня есть следующая модель lstm:

class LSTM_model():

def __init__(self):
    w2v_model = gensim.models.Word2Vec(sentences, size=150, window=10, min_count=2, workers=10)
    pretrained_weights = w2v_model.wv.syn0
    vocab_size, emdedding_size = pretrained_weights.shape
    self.w2v_model = w2v_model
    self.keras_lstm_model = Sequential()
    self.keras_lstm_model.add(Embedding(input_dim = vocab_size, output_dim = emdedding_size, weights = [pretrained_weights]))
    self.keras_lstm_model.add(LSTM(units = emdedding_size))
    self.keras_lstm_model.add(Dense(units = vocab_size))
    self.keras_lstm_model.add(Activation('sigmoid'))
    self.keras_lstm_model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['mae','acc'])

Моя цель - предсказать вероятность данного слова в контексте.У меня есть список предложений, и я хочу обучить этой модели:

   def train_lstm_model(self, sentences):

        sentences_as_indexes = []


        filtered_sentences = list(filter(lambda x : all([w in self.w2v_model.wv.vocab for w in x]) , sentences)) #filter to take only sentences with no OOV words
        for sentence in filtered_sentences:
            if all([w in self.w2v_model.wv.vocab for w in sentence]): #todo use filter
                indexes_row = []
                for word in sentence:

                    idx = self.w2v_model.wv.vocab.get(word).index
                    indexes_row.append(idx)

                sentences_as_indexes.append(indexes_row)
        X = pd.DataFrame([sentence[:-1] for sentence in sentences_as_indexes])
        y = pd.DataFrame([sentence[-1] for sentence in sentences_as_indexes])
        print(datetime.datetime.now(), ": Fitting LSTM model , size of X is ", X.shape)

        self.keras_lstm_model.fit(X, y) #HERE the error

Где предложения - это список из примерно 1 млн. Предложений (мой тренировочный набор).Я получаю следующую ошибку в Fit ():

tenorflow.python.framework.errors_impl.InvalidArgumentError: indices [7,38] = -2147483648 не находится в [0, 694415) [[Node: embedding_1 / embedding_lookup = GatherV2 [Taxis = DT_INT32, Tindices = DT_INT32, Tparams = DT_FLOAT, _class = ["loc: @ training / Adam / Assign_2"], _device = "/ job: localhost / replica: 0 / task: 0 / task: 0 /устройство: ЦП: 0 "] (embedding_1 / embeddings / read, embedding_1 / Cast, обучение / Adam / градиенты / embedding_1 / embedding_lookup_grad / concat / axis)]]

Я думаю, это может быть потому, что яу меня есть данные, которые являются OOV, но, как вы можете видеть - я отфильтровал предложения, чтобы взять только те, у которых есть все слова в словаре.

Что может вызвать эту ошибку?

Спасибо!

...