Я хочу вращать мои изображения параллельно, используя 'map', на этапе предварительной обработки.
Проблема состоит в том, что каждое изображение поворачивается в одном направлении (после генерирования одного случайного числа). Но я хочу, чтобы у каждого изображения была разная степень поворота.
Это мой код:
import tensorflow_addons as tfa
import math
import random
def rotate_tensor(image, label):
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
Я пытался менять начальное число при каждом вызове функции:
import tensorflow_addons as tfa
import math
import random
seed_num = 0
def rotate_tensor(image, label):
seed_num += 1
random.seed(seed_num)
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
Но я получаю:
UnboundLocalError: local variable 'seed_num' referenced before assignment
Я использую tf2, но я не думаю, что это имеет большое значение (кроме кода для поворота изображения).
Править Я попробовал то, что предложил @Mehraban, но похоже, что функция rotate_tensor вызывается только один раз:
import tensorflow_addons as tfa
import math
import random
num_seed = 1
def rotate_tensor(image, label):
global num_seed
num_seed += 1
print(num_seed) #<---- print num_seed
random.seed(num_seed)
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
Но она печатает «2» только один раз. Поэтому я думаю, что rotate_tensor вызывается один раз.
Редактировать 2 - эта функция показывает повернутые изображения:
plt.figure(figsize=(12, 10))
for X_batch, y_batch in rotated_test_set.take(1):
for index in range(9):
plt.subplot(3, 3, index + 1)
plt.imshow(X_batch[index])
plt.title("Predict: {} | Actual: {}".format(class_names[y_test_proba_max_index[index]], class_names[y_batch[index]]))
plt.axis("off")
plt.show()