Преобразование элементов NamedTuple в тип элемента Tensorflow Dataset - PullRequest
1 голос
/ 16 июня 2020

Когда tf.data.Dataset создается из встроенных / обычных Python типов (int, float, list, ...), листовые узлы внутри каждого элемента объекта набора данных преобразуются в tf.Tensor объекты.

from typing import NamedTuple, List
import tensorflow as tf

class Coord(NamedTuple):
  x: int
  y: int

class Element(NamedTuple):
  coords: List[Coord]
  kind: int

my_element = Element([Coord(1, 2), Coord(3, 4)], 5)

def iterable_to_generator(iter):
  def generator():
    for element in iter:
      yield element
  return generator

dataset = tf.data.Dataset.from_generator(
  iterable_to_generator([my_element]), output_types=Element(tf.int32, tf.int32))

for element in dataset:
  print(element)

# Prints:
# Element(coords=<tf.Tensor: id=111, shape=(2, 2), dtype=int32, numpy=
# array([[1, 2],
#        [3, 4]])>, kind=<tf.Tensor: id=112, shape=(), dtype=int32, numpy=5>)

Есть ли способ преобразовать my_element в тот же результат, что и в приведенном выше фрагменте (объект типа Element с двумя tf.Tensor объектами для .coords и .kind) без создания списка, генератора, tf.data.Dataset и последующего извлечения элемента из набора данных?

...