Я, должно быть, спал, вот как я решил:
def sample_mask(pdf, s, n, replace):
"""Initialize the model.
Args:
pdf: A 3D Tensor of shape (batch_size, hight, width, channels=1) to use as a PDF
s: The number of samples per mask. This value should be less than hight*width
n: The total number of masks to generate
replace: A boolean indicating if sampling should be done with replacement
Returns:
A Tensor of shape (batch_size, hight, width, channels=1, n) containing
values 1 or 0.
"""
batch_size, hight, width, channels = pdf.shape
# Flatten pdf
pdf = tf.reshape(pdf, (batch_size, hight*width))
if replace:
# Sample with replacement. Output is a tensor of shape (batch_size, n)
sample_fun = lambda: tf.multinomial(tf.log(pdf), s)
else:
# Sample without replacement. Output is a tensor of shape (batch_size, n).
# Cast the output to 'int64' to match the type needed for SparseTensor's indices
sample_fun = lambda: tf.cast(sample_without_replacement(tf.log(pdf), s), dtype='int64')
# Create batch indices
idx = tf.range(batch_size, dtype='int64')
idx = tf.expand_dims(idx, 1)
# Transform idx to a 2D tensor of shape (batch_size, samples_per_batch)
# Example: [[0 0 0 0 0],[1 1 1 1 1],[2 2 2 2 2]]
idx = tf.tile(idx, [1, s])
mask_list = []
for i in range(n):
# Generate samples
samples = sample_fun()
# Combine batch indices and samples
samples = tf.stack([idx,samples])
# Transform samples to a list of indicies: (batch_index, sample_index)
sample_indices = tf.transpose(tf.reshape(samples, [2, -1]))
# Create the mask as a sparse tensor and set sampled indices to 1
mask = tf.SparseTensor(indices=sample_indices, values=tf.ones(s*batch_size), dense_shape=[batch_size, hight*width])
# Convert mask to a dense tensor. Non-sampled values are set to 0.
# Don't validate the indices, since this requires indices to be ordered
# and unique.
mask = tf.sparse.to_dense(mask, default_value=0,validate_indices=False)
# Reshape to input shape and append to list of tensors
mask_list.append(tf.reshape(mask, [batch_size, hight, width, channels]))
# Combine all masks into a tensor of shape:
# (batch_size, hight, width, channels=1, number_of_masks)
return tf.stack(mask_list, axis=-1)
Функция для выборки без замены, как предложено здесь: https://github.com/tensorflow/tensorflow/issues/9260#issuecomment-437875125
Используется трюк Gumble-max: https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
def sample_without_replacement(logits, K):
z = -tf.log(-tf.log(tf.random_uniform(tf.shape(logits),0,1)))
_, indices = tf.nn.top_k(logits + z, K)
return indices