Вы можете определить метод для инициализации весов в соответствии с каждым слоем:
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
А затем просто примените его к своей сети:
model = create_your_model()
model.apply(weights_init)