Я занимался этим часами, искал повсюду ТАК без приличного ответа на мою конкретную проблему.
Итак, я строю тренировочный набор для модели обучения с подкреплением, и я бы хотелчтобы сохранить этот набор обучения в CSV-файл.Каждая запись обучающего набора имеет вид
[np.ndarray(shape=(18,8,8)), np.ndarray(shape=(1968,)), int64]
Вход в нейронную сеть представляет собой тензорный / массивный массив 18x8x8, а на выходе - политика плоского массива (1968) + целочисленное значение.
Когда я собираюсь записать это в файл csv
, я использую следующую функцию numpy для элементов ввода и политики каждой записи:
input_bytes = inputs.tobytes() # inputs.tostring() also works
policy_bytes = policy.tobytes() # same here, policy.tostring()
Когда наступает время обученияМне нужно прочитать эти столбцы из CSV-файла и превратить байты обратно в numpy.ndarray
объекты.Я знаю исходные типы данных и формы - np.int32
, (18,8,8)
для ввода и np.float64
, (1968,)
для политики.Итак, вы могли бы подумать, что я мог бы просто использовать:
# need to use *_bytes[1:] because the 'b' character is written
# to the csv when we save
inputs = np.reshape(np.fromstring(input_bytes[1:], dtype=np.int32), (18,8,8))
policy = np.fromstring(policy_bytes[1:], dtype=np.float64)
Это терпит неудачу, выдавая ошибку:
ValueError: string size must be a multiple of element size
Если я попытаюсь преобразовать строку в байты и использовать frombuffer
, т.е.
inputs = np.reshape(np.frombuffer(bytes(input_bytes[1:], 'utf-8'), dtype=np.int32), (18,8,8))
policy = np.frombuffer(bytes(policy_bytes[1:], 'utf-8'), dtype=np.float64)
По сути, я получаю точно такую же ошибку.
ValueError: buffer size must be a multiple of element size
Я понимаю, что это должно быть какой-то проблемой кодирования, но я не могу точно определить это.Обратите внимание, что следующее прекрасно работает:
test_input = np.zeros(shape=(18,8,8), dtype=np.int32).tobytes()
test_policy = np.zeros(shape=(1968,), dtype=np.float64).tobytes()
np.reshape(np.frombuffer(test_input, dtype=np.int32), (18,8,8))
np.frombuffer(test_policy, dtype=np.float64)
Как я могу записать байты в файл csv
и позже загрузить их обратно в объект ndarray, читая из файла?
РЕДАКТИРОВАТЬ: Вот пример CSV:
![csv head](https://i.stack.imgur.com/oI3yb.png)