Как получить оконный набор данных в tenorflow 2 из массива массивов numpy? - PullRequest
2 голосов
/ 26 сентября 2019

Представьте, что у меня есть некоторые данные:

some_data = np.array([[1,2,3,4], [5, 6, 7,8]])

Это выглядит так:

array([[1, 2, 3, 4],
       [5, 6, 7, 8]])

Каждая строка представляет отдельное наблюдение, поэтому их не следует объединять.Я хочу создать оконный набор данных, каждое окно размером 3, смещенное на 1. Когда я пропускаю одно наблюдение, я получаю то, что хочу, например:

dataset = tf.data.Dataset.from_tensor_slices(some_data[0])
dataset = dataset.window(size=3, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(3))

Результат:

for x in dataset:
    print(x.numpy())

[1 2 3]
[2 3 4]

Но когда я передаю весь массив массивов, я ничего не возвращаю.

dataset = tf.data.Dataset.from_tensor_slices(some_data)
dataset = dataset.window(size=3, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(3))

Это то, что я ожидал:

for x in dataset:
    print(x.numpy())

[1 2 3]
[2 3 4]
[5 6 7]
[6 7 8]

Полагаю, я мог бы перебрать some_data и пропустить один массив за раз, а затем объединить наборы данных, но этокажется плохим решением.Какой правильный способ сделать это?

Я использую Tensorflow 2.0.Спасибо!

1 Ответ

1 голос
/ 29 сентября 2019

Каждая строка набора данных имеет только один элемент при использовании dataset = tf.data.Dataset.from_tensor_slices(some_data[0]).

dataset = tf.data.Dataset.from_tensor_slices(some_data[0])
for x in dataset:
    print(x.numpy())
1
2
3
4

Но каждая строка набора данных имеет четыре элемента при использовании dataset = tf.data.Dataset.from_tensor_slices(some_data).

dataset = tf.data.Dataset.from_tensor_slices(some_data)
for x in dataset:
    print(x.numpy())
[1 2 3 4]
[5 6 7 8]

Итак, что вам нужно сделать, это преобразовать каждую строку и объединить ее.

import numpy as np
import tensorflow as tf

some_data = np.array([[1,2,3,4], [5, 6, 7,8]])
dataset = tf.data.Dataset.from_tensor_slices(some_data)

def parse_samples(x):
    return tf.data.Dataset.from_tensor_slices(x)\
        .window(size=3, shift=1, drop_remainder=True)\
        .flat_map(lambda window: window.batch(3))

dataset = dataset.flat_map(parse_samples)

for x in dataset:
    print(x.numpy())

[1 2 3]
[2 3 4]
[5 6 7]
[6 7 8]
...