В коде есть несколько проблем, вы не должны использовать для этого объекты tf.Variable
, этих tf.map_fn
можно избежать, а tf.cond
всегда должен иметь две ветви. Вот возможная реализация кода, который вы связали в TensorFlow, адаптированного для работы с пакетами изображений. Каждое изображение в пакете независимо модифицируется с заданной вероятностью в другом ящике. Я разбил logi c на несколько функций для ясности.
import tensorflow as tf
@tf.function
def random_erasing(img, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
'''
img is a 4-D variable (ex: tf.Variable(image, validate_shape=False) ) and NHWC order
probability: The probability that the operation will be performed.
sl: min erasing area
sh: max erasing area
r1: min aspect ratio
mean: erasing value
'''
return tf.where(tf.random.uniform([tf.shape(img)[0], 1, 1, 1]) > probability,
img,
_do_random_erasing(img, sl, sh, r1, mean))
def _do_random_erasing(img, sl, sh, r1, mean):
s = tf.shape(img, out_type=tf.int32)
# Sample random h and w values
def sample_hw(h, w):
s = tf.shape(img)
area = s[1] * s[2]
target_area = tf.random.uniform([s[0]], sl, sh)
target_area *= tf.dtypes.cast(area, target_area.dtype)
aspect_ratio = tf.random.uniform([s[0]], r1, 1 / r1)
h_new = tf.dtypes.cast(tf.math.round(tf.math.sqrt(target_area * aspect_ratio)), tf.int32)
w_new = tf.dtypes.cast(tf.math.round(tf.math.sqrt(target_area / aspect_ratio)), tf.int32)
# Only replace values that are still wrong
m = (h >= s[0]) | (w >= s[1])
h = tf.where(m, h_new, h)
w = tf.where(m, w_new, w)
return h, w
# Loop
_, h, w = tf.while_loop(
# While there are iterations to go and h and w are not good
lambda i, h, w: (i < 100) & tf.reduce_any((h >= s[1]) | (w >= s[2])),
# Get new h and w values
lambda i, h, w: (i + 1, *sample_hw(h, w)),
[0, tf.fill([s[0]], s[1]), tf.fill([s[0]], s[2])])
# Erase box if we got valid h and w values
return tf.cond(tf.reduce_all((h < s[1]) & (w < s[2])),
lambda: _erase_random_box(img, h, w, mean),
lambda: img)
def _erase_random_box(img, h, w, mean):
# Make box boundaries
s = tf.shape(img, out_type=tf.int32)
# Add extra dimensions for later
h = tf.reshape(h, [-1, 1, 1])
w = tf.reshape(w, [-1, 1, 1])
# Sample random boundaries
h_max = tf.dtypes.cast(s[1] - h + 1, tf.float32)
x1 = tf.dtypes.cast(tf.random.uniform(tf.shape(h)) * h_max, h.dtype)
w_max = tf.dtypes.cast(s[2] - w + 1, tf.float32)
y1 = tf.dtypes.cast(tf.random.uniform(tf.shape(w)) * w_max, w.dtype)
# Replacement mask
_, ii, jj = tf.meshgrid(tf.range(s[0]), tf.range(s[1]), tf.range(s[2]), indexing='ij')
mask = (ii >= x1) & (ii < x1 + h) & (jj >= y1) & (jj < y1 + w)
# Replace box
result = tf.where(tf.expand_dims(mask, axis=-1),
tf.dtypes.cast(mean, img.dtype),
img)
# Maybe can use tfa.image.cutout for this function?
return result
# Test
tf.random.set_seed(100)
# Example batch of three 10x8 single-channel random images
img = tf.random.uniform([3, 8, 10, 1], dtype=tf.float32)
# Apply erasing
erased = random_erasing(img, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[-1])
# Check results
with np.printoptions(precision=2, suppress=True):
erased_np = erased.numpy()
print(erased_np[0, :, :, 0])
# [[ 0.25 0.48 0.39 0.82 0.24 0.39 0.96 0.74 0.31 0.78]
# [ 0.36 0.44 0.39 0.41 -1. -1. -1. 0.99 0.08 0.7 ]
# [ 0.3 0.69 0.95 0.65 -1. -1. -1. 0.37 0.5 0.66]
# [ 0.42 0.64 0.71 0.86 -1. -1. -1. 0.78 0.16 0.19]
# [ 0.47 0.66 0.97 0.63 -1. -1. -1. 0.66 0.41 0.18]
# [ 0.56 0.33 0.58 0.03 -1. -1. -1. 0.01 0.44 0.29]
# [ 0.77 0.63 0.61 0.09 0.77 0.25 0.15 0.18 0.75 0.6 ]
# [ 0.74 0.4 0.15 0.18 0.18 0.07 0.53 0.16 0.61 0.42]]
print(erased_np[1, :, :, 0])
# [[0.55 0.31 0.67 0.42 0.93 0.31 0.1 0.67 0.11 0.3 ]
# [0.99 0.66 0.57 0.51 0.01 0.76 0.69 0.28 0.1 0.6 ]
# [0.91 0.63 0.23 0. 0.21 0.7 0.85 0.16 0.35 0.18]
# [0.67 0.83 0.66 0.4 0.51 0.84 0.07 0.62 0.8 0.66]
# [0.62 0.23 0.29 0.99 0.9 0.7 0.68 0.09 0.92 0.67]
# [0.36 0.75 0.51 0.76 0.68 0.56 0.07 0.68 0.57 0.58]
# [0.98 0.75 0.22 0.87 0.28 0.55 0.77 0.65 0.8 0.28]
# [0.76 0.46 0.11 0.85 0.3 0.35 0.81 0.48 0.24 0.81]]
print(erased_np[2, :, :, 0])
# [[ 0.42 0.33 0.44 0.68 0.89 0.88 0.8 0.72 0.5 0.61]
# [ 0.54 -1. -1. -1. -1. 0.56 0.33 0.24 0.98 0.89]
# [ 0.06 -1. -1. -1. -1. 0.64 0.76 0.26 0.1 0.57]
# [ 0.39 -1. -1. -1. -1. 0.09 0.24 0.47 0.92 0.2 ]
# [ 0.46 -1. -1. -1. -1. 0.61 0.11 0.5 0.52 0.06]
# [ 0.71 0.74 0.03 0.77 0.87 0.51 0.42 0.87 0.73 0.01]
# [ 0.18 0.71 0.38 0.17 0.18 0.56 0.58 0.7 0.1 0.87]
# [ 0.46 0.19 0.98 0.19 0.19 0.41 0.95 0. 0.82 0.05]]
Одно предостережение при использовании этой функции состоит в том, что tf.while_loop
пытается найти хорошие значения h
и w
для все изображений в пакете, но если ему не удается выбрать хорошую пару значений в итерациях 100 l oop даже для одного из изображений, тогда, если не будет применяться стирание к любому изображение. Вы можете подправить код тем или иным способом, чтобы обойти это, хотя я полагаю, достаточно просто указать разумное количество итераций.