Я нашел код в Интернете, чтобы получить производную от общих потерь по весам глубокого обучения.Я пытаюсь найти производную весов в отношении потери одного класса вместо всех классов.
Я использовал следующий код для получения градиента входного изображения по отношению к общей потере.Если я это визуализирую, это показывает важность пикселей для всех прогнозов.Но я хотел бы вычислить производную входного изображения по определенному классу (например, «lady_bug»).Это должно показать важность пикселей для предсказания lady_bug.У вас есть идея, как я могу это сделать?
from keras.applications.vgg19 import VGG19
import numpy as np
import cv2
from keras import backend as K
import matplotlib.pyplot as plt
from keras.applications.inception_v3 import decode_predictions
def get_model():
model = VGG19(include_top=True, weights='imagenet')
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
def predict(model, images):
numeric_prediction = model.predict(images)
categorical_prediction = decode_predictions(numeric_prediction, top=1)
return [(x[0][1], x[0][2]) for x in categorical_prediction]
def get_test_image():
# Image
image_path = "lady_bug.jpg"
image = cv2.imread(image_path)
my_image = cv2.resize(image, (224,224))
my_image = np.expand_dims(my_image, axis=0)
return my_image
def visualize_sample(sample, file_path):
plt.figure()
plt.imshow(sample)
plt.savefig(file_path, bbox_inches='tight')
def test_input_gradient():
images = get_test_image()
model = get_model()
prediction = predict(model, images)
print(prediction)
gradients = K.gradients(model.output, model.input) #Gradient of output wrt the input of the model (Tensor)
print(gradients)
sess = K.get_session()
evaluated_gradients = sess.run(gradients[0], feed_dict={model.input:
images})
visualize_sample((evaluated_gradients[0]*(10**9.5)).clip(0,255), "test.png")
if __name__ == "__main__":
test_input_gradient()
Вывод:
[('ladybug', 0.53532666)]
[<tf.Tensor 'gradients/block1_conv1/convolution_grad/Conv2DBackpropInput:0' shape=(?, 224, 224, 3) dtype=float32>]