Случайная логическая маска, выбранная в соответствии с пользовательским PDF в Tensorflow - PullRequest
0 голосов
/ 12 февраля 2019

Я пытаюсь сгенерировать случайную логическую маску, выбранную в соответствии с предопределенным распределением вероятностей.Распределение вероятностей сохраняется в тензоре той же формы, что и полученная маска.Каждая запись содержит вероятность того, что маска будет истинной в этом конкретном месте.

Короче говоря, я ищу функцию, которая принимает 4 входа:

  • pdf : Тензор для использования в качестве PDF
  • s : количество выборок на маску
  • n : общее количество создаваемых масок
  • replace : логическое значение, указывающее, следует ли выполнять выборку с заменой

и возвращает n логические маски

Упрощенный способ сделать это с помощью numpy будет выглядеть следующим образом:

def sample_mask(pdf, s, replace):

    hight, width = pdf.shape
    # Flatten to 1 dimension
    pdf = np.resize(pdf, (hight*width))
    # Sample according to pdf, the result is an array of indices
    samples=np.random.choice(np.arange(hight*width),
                    size=s, replace=replace, p=pdf)

    mask = np.zeros(hight*width)

    # Apply indices to mask
    for s in samples:
        mask[s]=1
    # Resize back to the original shape
    mask = np.resize(mask, (hight, width))
    return mask

Я уже понял, что часть выборки без параметра replace может быть выполнена следующим образом:

    samples = tf.multinomial(tf.log(pdf_tensor), n)

Но я застрял, когда дело доходит до преобразования образцов в маску.

1 Ответ

0 голосов
/ 13 февраля 2019

Я, должно быть, спал, вот как я решил:

def sample_mask(pdf, s, n, replace):
    """Initialize the model.

        Args:    
             pdf: A 3D Tensor of shape (batch_size, hight, width, channels=1) to use as a PDF
             s: The number of samples per mask. This value should be less than hight*width
             n: The total number of masks to generate
             replace: A boolean indicating if sampling should be done with replacement

        Returns:
            A Tensor of shape (batch_size, hight, width, channels=1, n) containing
            values 1 or 0.
    """
    batch_size, hight, width, channels = pdf.shape
    # Flatten pdf
    pdf = tf.reshape(pdf, (batch_size, hight*width))

    if replace:
        # Sample with replacement. Output is a tensor of shape (batch_size, n)
        sample_fun = lambda: tf.multinomial(tf.log(pdf), s)
    else:
        # Sample without replacement. Output is a tensor of shape (batch_size, n).
        # Cast the output to 'int64' to match the type needed for SparseTensor's indices
        sample_fun = lambda: tf.cast(sample_without_replacement(tf.log(pdf), s), dtype='int64')

    # Create batch indices
    idx = tf.range(batch_size, dtype='int64')
    idx = tf.expand_dims(idx, 1)
    # Transform idx to a 2D tensor of shape (batch_size, samples_per_batch)
    # Example: [[0 0 0 0 0],[1 1 1 1 1],[2 2 2 2 2]]
    idx = tf.tile(idx, [1, s])

    mask_list = []
    for i in range(n):
        # Generate samples
        samples = sample_fun()
        # Combine batch indices and samples
        samples = tf.stack([idx,samples])
        # Transform samples to a list of indicies: (batch_index, sample_index)
        sample_indices = tf.transpose(tf.reshape(samples, [2, -1]))
        # Create the mask as a sparse tensor and set sampled indices to 1
        mask = tf.SparseTensor(indices=sample_indices, values=tf.ones(s*batch_size), dense_shape=[batch_size, hight*width]) 
        # Convert mask to a dense tensor. Non-sampled values are set to 0.
        # Don't validate the indices, since this requires indices to be ordered
        # and unique.
        mask = tf.sparse.to_dense(mask, default_value=0,validate_indices=False)
        # Reshape to input shape and append to list of tensors
        mask_list.append(tf.reshape(mask, [batch_size, hight, width, channels]))
    # Combine all masks into a tensor of shape:
    # (batch_size, hight, width, channels=1, number_of_masks)
    return tf.stack(mask_list, axis=-1)

Функция для выборки без замены, как предложено здесь: https://github.com/tensorflow/tensorflow/issues/9260#issuecomment-437875125

Используется трюк Gumble-max: https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/

def sample_without_replacement(logits, K):
    z = -tf.log(-tf.log(tf.random_uniform(tf.shape(logits),0,1)))
    _, indices = tf.nn.top_k(logits + z, K)
    return indices
...