Простая свертка из 1 фильтра со всеми единицами, за которой следует максимальное объединение, сделает это.
subArrayX = 3
subArrayY = 3
inputChannels = 1
outputChannels = 1
convFilter = K.ones((subArrayX, subArrayY, inputChannels, outputChannels))
def local_loss(true, pred):
diff = K.abs(true-pred) #you might also try K.square instead of abs
localSums = K.conv2d(diff, convFilter)
localSums = K.batch_flatten(localSums)
#if using more than 1 channel, you might want a different thing here
return K.max(localSums, axis=-1)
model.compile(loss = local_loss, ....)
Для всех возможных форм:
convWeights = []
for i in range(1, maxWidth+1):
for j in range(1, maxHeight+1):
convWeights.append(K.ones((i,j,1,1)))
def custom_loss(true,pred):
diff = true - pred
#sums for each array size
sums = [K.conv2d(diff, w) for w in convWeights]
# I didn't understand if you want the max abs sum or abs of max sum
# add this line depending on the answer:
sums = [K.abs(s) for s in sums]
#get the max sum for each array size
sums = [K.batch_flatten(s) for s in sums]
sums = [K.max(s, axis=-1) for s in sums]
#global sums for all sizes
sums = K.stack(sums, axis=-1)
sums = K.max(sums, axis=-1)
return K.abs(sums)
Попытка чего-то похожего на Кадане (разделите размеры)
Давайте просто сделаем это в отдельных измерениях:
if height >= width:
convFilters1 = [K.ones((1, i, 1, 1)) for i in range(1,width+1)]
convFilters2 = [K.ones((i, 1, 1, 1) for i in range(1,height+1)]
concatDim1 = 2
concatDim2 = 1
else:
convFilters1 = [K.ones((i, 1, 1, 1)) for i in range(1,height+1)]
convFilters2 = [K.ones((1, i, 1, 1) for i in range(1,width+1)]
concatDim1 = 1
concatDim2 = 2
def custom_loss_2_step(true,pred):
diff = true-pred #shape (samp, h, w, 1)
sums = [K.conv2d(diff, f) for f in convFilters1] #(samp, h, var, 1)
#(samp, var, w, 1)
sums = K.concatenate(sums, axis=concatDim1) #(samp, h, superW, 1)
#(samp, superH, w, 1)
sums = [K.conv2d(sums, f) for f in convFilters2] #(samp, var, superW, 1)
#(samp, superH, var, 1)
sums = K.concatenate(sums, axis=concatDim2) #(samp, superH, superW, 1)
sums = K.batch_flatten(sums) #(samp, allSums)
#??? sums = K.abs(sums)
maxSum = K.max(sums, axis-1) #(samp,)
#??? maxSum = K.abs(maxSum)
return maxSum