У меня есть модель pytorch размером 386 МБ, но когда я загружаю модель
state = torch.load(f, flair.device)
Моя память GPU занимает до 900 МБ, почему это происходит, и есть ли способ решить эту проблему??
Так я могу сохранить модель
model_state = self._get_state_dict()
# additional fields for model checkpointing
model_state["optimizer_state_dict"] = optimizer_state
model_state["scheduler_state_dict"] = scheduler_state
model_state["epoch"] = epoch
model_state["loss"] = loss
torch.save(model_state, str(model_file), pickle_protocol=4)