DL4j ZooModels - java.lang.IllegalStateException: Невозможно выполнить backprop: Отсутствует массив маски удаления (уже очищен?) - PullRequest
0 голосов
/ 05 мая 2019

Я пытаюсь сделать Style Transfer с DarkNet19 (ZooModel из DL4J), и после одного обратного распространения через граф очищаются маски выпадения на слоях. Как отключить такую ​​настройку?

protected static final Logger log = LoggerFactory.getLogger(BaseGenerationRepo.class);

    protected INDArray contentImage;
    protected INDArray styleImage;

public static final String[] ALL_LAYERS = new String[]{
            "input_1",
            "conv2d_1",
            "batch_normalization_1",
            "leaky_re_lu_1",
            "max_pooling2d_1",
            "conv2d_2",
            "batch_normalization_2",
            "leaky_re_lu_2",
            "max_pooling2d_2",
            "conv2d_3",
            "batch_normalization_3",
            "leaky_re_lu_3",
            "conv2d_4",
            "batch_normalization_4",
            "leaky_re_lu_4",
            "conv2d_5",
            "batch_normalization_5",
            "leaky_re_lu_5",
            "max_pooling2d_3",
            "conv2d_6",
            "batch_normalization_6",
            "leaky_re_lu_6",
            "conv2d_7",
            "batch_normalization_7",
            "leaky_re_lu_7",
            "conv2d_8",
            "batch_normalization_8",
            "leaky_re_lu_8",
            "max_pooling2d_4",
            "conv2d_9",
            "batch_normalization_9",
            "leaky_re_lu_9",
            "conv2d_10",
            "batch_normalization_10",
            "leaky_re_lu_10",
            "conv2d_11",
            "batch_normalization_11",
            "leaky_re_lu_11",
            "conv2d_12",
            "batch_normalization_12",
            "leaky_re_lu_12",
            "conv2d_13",
            "batch_normalization_13",
            "leaky_re_lu_13",
            "max_pooling2d_5",
            "conv2d_14",
            "batch_normalization_14",
            "leaky_re_lu_14",
            "conv2d_15",
            "batch_normalization_15",
            "leaky_re_lu_15",
            "conv2d_16",
            "batch_normalization_16",
            "leaky_re_lu_16",
            "conv2d_17",
            "batch_normalization_17",
            "leaky_re_lu_17",
            "conv2d_18",
            "batch_normalization_18",
            "leaky_re_lu_18",
            "conv2d_19",
            "globalpooling",
            "softmax",
            "loss"
    };
    public static final String[] STYLE_LAYERS = new String[]{

            "conv2d_4,1.0",
            "conv2d_9,2.0"

    };
    public static final String CONTENT_LAYER_NAME = "conv2d_10";

    public void generate() throws GenerationException {
        try {

            ComputationGraph Graph = loadModel(false);
            INDArray generatedImage = initGeneratedImage();
            Map<String, INDArray> contentActivation = Graph.feedForward(contentImage, true);
            Map<String, INDArray> styleActivation = Graph.feedForward(styleImage, true);
            HashMap<String, INDArray> styleActivationGram = initStyleGramMap(styleActivation);
            AdamUpdater optim = createAdamUpdater();
            for (int i = 0; i < ITERATIONS; i++) {
                if (i % 5 == 0) log.info("iteration " + i);
                Map<String, INDArray> forwardActivation = Graph.feedForward(new INDArray[] { generatedImage }, true, false);
                INDArray styleGrad = backPropStyles(Graph, styleActivationGram, forwardActivation);
                INDArray contentGrad = backPropContent(Graph, contentActivation, forwardActivation);
                INDArray totalGrad = contentGrad.muli(ALPHA).addi(styleGrad.muli(BETA));
                optim.applyUpdater(totalGrad, i, 0);
                generatedImage.subi(totalGrad);
            }
        } catch (Exception e) {
            e.printStackTrace();
            throw new GenerationException();
        }
    }

    protected AdamUpdater createAdamUpdater() {
        AdamUpdater adam = new AdamUpdater(new Adam(LEARNING_RATE, BETA1, BETA2, EPSILON));
        adam.setStateViewArray(Nd4j.zeros(1, 2* CHANNELS * WIDTH * HEIGHT),
                new long[] {1, CHANNELS, HEIGHT, WIDTH}, 'c', true);
        return adam;
    }

    protected ComputationGraph loadModel(boolean logIt) throws IOException {
        ZooModel zooModel = Darknet19.builder().workspaceMode(WorkspaceMode.NONE).build();
        zooModel.setInputShape(new int[][] {{3, 224, 224}});
        ComputationGraph darkNet = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
        darkNet.initGradientsView();
        if (logIt) log.info(darkNet.summary());
        return darkNet;
    }

    protected INDArray gramMatrix(INDArray activation, boolean normalize) {
        INDArray flat = flatten(activation);
        return flat.mmul(flat.transpose());
    }

    protected INDArray flatten(INDArray x) {
        long[] shape = x.shape();
        return x.reshape(shape[0] * shape[1], shape[2] * shape[3]);
    }

    protected INDArray initGeneratedImage() {
        int totalEntries = CHANNELS * HEIGHT * WIDTH;
        double[] result = new double[totalEntries];
        for (int i = 0; i < result.length; i++) {
            result[i] = ThreadLocalRandom.current().nextDouble(-20, 20);
        }
        INDArray randomMatrix = Nd4j.create(result, new int[]{1, CHANNELS, HEIGHT, WIDTH});
        return randomMatrix.muli(NOISE).addi(contentImage.muli(1 - NOISE));
    }

    protected HashMap<String, INDArray> initStyleGramMap(Map<String, INDArray> styleActivation) {
        HashMap<String, INDArray> gramMap = new HashMap<>();
        for (String s : STYLE_LAYERS) {
            String[] spl = s.split(",");
            String styleLayerName = spl[0];
            INDArray activation = styleActivation.get(styleLayerName);
            gramMap.put(styleLayerName, gramMatrix(activation, false));
        }
        return gramMap;
    }

    protected INDArray backPropStyles(ComputationGraph graph, HashMap<String, INDArray> gramActivations, Map<String, INDArray> forwardActivations) {
        INDArray backProp = Nd4j.zeros(1, CHANNELS, HEIGHT, WIDTH);
        for (String s : STYLE_LAYERS) {
            String[] spl = s.split(",");
            String layerName = spl[0];
            double weight = Double.parseDouble(spl[1]);
            INDArray gramActivation = gramActivations.get(layerName);
            INDArray forwardActivation = forwardActivations.get(layerName);
            int index = layerIndex(layerName);
            INDArray derivativeStyle = derivStyleLossInLayer(gramActivation, forwardActivation).transpose();
            backProp.addi(backPropagate(graph, derivativeStyle.reshape(forwardActivation.shape()), index).muli(weight));
        }
        return backProp;
    }

    protected INDArray derivStyleLossInLayer(INDArray gramFeatures, INDArray targetFeatures) {
        targetFeatures = targetFeatures.dup();
        double N = targetFeatures.shape()[0];
        double M = targetFeatures.shape()[1] * targetFeatures.shape()[2];
        double styleWeight = 1 / (N * N * M * M);
        INDArray contentGram = gramMatrix(targetFeatures, false);
        INDArray diff = contentGram.sub(gramFeatures);
        INDArray fTranspose = flatten(targetFeatures).transpose();
        INDArray fTmulGA = fTranspose.mmul(diff);
        INDArray derivative = fTmulGA.muli(styleWeight);
        return derivative.muli(checkPositive(fTranspose));
    }

    protected INDArray backPropContent(ComputationGraph graph, Map<String, INDArray> contentActivations, Map<String, INDArray> forwardActivations) {
        INDArray contentActivation = contentActivations.get(CONTENT_LAYER_NAME);
        INDArray forwardActivation = forwardActivations.get(CONTENT_LAYER_NAME);
        INDArray derivativeContent = derivContentLossInLayer(contentActivation, forwardActivation);
        return backPropagate(graph, derivativeContent.reshape(forwardActivation.shape()), layerIndex(CONTENT_LAYER_NAME));
    }

    protected INDArray derivContentLossInLayer(INDArray contentFeatures, INDArray targetFeatures) {
        targetFeatures = targetFeatures.dup();
        contentFeatures = contentFeatures.dup();
        double C = targetFeatures.shape()[0];
        double W = targetFeatures.shape()[1];
        double H = targetFeatures.shape()[2];
        double contentWeight = 1.0 / (2 * C * H * W);
        INDArray derivative = targetFeatures.sub(contentFeatures);
        return flatten(derivative.muli(contentWeight).muli(checkPositive(targetFeatures)));
    }

    protected INDArray checkPositive(INDArray matrix) {
        BooleanIndexing.applyWhere(matrix, Conditions.lessThan(0.0f), new Value(0.0f));
        BooleanIndexing.applyWhere(matrix, Conditions.greaterThan(0.0f), new Value(1.0f));
        return matrix;
    }

    protected int layerIndex(String layerName) {
        for (int i = 0; i < ALL_LAYERS.length; i++) {
            if (layerName.equalsIgnoreCase(ALL_LAYERS[i])) return i;
        }
        return -1;
    }

    protected INDArray backPropagate(ComputationGraph graph, INDArray dLdA, int startIndex) {
        for (int i = startIndex; i > 0; i--) {
            Layer layer = graph.getLayer(ALL_LAYERS[i]);
            dLdA = layer.backpropGradient(dLdA, LayerWorkspaceMgr.noWorkspaces()).getSecond();
        }
        return dLdA;
    }

}

Итак, давайте посмотрим, как обратное распространение приводит к исключению. Во-первых, он пересылается через график и получает обратно STYLE_LAYERS [0] - «conv2d_4», и все в порядке. После этого без другого feedForward он начинает обратное распространение STYLE_LAYERS [1] - «conv2d_9», а когда он достигает «conv2d_4», появляется сообщение об исключении. Это означает, что после одного обратного распространения на этом слое маски удаления были очищены, и я не могу выполнить обратный переход через него снова. Как я могу решить эту проблему?

...