Преобразовать строку в байт для загрузчика Pytorch - PullRequest
0 голосов
/ 11 октября 2019

Метод загрузки пути к модели Pytorch не находится под моим контролем, и я пытаюсь найти способ конвертировать загруженные строковые данные в байтовые данные. Приведенный ниже код загружает мою сохраненную модель из Dropbox и использует байты с кодировкой utf-8 для кодирования строки. Проблема в том, что когда я использую torch.load с BytesIO, я получаю UnpicklingError с недопустимым ключом загрузки, '<'. </p>

    data = bytes(self.Download("https://www.dropbox.com/s/exampleurl/checkpoint.pth?dl=1"), 'utf-8')

    self.agent.local.load_state_dict(torch.load(BytesIO(data ), map_location=lambda storage, loc: storage))

Код ниже работал отлично, пока запросы не были отключены, и теперь я пытаюсь использоватьМетод выше.

    dropbox_url = "https://www.dropbox.com/s/exampleurl/checkpoint.pth?dl=1"

    data = requests.get(dropbox_url )

    self.agent.local.load_state_dict(torch.load(BytesIO(data.content), map_location=lambda storage, loc: storage))

Мне просто нужно найти способ, как правильно преобразовать строку в байтовые данные.

1 Ответ

0 голосов
/ 15 октября 2019

Мне пришлось преобразовать байтовые данные в base64 и сохранить файл в этом формате. После загрузки в Dropbox и загрузки с использованием встроенного метода я преобразовал файл base64 обратно в байты, и он заработал!

import base64
from io import BytesIO

with open("checkpoint.pth", "rb") as f:
    byte = f.read(1)

# Base64 Encode the bytes
data_e = base64.b64encode(byte)

filename ='base64_checkpoint.pth'

with open(filename, "wb") as output:
    output.write(data_e)

# Save file to Dropbox

# Download file on server
b64_str= self.Download('url')

# String Encode to bytes
byte_data = b64_str.encode("UTF-8")

# Decoding the Base64 bytes
str_decoded = base64.b64decode(byte_data)

# String Encode to bytes
byte_decoded = str_decoded.encode("UTF-8")

# Decoding the Base64 bytes
decoded = base64.b64decode(byte_decoded)

torch.load(BytesIO(decoded))
...