Как пропустить текущую итерацию tf. while_l oop ()? - PullRequest
2 голосов
/ 16 июня 2020

Я только недавно начал работать с Tensorflow2. Я пытаюсь перепрограммировать скрипт, который случайным образом вырезает квадраты из изображений. Исходный код взят из этого репозитория github: Link . Я терплю неудачу из-за tf. while_for () l oop в Tensorflow2. Но вот код, который я написал до сих пор:

def random_erasing(img, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3):
    '''
    img is a 3-D variable (ex: tf.Variable(image, validate_shape=False) ) and  HWC order

    probability: The probability that the operation will be performed.
    sl: min erasing area
    sh: max erasing area
    r1: min aspect ratio
    mean: erasing value
    '''

    i = tf.constant(0)
    N = tf.constant(100)
    while_condition = lambda i: tf.less(i, N)

    def body(i):

        def calculate_valid_boxes(h, w):

            h_tmp = tf.Variable(tf.shape(img)[1]-h, dtype=tf.dtypes.int32)
            w_tmp = tf.Variable(tf.shape(img)[2]-w, dtype=tf.dtypes.int32)

            # x1 = random.randint(0, img.size()[1] - h)
            # y1 = random.randint(0, img.size()[2] - w)
            x1 = tf.map_fn(lambda x: tf.random.uniform([], minval=0, maxval=x, dtype=tf.dtypes.int32), h_tmp)
            y1 = tf.map_fn(lambda x: tf.random.uniform([], minval=0, maxval=x, dtype=tf.dtypes.int32), w_tmp)

            return x1, y1


        area = tf.shape(img)[1] * tf.shape(img)[2]

        target_area = tf.random.uniform([3], minval=sl, maxval=sh, dtype=tf.dtypes.float64) * tf.cast(area, tf.dtypes.float64)
        aspect_ratio = tf.cast(tf.random.uniform([3], minval=r1, maxval=1/r1), tf.dtypes.float64)

        h = tf.cast(tf.math.round(tf.sqrt(target_area * aspect_ratio)), tf.dtypes.int32)
        w = tf.cast(tf.math.round(tf.sqrt(target_area / aspect_ratio)), tf.dtypes.int32)

        # if condition: w < img.size()[2] and h < img.size()[1]:
        cond_1 = tf.less(w, tf.shape(img)[2])
        cond_2 = tf.less(h,tf.shape(img)[1])
        x1 = tf.cond(tf.cast(tf.logical_and(cond_1, cond_2), tf.int32) == 3, lambda: calculate_valid_boxes(h, w))

        return h, w, x1, y1


    # mask_size= area of cutout, offset= place of cutout, constant_value=pixel value to fill in at cutout
    image = tfa.image.cutout(img, mask_size=(h, w), offset=(x1, y1), constant_values=255)

    return image

Моя проблема заключается в следующей строке:

x1 = tf.cond(tf.cast(tf.logical_and(cond_1, cond_2), tf.int32) == 3, calculate_valid_boxes(h, w))

Я всегда получаю «Произошло исключение: TypeError cond (): false_fn требуется аргумент "сообщения. Я хочу вызвать функцию «calculate_valid_boxes ()» в этой строке, если утверждение истинно или если утверждение ложно. Я хочу перейти к новой итерации.

На простом Python вы можете решить эту проблему с помощью оператора break или continue (в зависимости от реализации), но с Tensorflow2 я не могу найти решение.

Если информация актуальна, функция работает с пакетом изображений.

1 Ответ

1 голос
/ 16 июня 2020

В коде есть несколько проблем, вы не должны использовать для этого объекты tf.Variable, этих tf.map_fn можно избежать, а tf.cond всегда должен иметь две ветви. Вот возможная реализация кода, который вы связали в TensorFlow, адаптированного для работы с пакетами изображений. Каждое изображение в пакете независимо модифицируется с заданной вероятностью в другом ящике. Я разбил logi c на несколько функций для ясности.

import tensorflow as tf

@tf.function
def random_erasing(img, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
    '''
    img is a 4-D variable (ex: tf.Variable(image, validate_shape=False) ) and NHWC order

    probability: The probability that the operation will be performed.
    sl: min erasing area
    sh: max erasing area
    r1: min aspect ratio
    mean: erasing value
    '''
    return tf.where(tf.random.uniform([tf.shape(img)[0], 1, 1, 1]) > probability,
                    img,
                    _do_random_erasing(img, sl, sh, r1, mean))

def _do_random_erasing(img, sl, sh, r1, mean):
    s = tf.shape(img, out_type=tf.int32)
    # Sample random h and w values
    def sample_hw(h, w):
        s = tf.shape(img)
        area = s[1] * s[2]
        target_area = tf.random.uniform([s[0]], sl, sh)
        target_area *= tf.dtypes.cast(area, target_area.dtype)
        aspect_ratio = tf.random.uniform([s[0]], r1, 1 / r1)
        h_new = tf.dtypes.cast(tf.math.round(tf.math.sqrt(target_area * aspect_ratio)), tf.int32)
        w_new = tf.dtypes.cast(tf.math.round(tf.math.sqrt(target_area / aspect_ratio)), tf.int32)
        # Only replace values that are still wrong
        m = (h >= s[0]) | (w >= s[1])
        h = tf.where(m, h_new, h)
        w = tf.where(m, w_new, w)
        return h, w
    # Loop
    _, h, w = tf.while_loop(
        # While there are iterations to go and h and w are not good
        lambda i, h, w: (i < 100) & tf.reduce_any((h >= s[1]) | (w >= s[2])),
        # Get new h and w values
        lambda i, h, w: (i + 1, *sample_hw(h, w)),
        [0, tf.fill([s[0]], s[1]), tf.fill([s[0]], s[2])])
    # Erase box if we got valid h and w values
    return tf.cond(tf.reduce_all((h < s[1]) & (w < s[2])),
                   lambda: _erase_random_box(img, h, w, mean),
                   lambda: img)

def _erase_random_box(img, h, w, mean):
    # Make box boundaries
    s = tf.shape(img, out_type=tf.int32)
    # Add extra dimensions for later
    h = tf.reshape(h, [-1, 1, 1])
    w = tf.reshape(w, [-1, 1, 1])
    # Sample random boundaries
    h_max = tf.dtypes.cast(s[1] - h + 1, tf.float32)
    x1 = tf.dtypes.cast(tf.random.uniform(tf.shape(h)) * h_max, h.dtype)
    w_max = tf.dtypes.cast(s[2] - w + 1, tf.float32)
    y1 = tf.dtypes.cast(tf.random.uniform(tf.shape(w)) * w_max, w.dtype)
    # Replacement mask
    _, ii, jj = tf.meshgrid(tf.range(s[0]), tf.range(s[1]), tf.range(s[2]), indexing='ij')
    mask = (ii >= x1) & (ii < x1 + h) & (jj >= y1) & (jj < y1 + w)
    # Replace box
    result = tf.where(tf.expand_dims(mask, axis=-1),
                      tf.dtypes.cast(mean, img.dtype),
                      img)
    # Maybe can use tfa.image.cutout for this function?
    return result

# Test
tf.random.set_seed(100)
# Example batch of three 10x8 single-channel random images
img = tf.random.uniform([3, 8, 10, 1], dtype=tf.float32)
# Apply erasing
erased = random_erasing(img, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[-1])
# Check results
with np.printoptions(precision=2, suppress=True):
    erased_np = erased.numpy()
    print(erased_np[0, :, :, 0])
    # [[ 0.25  0.48  0.39  0.82  0.24  0.39  0.96  0.74  0.31  0.78]
    #  [ 0.36  0.44  0.39  0.41 -1.   -1.   -1.    0.99  0.08  0.7 ]
    #  [ 0.3   0.69  0.95  0.65 -1.   -1.   -1.    0.37  0.5   0.66]
    #  [ 0.42  0.64  0.71  0.86 -1.   -1.   -1.    0.78  0.16  0.19]
    #  [ 0.47  0.66  0.97  0.63 -1.   -1.   -1.    0.66  0.41  0.18]
    #  [ 0.56  0.33  0.58  0.03 -1.   -1.   -1.    0.01  0.44  0.29]
    #  [ 0.77  0.63  0.61  0.09  0.77  0.25  0.15  0.18  0.75  0.6 ]
    #  [ 0.74  0.4   0.15  0.18  0.18  0.07  0.53  0.16  0.61  0.42]]
    print(erased_np[1, :, :, 0])
    # [[0.55 0.31 0.67 0.42 0.93 0.31 0.1  0.67 0.11 0.3 ]
    #  [0.99 0.66 0.57 0.51 0.01 0.76 0.69 0.28 0.1  0.6 ]
    #  [0.91 0.63 0.23 0.   0.21 0.7  0.85 0.16 0.35 0.18]
    #  [0.67 0.83 0.66 0.4  0.51 0.84 0.07 0.62 0.8  0.66]
    #  [0.62 0.23 0.29 0.99 0.9  0.7  0.68 0.09 0.92 0.67]
    #  [0.36 0.75 0.51 0.76 0.68 0.56 0.07 0.68 0.57 0.58]
    #  [0.98 0.75 0.22 0.87 0.28 0.55 0.77 0.65 0.8  0.28]
    #  [0.76 0.46 0.11 0.85 0.3  0.35 0.81 0.48 0.24 0.81]]
    print(erased_np[2, :, :, 0])
    # [[ 0.42  0.33  0.44  0.68  0.89  0.88  0.8   0.72  0.5   0.61]
    #  [ 0.54 -1.   -1.   -1.   -1.    0.56  0.33  0.24  0.98  0.89]
    #  [ 0.06 -1.   -1.   -1.   -1.    0.64  0.76  0.26  0.1   0.57]
    #  [ 0.39 -1.   -1.   -1.   -1.    0.09  0.24  0.47  0.92  0.2 ]
    #  [ 0.46 -1.   -1.   -1.   -1.    0.61  0.11  0.5   0.52  0.06]
    #  [ 0.71  0.74  0.03  0.77  0.87  0.51  0.42  0.87  0.73  0.01]
    #  [ 0.18  0.71  0.38  0.17  0.18  0.56  0.58  0.7   0.1   0.87]
    #  [ 0.46  0.19  0.98  0.19  0.19  0.41  0.95  0.    0.82  0.05]]

Одно предостережение при использовании этой функции состоит в том, что tf.while_loop пытается найти хорошие значения h и w для все изображений в пакете, но если ему не удается выбрать хорошую пару значений в итерациях 100 l oop даже для одного из изображений, тогда, если не будет применяться стирание к любому изображение. Вы можете подправить код тем или иным способом, чтобы обойти это, хотя я полагаю, достаточно просто указать разумное количество итераций.

...