Обучение распределению Tensorflow 2.3: как создать переменные PerReplica из распределенного набора данных? - PullRequest
0 голосов
/ 02 августа 2020

С помощью tensorflow MirroredStrategy я хотел бы создать переменные PerReplica. Однако с помощью следующего кода:

strategy = tf.distribute.MirroredStrategy()

x = tf.cast( np.linspace(1, 10, 10), dtype=tf.float64 )

x_dataset = tf.data.Dataset.from_tensor_slices(x).batch(10)
x_dataset_dist = strategy.experimental_distribute_dataset(x_dataset)

iterator = iter(x_dataset_dist)

def func(x):
    
    x_ = tf.Variable(x)
    
    return x, x_

strategy.run(func, args=(next(iterator),))

переменные создаются как «Зеркально»:

(PerReplica:{
   0: <tf.Tensor: shape=(5,), dtype=float64, numpy=array([1., 2., 3., 4., 5.])>,
   1: <tf.Tensor: shape=(5,), dtype=float64, numpy=array([ 6.,  7.,  8.,  9., 10.])>
 },
 MirroredVariable:{
   0: <tf.Variable 'Variable:0' shape=(5,) dtype=float64, numpy=array([1., 2., 3., 4., 5.])>,
   1: <tf.Variable 'Variable/replica_1:0' shape=(5,) dtype=float64, numpy=array([1., 2., 3., 4., 5.])>
 })

Вместо «MirroredVariable» можно ли преобразовать набор данных «PerReplica» в "PerReplica" tf.Varibles?

Спасибо!

Wang Zhe

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...