У меня есть этот код Pytorch, который я хочу скопировать в Tensorflow. Я почти на месте, но вот здесь я застрял в цикле while:
Pytorch
keep = scores.new(scores.size(0)).zero_().long()
if boxes.numel() == 0: return keep
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
area = torch.mul(x2 - x1, y2 - y1)
v, idx = scores.sort(0) # sort in ascending order
idx = idx[-top_k:] # indices of the top-k largest vals
xx1 = boxes.new()
yy1 = boxes.new()
xx2 = boxes.new()
yy2 = boxes.new()
w = boxes.new()
h = boxes.new()
count = 0
while idx.numel() > 0:
i = idx[-1] # index of current largest val
keep[count] = i
count += 1
pdb.set_trace()
if idx.size(0) == 1: break
idx = idx[:-1] # remove kept element from view
# load bboxes of next highest vals
torch.index_select(x1, 0, idx, out=xx1)
torch.index_select(y1, 0, idx, out=yy1)
torch.index_select(x2, 0, idx, out=xx2)
torch.index_select(y2, 0, idx, out=yy2)
# store element-wise max with next highest score
xx1 = torch.clamp(xx1, min=x1[i])
yy1 = torch.clamp(yy1, min=y1[i])
xx2 = torch.clamp(xx2, max=x2[i])
yy2 = torch.clamp(yy2, max=y2[i])
w.resize_as_(xx2)
h.resize_as_(yy2)
w = xx2 - xx1
h = yy2 - yy1
# check sizes of xx1 and xx2.. after each iteration
w = torch.clamp(w, min=0.0)
h = torch.clamp(h, min=0.0)
inter = w*h
# IoU = i / (area(a) + area(b) - i)
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
union = (rem_areas - inter) + area[i]
IoU = inter/union # store result in iou
# keep only elements with an IoU <= overlap
idx = idx[IoU.le(overlap)]
pdb.set_trace()
Tensorflow
keep = tf.Variable(tf.zeros(tf.size(scores), tf.int64))
if tf.size(boxes) == 0: return keep
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
area = tf.multiply(x2 - x1, y2 - y1)
idx = tf.argsort(scores, axis=0)#scores.sort(0) # sort in ascending order
idx = idx[-top_k:] # indices of the top-k largest vals
xx1 = tf.placeholder(boxes.dtype, shape=[None])
yy1 = tf.placeholder(boxes.dtype, shape=[None])
xx2 = tf.placeholder(boxes.dtype, shape=[None])
yy2 = tf.placeholder(boxes.dtype, shape=[None])
w = tf.placeholder(boxes.dtype, shape=[None])
h = tf.placeholder(boxes.dtype, shape=[None])
count = 0
loop_vars = [idx, keep, count, x1, x2, y1, y2, xx1, xx2, yy1, yy2, w, h, area, overlap]
def loop_cond(idx, keep, count, x1, x2, y1, y2, xx1, xx2, yy1, yy2, w, h, area, overlap):
return tf.size(idx) > 0
def loop_body(idx, keep, count, x1, x2, y1, y2, xx1, xx2, yy1, yy2, w, h, area, overlap):
i = idx[-1] # index of current largest val
pdb.set_trace()
# m_mask = np.zeros(keep.get_shape().as_list()[0], dtype=np.int64)
# m_mask[count] = i
keep = tf.scatter_update(keep, count, i)
# keep = tf.add(keep, tf.constant(m_mask))
#keep[count] = i
count += 1
if idx.get_shape().as_list()[0] == 1: return keep, count
idx = idx[:-1] # remove kept element from view
# load bboxes of next highest vals
xx1 = tf.gather(x1, idx)
yy1 = tf.gather(y1, idx)
xx2 = tf.gather(x2, idx)
yy2 = tf.gather(y2, idx)
# store element-wise max with next highest score
# xx1 = torch.clamp(xx1, min=x1[i])
xx1 = tf.clip_by_value(xx1, x1[i], tf.reduce_max(xx1))
# yy1 = torch.clamp(yy1, min=y1[i])
yy1 = tf.clip_by_value(yy1, y1[i], tf.reduce_max(yy1))
# xx2 = torch.clamp(xx2, max=x2[i])
xx2 = tf.clip_by_value(xx2, tf.reduce_min(xx2), x2[i])
# yy2 = torch.clamp(yy2, max=y2[i])
yy2 = tf.clip_by_value(yy2, tf.reduce_min(yy2), y2[i])
# w.resize_as_(xx2)
w = tf.reshape(w, xx2.shape)
# h.resize_as_(yy2)
h = tf.reshape(w, yy2.shape)
w = xx2 - xx1
h = yy2 - yy1
# check sizes of xx1 and xx2.. after each iteration
w = tf.clip_by_value(w, 0.0, tf.reduce_max(w))
h = tf.clip_by_value(h, 0.0, tf.reduce_max(h))
inter = w*h
# IoU = i / (area(a) + area(b) - i)
rem_areas = tf.gather(area, idx) # load remaining areas)
union = (rem_areas - inter) + area[i]
IoU = inter / union # store result in iou
# keep only elements with an IoU <= overlap
idx = idx[IoU.le(overlap)]
return idx, keep, count, xx1, xx2, yy1, yy2, w, h
# pdb.set_trace()
idx, keep, count, xx1, xx2, yy1, yy2, w, h = tf.while_loop(loop_cond, loop_body, loop_vars)
Я пробежал большую часть кода, но застрял, когда столкнулся с циклом while.Я не мог использовать обычный цикл Python в Tensorflow , потому что условие было бы тензорным.