На самом деле это была проблема, связанная с Pytorch
, в которой весь этот беспорядок был вызван этой простой командой при загрузке модели внутри F_V.py
:
if self.device == 'cpu':
checkpoint = torch.load(self.model_checkpoint_path, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(self.model_checkpoint_path)
# the culprit!
self.model = checkpoint['model'].module
Этот способ хранения и загрузка модели очень плохая, предварительно обученные модели изначально были обернуты nn.DataParallel
, и человек, который их спас, сделал это так:
def save_checkpoint(epoch, epochs_since_improvement, model, metric_fc, optimizer, acc, is_best):
print('saving checkpoint ...')
state = {'epoch': epoch,
'epochs_since_improvement': epochs_since_improvement,
'acc': acc,
'model': model,
'metric_fc': metric_fc,
'optimizer': optimizer}
# filename = 'checkpoint_' + str(epoch) + '_' + str(loss) + '.tar'
filename = 'checkpoint.tar'
torch.save(state, filename)
# If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
if is_best:
torch.save(state, 'BEST_checkpoint.tar')
Как видите, он использовал всю модель ('model': model
) в state_dict
и сохраните это. Он должен был использовать state_dict()
.
Это плохо, поскольку одна и та же иерархия файлов / структуры dir / все должно сохраняться постоянно, везде, где будет использоваться эта модель.
и на нас повлияло это так же, как вы можете видеть. Все клиентские службы полагались на models.py
и нуждались в том, чтобы он был рядом с ними, хотя они даже не использовали его.
Изначально я думал, что для решения этой проблемы нам нужно самим создать модели и затем загрузить весы вручную.
if self.model_name == 'r18':
self.model = resnet18(pretrained=False, use_se=use_se)
elif self.model_name == 'r50':
self.model = resnet50(pretrained=False, use_se=use_se)
elif self.model_name == 'r101':
self.model = resnet101(pretrained=False, use_se=use_se)
else:
raise Exception(f"Model name: '{self.model_name}' is not recognized.")
# load the model weights
self.model.load_state_dict(checkpoint['model'].module.state_dict())
Обратите внимание, что, поскольку модель изначально была моделью nn.DataParallel
, для доступа к самой модели мы используем свойство .module
, а затем используем модели state_dict()
для инициализации модели и надеюсь, это решит проблему.
Однако, похоже, что это не так, и поскольку модель сохраняется таким образом, кажется, что таким способом избавиться от таких зависимостей невозможно. вместо этого преобразуйте вашу модель в сценарий факела, а затем сохраните модель.
Таким образом, вы можете избавиться от всех неприятностей.
Решение 1:
Попробуйте преобразовать вашу модель в torch script
, а затем используйте это вместо:
def convert_model(model, input=torch.tensor(torch.rand(size=(1,3,112,112)))):
model = torch.jit.trace(self.model, input)
torch.jit.save(model,'/home/Rika/Documents/models/model.tjm')
, а затем загрузите эту версию вместо:
# load the model
self.model = torch.jit.load('/home/Rika/Documents/models/model.tjm')
Решение 2:
просто сохраните модель state_dict () снова и используйте ее вместо этого: я сам в итоге сделал:
self.model = checkpoint['model'].module
# create the new checkpoint based on what you need
torch.save({'state_dict' : self.model.state_dict(), 'use_se':True},
'/home/Rika/Documents/BEST_checkpoint_r18_2.tar')
and started using the new checkpoint and so far everything has been good