Python тензорный поток, создающий tfrecord с несколькими функциями массива - PullRequest
1 голос
/ 29 января 2020

Я следую за TensorFlow документами , чтобы сгенерировать tf.record из трех NumPy массивов, однако при попытке сериализации данных я получаю сообщение об ошибке. Я хочу, чтобы полученный tfrecord содержал три функции.

import numpy as np
import pandas as pd
# some random data
x = np.random.randn(85)
y = np.random.randn(85,2128)
z = np.random.choice(range(10),(85,155))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_example(feature0, feature1, feature2):
    """
    Creates a tf.Example message ready to be written to a file.
    """
    # Create a dictionary mapping the feature name to the tf.Example-compatible
    # data type.
    feature = {
      'feature0': _float_feature(feature0),
      'feature1': _float_feature(feature1),
      'feature2': _int64_feature(feature2)
    }
    # Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

features_dataset = tf.data.Dataset.from_tensor_slices((x, y, z))

features_dataset

<TensorSliceDataset shapes: ((), (2128,), (155,)), types: (tf.float64, tf.float32, tf.int64)>

for f0,f1,f2 in features_dataset.take(1):
    print(f0)
    print(f1)
    print(f2)
def tf_serialize_example(f0,f1,f2):
  tf_string = tf.py_function(
    serialize_example,
    (f0,f1,f2),  # pass these args to the above function.
    tf.string)      # the return type is `tf.string`.
  return tf.reshape(tf_string, ()) # The result is a scalar

Тем не менее, при попытке запустить tf_serialize_example(f0,f1,f2)

я получаю сообщение об ошибке:

InvalidArgumentError: TypeError: <tf.Tensor: shape=(2128,), dtype=float32, numpy=
array([-0.5435242 ,  0.97947884, -0.74457455, ...,  has type tensorflow.python.framework.ops.EagerTensor, but expected one of: int, long, float
Traceback (most recent call last):

Я думаю, причина в том, что мои функции являются массивами а не цифры. Как заставить этот код работать для функций, которые являются массивами, а не числами?

1 Ответ

1 голос
/ 01 февраля 2020

Хорошо, я нашел время, чтобы посмотреть поближе. Я заметил, что использование features_dataset и tf_serialize_example происходит из учебника на веб-странице tenorflow. Я не знаю, каковы преимущества этого метода и как это исправить.

Но вот рабочий процесс, который должен работать для вашего кода (я заново открыл сгенерированные файлы tfrecords, и они были в порядке).

import numpy as np
import tensorflow as tf

# some random data
x = np.random.randn(85)
y = np.random.randn(85,2128)
z = np.random.choice(range(10),(85,155))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value.flatten()))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""

    return tf.train.Feature(int64_list=tf.train.Int64List(value=value.flatten()))

def serialize_example(feature0, feature1, feature2):
    """
    Creates a tf.Example message ready to be written to a file.
    """
    # Create a dictionary mapping the feature name to the tf.Example-compatible
    # data type.
    feature = {
      'feature0': _float_feature(feature0),
      'feature1': _float_feature(feature1),
      'feature2': _int64_feature(feature2)
    }
    # Create a Features message using tf.train.Example.
    return tf.train.Example(features=tf.train.Features(feature=feature))


writer = tf.python_io.TFRecordWriter('TEST.tfrecords')
example = serialize_example(x,y,z)
writer.write(example.SerializeToString())
writer.close()

Основное различие в этом коде состоит в том, что вы подаете numpy массивы в отличие от Tenorsflow Tensors в serialize_example. Надеюсь, это поможет

...