Я пытаюсь сделать Style Transfer с DarkNet19 (ZooModel из DL4J), и после одного обратного распространения через граф очищаются маски выпадения на слоях. Как отключить такую настройку?
protected static final Logger log = LoggerFactory.getLogger(BaseGenerationRepo.class);
protected INDArray contentImage;
protected INDArray styleImage;
public static final String[] ALL_LAYERS = new String[]{
"input_1",
"conv2d_1",
"batch_normalization_1",
"leaky_re_lu_1",
"max_pooling2d_1",
"conv2d_2",
"batch_normalization_2",
"leaky_re_lu_2",
"max_pooling2d_2",
"conv2d_3",
"batch_normalization_3",
"leaky_re_lu_3",
"conv2d_4",
"batch_normalization_4",
"leaky_re_lu_4",
"conv2d_5",
"batch_normalization_5",
"leaky_re_lu_5",
"max_pooling2d_3",
"conv2d_6",
"batch_normalization_6",
"leaky_re_lu_6",
"conv2d_7",
"batch_normalization_7",
"leaky_re_lu_7",
"conv2d_8",
"batch_normalization_8",
"leaky_re_lu_8",
"max_pooling2d_4",
"conv2d_9",
"batch_normalization_9",
"leaky_re_lu_9",
"conv2d_10",
"batch_normalization_10",
"leaky_re_lu_10",
"conv2d_11",
"batch_normalization_11",
"leaky_re_lu_11",
"conv2d_12",
"batch_normalization_12",
"leaky_re_lu_12",
"conv2d_13",
"batch_normalization_13",
"leaky_re_lu_13",
"max_pooling2d_5",
"conv2d_14",
"batch_normalization_14",
"leaky_re_lu_14",
"conv2d_15",
"batch_normalization_15",
"leaky_re_lu_15",
"conv2d_16",
"batch_normalization_16",
"leaky_re_lu_16",
"conv2d_17",
"batch_normalization_17",
"leaky_re_lu_17",
"conv2d_18",
"batch_normalization_18",
"leaky_re_lu_18",
"conv2d_19",
"globalpooling",
"softmax",
"loss"
};
public static final String[] STYLE_LAYERS = new String[]{
"conv2d_4,1.0",
"conv2d_9,2.0"
};
public static final String CONTENT_LAYER_NAME = "conv2d_10";
public void generate() throws GenerationException {
try {
ComputationGraph Graph = loadModel(false);
INDArray generatedImage = initGeneratedImage();
Map<String, INDArray> contentActivation = Graph.feedForward(contentImage, true);
Map<String, INDArray> styleActivation = Graph.feedForward(styleImage, true);
HashMap<String, INDArray> styleActivationGram = initStyleGramMap(styleActivation);
AdamUpdater optim = createAdamUpdater();
for (int i = 0; i < ITERATIONS; i++) {
if (i % 5 == 0) log.info("iteration " + i);
Map<String, INDArray> forwardActivation = Graph.feedForward(new INDArray[] { generatedImage }, true, false);
INDArray styleGrad = backPropStyles(Graph, styleActivationGram, forwardActivation);
INDArray contentGrad = backPropContent(Graph, contentActivation, forwardActivation);
INDArray totalGrad = contentGrad.muli(ALPHA).addi(styleGrad.muli(BETA));
optim.applyUpdater(totalGrad, i, 0);
generatedImage.subi(totalGrad);
}
} catch (Exception e) {
e.printStackTrace();
throw new GenerationException();
}
}
protected AdamUpdater createAdamUpdater() {
AdamUpdater adam = new AdamUpdater(new Adam(LEARNING_RATE, BETA1, BETA2, EPSILON));
adam.setStateViewArray(Nd4j.zeros(1, 2* CHANNELS * WIDTH * HEIGHT),
new long[] {1, CHANNELS, HEIGHT, WIDTH}, 'c', true);
return adam;
}
protected ComputationGraph loadModel(boolean logIt) throws IOException {
ZooModel zooModel = Darknet19.builder().workspaceMode(WorkspaceMode.NONE).build();
zooModel.setInputShape(new int[][] {{3, 224, 224}});
ComputationGraph darkNet = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
darkNet.initGradientsView();
if (logIt) log.info(darkNet.summary());
return darkNet;
}
protected INDArray gramMatrix(INDArray activation, boolean normalize) {
INDArray flat = flatten(activation);
return flat.mmul(flat.transpose());
}
protected INDArray flatten(INDArray x) {
long[] shape = x.shape();
return x.reshape(shape[0] * shape[1], shape[2] * shape[3]);
}
protected INDArray initGeneratedImage() {
int totalEntries = CHANNELS * HEIGHT * WIDTH;
double[] result = new double[totalEntries];
for (int i = 0; i < result.length; i++) {
result[i] = ThreadLocalRandom.current().nextDouble(-20, 20);
}
INDArray randomMatrix = Nd4j.create(result, new int[]{1, CHANNELS, HEIGHT, WIDTH});
return randomMatrix.muli(NOISE).addi(contentImage.muli(1 - NOISE));
}
protected HashMap<String, INDArray> initStyleGramMap(Map<String, INDArray> styleActivation) {
HashMap<String, INDArray> gramMap = new HashMap<>();
for (String s : STYLE_LAYERS) {
String[] spl = s.split(",");
String styleLayerName = spl[0];
INDArray activation = styleActivation.get(styleLayerName);
gramMap.put(styleLayerName, gramMatrix(activation, false));
}
return gramMap;
}
protected INDArray backPropStyles(ComputationGraph graph, HashMap<String, INDArray> gramActivations, Map<String, INDArray> forwardActivations) {
INDArray backProp = Nd4j.zeros(1, CHANNELS, HEIGHT, WIDTH);
for (String s : STYLE_LAYERS) {
String[] spl = s.split(",");
String layerName = spl[0];
double weight = Double.parseDouble(spl[1]);
INDArray gramActivation = gramActivations.get(layerName);
INDArray forwardActivation = forwardActivations.get(layerName);
int index = layerIndex(layerName);
INDArray derivativeStyle = derivStyleLossInLayer(gramActivation, forwardActivation).transpose();
backProp.addi(backPropagate(graph, derivativeStyle.reshape(forwardActivation.shape()), index).muli(weight));
}
return backProp;
}
protected INDArray derivStyleLossInLayer(INDArray gramFeatures, INDArray targetFeatures) {
targetFeatures = targetFeatures.dup();
double N = targetFeatures.shape()[0];
double M = targetFeatures.shape()[1] * targetFeatures.shape()[2];
double styleWeight = 1 / (N * N * M * M);
INDArray contentGram = gramMatrix(targetFeatures, false);
INDArray diff = contentGram.sub(gramFeatures);
INDArray fTranspose = flatten(targetFeatures).transpose();
INDArray fTmulGA = fTranspose.mmul(diff);
INDArray derivative = fTmulGA.muli(styleWeight);
return derivative.muli(checkPositive(fTranspose));
}
protected INDArray backPropContent(ComputationGraph graph, Map<String, INDArray> contentActivations, Map<String, INDArray> forwardActivations) {
INDArray contentActivation = contentActivations.get(CONTENT_LAYER_NAME);
INDArray forwardActivation = forwardActivations.get(CONTENT_LAYER_NAME);
INDArray derivativeContent = derivContentLossInLayer(contentActivation, forwardActivation);
return backPropagate(graph, derivativeContent.reshape(forwardActivation.shape()), layerIndex(CONTENT_LAYER_NAME));
}
protected INDArray derivContentLossInLayer(INDArray contentFeatures, INDArray targetFeatures) {
targetFeatures = targetFeatures.dup();
contentFeatures = contentFeatures.dup();
double C = targetFeatures.shape()[0];
double W = targetFeatures.shape()[1];
double H = targetFeatures.shape()[2];
double contentWeight = 1.0 / (2 * C * H * W);
INDArray derivative = targetFeatures.sub(contentFeatures);
return flatten(derivative.muli(contentWeight).muli(checkPositive(targetFeatures)));
}
protected INDArray checkPositive(INDArray matrix) {
BooleanIndexing.applyWhere(matrix, Conditions.lessThan(0.0f), new Value(0.0f));
BooleanIndexing.applyWhere(matrix, Conditions.greaterThan(0.0f), new Value(1.0f));
return matrix;
}
protected int layerIndex(String layerName) {
for (int i = 0; i < ALL_LAYERS.length; i++) {
if (layerName.equalsIgnoreCase(ALL_LAYERS[i])) return i;
}
return -1;
}
protected INDArray backPropagate(ComputationGraph graph, INDArray dLdA, int startIndex) {
for (int i = startIndex; i > 0; i--) {
Layer layer = graph.getLayer(ALL_LAYERS[i]);
dLdA = layer.backpropGradient(dLdA, LayerWorkspaceMgr.noWorkspaces()).getSecond();
}
return dLdA;
}
}
Итак, давайте посмотрим, как обратное распространение приводит к исключению.
Во-первых, он пересылается через график и получает обратно STYLE_LAYERS [0] - «conv2d_4», и все в порядке. После этого без другого feedForward он начинает обратное распространение STYLE_LAYERS [1] - «conv2d_9», а когда он достигает «conv2d_4», появляется сообщение об исключении. Это означает, что после одного обратного распространения на этом слое маски удаления были очищены, и я не могу выполнить обратный переход через него снова. Как я могу решить эту проблему?