Проблема с предсказаниями модели Keras - PullRequest
0 голосов
/ 02 марта 2020

Я обучил двунаправленную модель GRU в керасе для многоклассовой классификации текста. При прогнозировании я получаю несколько ярлыков вместо одного целевого ярлыка. Как я могу исправить эту проблему.

загрузка модели BiGRU

def load_biGRU():
    """
        Loads the GRU Model
    """

    fp = os.path.join(GOOGLE_DRV_PATH,MODELS_PATH,BEST_GRU_MODEL_FILE)
    best_model_path = str(np.load(fp))    

    json_file = open(GOOGLE_DRV_PATH + best_model_path + '.json', 'r')
    loaded_model_json = json_file.read()
    json_file.close()
    print(f"Model Info {loaded_model_json}")
    model = model_from_json(loaded_model_json)

    model.load_weights(GOOGLE_DRV_PATH + best_model_path+'.h5')

    model.summary()
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])
    return model

Вывод:

Model Info {"class_name": "Model", "config": {"name": "model_1", "layers": [{"name": "input_1", "class_name": "InputLayer", "config": {"batch_input_shape": [null, 390], "dtype": "int32", "sparse": false, "name": "input_1"}, "inbound_nodes": []}, {"name": "embedding_1", "class_name": "Embedding", "config": {"name": "embedding_1", "trainable": true, "batch_input_shape": [null, 390], "dtype": "float32", "input_dim": 40000, "output_dim": 100, "embeddings_initializer": {"class_name": "RandomUniform", "config": {"minval": -0.05, "maxval": 0.05, "seed": null}}, "embeddings_regularizer": null, "activity_regularizer": null, "embeddings_constraint": null, "mask_zero": false, "input_length": 390}, "inbound_nodes": [[["input_1", 0, 0, {}]]]}, {"name": "bidirectional_1", "class_name": "Bidirectional", "config": {"name": "bidirectional_1", "trainable": true, "layer": {"class_name": "GRU", "config": {"name": "gru_1", "trainable": true, "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "units": 298, "activation": "tanh", "recurrent_activation": "hard_sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.3679508555724801, "recurrent_dropout": 0.3679508555724801, "implementation": 1, "reset_after": false}}, "merge_mode": "concat"}, "inbound_nodes": [[["embedding_1", 0, 0, {}]]]}, {"name": "dropout_1", "class_name": "Dropout", "config": {"name": "dropout_1", "trainable": true, "rate": 0.22713711911220283, "noise_shape": null, "seed": null}, "inbound_nodes": [[["bidirectional_1", 0, 0, {}]]]}, {"name": "dense_1", "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "units": 18, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dropout_1", 0, 0, {}]]]}], "input_layers": [["input_1", 0, 0]], "output_layers": [["dense_1", 0, 0]]}, "keras_version": "2.2.0", "backend": "tensorflow"}
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 390)               0         
_________________________________________________________________
embedding_1 (Embedding)      (None, 390, 100)          4000000   
_________________________________________________________________
bidirectional_1 (Bidirection (None, 596)               713412    
_________________________________________________________________
dropout_1 (Dropout)          (None, 596)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 18)                10746     
=================================================================
Total params: 4,724,158
Trainable params: 4,724,158
Non-trainable params: 0
_________________________________________________________________
Loaded GRU model from disk
Built Model Graph
/usr/local/lib/python3.6/dist-packages/sklearn/base.py:318: UserWarning: Trying to unpickle estimator LabelEncoder from version 0.20.3 when using version 0.22.1. This might lead to breaking code or invalid results. Use at your own risk.
  UserWarning)

Код прогноза

MAX_SEQUENCE_LENGTH = 390

test_seq = 'The toxicological analysis of Ni and Cr in prescription food for special medical purposes and modified milk products for babies in infancy available in pharmacies.'
test_sentence_token = nlp.tokenizer(test_seq)
test_sentence_token = [token.text for token in test_sentence_token if not token.is_stop]    
print(f"Number of token {len(test_sentence_token)}")        
test_sentence_seq = tokenizer.texts_to_sequences(test_sentence_token)
test_sentence_pad = pad_sequences(test_sentence_seq, maxlen=MAX_SEQUENCE_LENGTH)

prediction = model.predict(test_sentence_pad)
print("Model Prediction")
for i in prediction:                
    print(lbl_encodr.inverse_transform([np.argmax(i)]))

Вывод прогноза: получение нескольких меток, равных номер test_sentence_token. Правильный ответ здесь: «Молоко и молочные продукты».

Number of token 17
Model Prediction
['Packaging']
['Packaging']
['Packaging']
['Packaging']
['Packaging']
['Packaging']
['Packaging']
['Packaging']
['Packaging']
['Packaging']
['Milk and dairy products']
['Packaging']
['Packaging']
['Packaging']
['Packaging']
['Packaging']
['Packaging']
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...