numpy frombuffer () и tostring () - чтение из csv - PullRequest
0 голосов
/ 20 февраля 2019

Я занимался этим часами, искал повсюду ТАК без приличного ответа на мою конкретную проблему.

Итак, я строю тренировочный набор для модели обучения с подкреплением, и я бы хотелчтобы сохранить этот набор обучения в CSV-файл.Каждая запись обучающего набора имеет вид

[np.ndarray(shape=(18,8,8)), np.ndarray(shape=(1968,)), int64]

Вход в нейронную сеть представляет собой тензорный / массивный массив 18x8x8, а на выходе - политика плоского массива (1968) + целочисленное значение.

Когда я собираюсь записать это в файл csv, я использую следующую функцию numpy для элементов ввода и политики каждой записи:

    input_bytes = inputs.tobytes() # inputs.tostring() also works
    policy_bytes = policy.tobytes() # same here, policy.tostring()

Когда наступает время обученияМне нужно прочитать эти столбцы из CSV-файла и превратить байты обратно в numpy.ndarray объекты.Я знаю исходные типы данных и формы - np.int32, (18,8,8) для ввода и np.float64, (1968,) для политики.Итак, вы могли бы подумать, что я мог бы просто использовать:

    # need to use *_bytes[1:] because the 'b' character is written
    # to the csv when we save
    inputs = np.reshape(np.fromstring(input_bytes[1:], dtype=np.int32), (18,8,8))
    policy = np.fromstring(policy_bytes[1:], dtype=np.float64)

Это терпит неудачу, выдавая ошибку:

ValueError: string size must be a multiple of element size

Если я попытаюсь преобразовать строку в байты и использовать frombuffer, т.е.

    inputs = np.reshape(np.frombuffer(bytes(input_bytes[1:], 'utf-8'), dtype=np.int32), (18,8,8))
    policy = np.frombuffer(bytes(policy_bytes[1:], 'utf-8'), dtype=np.float64)

По сути, я получаю точно такую ​​же ошибку.

ValueError: buffer size must be a multiple of element size

Я понимаю, что это должно быть какой-то проблемой кодирования, но я не могу точно определить это.Обратите внимание, что следующее прекрасно работает:

test_input = np.zeros(shape=(18,8,8), dtype=np.int32).tobytes()
test_policy = np.zeros(shape=(1968,), dtype=np.float64).tobytes()

np.reshape(np.frombuffer(test_input, dtype=np.int32), (18,8,8))
np.frombuffer(test_policy, dtype=np.float64)

Как я могу записать байты в файл csv и позже загрузить их обратно в объект ndarray, читая из файла?

РЕДАКТИРОВАТЬ: Вот пример CSV:

csv head

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...