Inception V3 переучил модель загрузки весов - PullRequest
0 голосов
/ 10 сентября 2018

Я работаю с переобученной моделью Inception V3. Я пытаюсь использовать свои данные .data (веса) и текущую контрольную точку для прогнозирования новых данных, но я не знаю, куда вставить эту информацию.

Я использовал учебник, чтобы получить прогнозы для базовых классов начальных классов, теперь я хотел бы делать прогнозы для классов, на которых я тренировался.

from datasets import imagenet
from tensorflow.contrib import slim
from nets import inception

def predict(image, version='V3'):
tf.reset_default_graph()

# Process the image 
raw_image, processed_image = process_image(image)
class_names = imagenet.create_readable_names_for_imagenet_labels()

# Create a placeholder for the images
X = tf.placeholder(tf.float32, [None, 299, 299, 3], name="X")

'''
inception_v3 function returns logits and end_points dictionary
logits are output of the network before applying softmax activation
'''

if version.upper() == 'V3':
    model_ckpt_path = INCEPTION_V3_CKPT_PATH
    with slim.arg_scope(inception.inception_v3_arg_scope()):
        # Set the number of classes and is_training parameter  
        logits, end_points = inception.inception_v3(X, num_classes=1001, is_training=False)

predictions = end_points.get('Predictions', 'No key named predictions')
saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, model_ckpt_path)
    prediction_values = predictions.eval({X: processed_image})

try:
    # Add an index to predictions and then sort by probability
    prediction_values = [(i, prediction) for i, prediction in enumerate(prediction_values[0,:])]
    prediction_values = sorted(prediction_values, key=lambda x: x[1], reverse=True)

    # Plot the image
    plot_color_image(raw_image)
    plt.show()
    print("Using Inception_{} CNN\nPrediction: Probability\n".format(version))
    # Display the image and predictions 
    for i in range(10):
        predicted_class = class_names[prediction_values[i][0]]
        probability = prediction_values[i][1]
        print("{}: {:.2f}%".format(predicted_class, probability*100))

# If the predictions do not come out right
except:
    print(predictions)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...