Определить пользовательский float8 в python-numpy и конвертировать из / в float16? - PullRequest
1 голос
/ 10 июля 2019

Я пытаюсь определить пользовательский 8-битный формат с плавающей запятой следующим образом:

  • 1 знаковый бит
  • 2 бита для мантиссы
  • 5 бит для показателя степени

Можно ли определить это как тип данных numpy? Если нет, то как проще всего преобразовать массив numpy из dtype float16 в такой формат (для хранения) и преобразовать его обратно (для вычислений в float16), возможно, используя битовые операции numpy?

Почему:

Я пытаюсь оптимизировать нейронную сеть на специальном оборудовании (FPGA). Для этого я играю с различными представлениями с плавающей точкой. Я уже построил инфраструктуру прямого прохода для своей нейронной сети с помощью numpy, поэтому что-то вроде выше поможет мне проверить снижение точности путем сохранения значений в моем пользовательском типе данных.

1 Ответ

1 голос
/ 11 июля 2019

Я ни в коем случае не эксперт в numpy, но мне нравится думать о проблемах представления FP. Размер вашего массива не огромен, поэтому любой достаточно эффективный метод должен подойти. Не похоже, что есть 8-битное представление FP, я думаю, потому что точность не так хороша.

Для преобразования в массив байтов, каждый из которых содержит одно 8-битное значение FP, для одномерного массива все, что вам нужно, это

float16 = np.array([6.3, 2.557])           # Here's some data in an array
float8s = array.tobytes()[1::2]
print(float8s)
>>> b'FAAF'

Это просто берет старшие байты из 16-битного числа с плавающей запятой, отсекая младшую часть, давая 1-битный знак, 5-битную экспоненту и 2-битную значимость. Старший байт всегда является вторым байтом каждой пары на машине с прямым порядком байтов. Я пробовал это на 2D массиве, и он работает так же. Это усекает. Округление в десятичной системе - это еще одна банка червей.

Возвращение к 16 битам - это просто вставка нулей. Я нашел этот метод экспериментально, и, несомненно, есть лучший способ, но он считывает байтовый массив как 8-битные целые числа и записывает новый как 16-битные целые числа, а затем преобразует его обратно в массив с плавающей точкой. Обратите внимание на представление с прямым порядком байтов, преобразующее обратно в байты, поскольку мы хотим, чтобы 8-битные значения были старшими байтами целых чисел.

float16 = np.frombuffer(np.array(np.frombuffer(float8s, dtype='u1'), dtype='>u2').tobytes(), dtype='f2')
print(float16)
>>> array([6. , 2.5, 2.5, 6. ], dtype=float16)

Вы точно можете увидеть потерю точности! Надеюсь, это поможет. Если этого достаточно, дайте мне знать. Если нет, я бы заглянул глубже.

...