Я использую нейронную сеть в стиле U-Net для сегментации и потери игральных костей, как определено, хорошо работает
def dice_loss(y_true, y_pred):
smooth = 1.
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = y_true_f * y_pred_f
score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
return 1. - score
Тем не менее, я хотел бы обучить эту модель только на пикселях вокруг сегментации правды земли. Один из способов - установить рамку на сегментации, сделать из нее новый образ и использовать такие образы для обучения. Однако в этом случае теряется много фонового контекста. Лучшим способом было бы сформировать функцию потерь, которая наказывает сеть только за ее предсказание вокруг коробки, установленной на основании истинности, игнорируя значения, которые сеть дает нам за пределами этой рамки. Для этого я попробовал следующую функцию потерь, но она не работает с тензорами. Ниже моя неудачная попытка. Есть ли способ настроить это, чтобы оно заработало?
def getbound(a):
min_x=10000
min_y=10000
max_x=-1
max_y=-1
rows=a.shape[0]
cols=a.shape[1]
for y in range(0, rows):
for x in range(0, cols):
if(a[y,x]>0):
if(x>max_x):
max_x=x
if(y>max_y):
max_y=y
if(x<min_x):
min_x=x
if(y<min_y):
min_y=y
return[min_x,min_y,max_x,max_y]
def box_dice_loss(y_true, y_pred):
smooth = 1.
padding=32
print(y_true.shape)
[min_x,min_y,max_x,max_y]=getbound(y_true)
min_x-=padding
min_y-=padding
max_x+=padding
max_y+=padding
rows=y_true.shape[0]
cols=y_true.shape[1]
for y in range(0, rows):
for x in range(0, cols):
if(y<min_y or y>max_y or x<min_x or x>max_x):
y_pred[y,x]=y_true[y,x]
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = y_true_f * y_pred_f
score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
return 1. - score
В настоящее время ошибка заключается в том, что при инициализации сети тензоры не имеют заданных форм, поэтому я получаю этот вывод
Использование бэкэнда TensorFlow. (?,?,?,?) в getbound
для y в диапазоне (0, строки): ошибка типа: индекс возвращен не-int (тип NoneType)