Я пытаюсь построить федеративную модель обучения. В моем сценарии у меня есть 3 рабочих и оркестратор. Рабочие начинают обучение, и в конце каждого цикла обучения модели отправляются в оркестратор, оркестратор вычисляет среднее значение по федерации и отправляет обратно новую модель, рабочие обучаются этой новой модели и т. Д. c. Пользовательская сеть - это AutoEncoder, который я построил с нуля.
К сожалению, я получаю это сообщение об ошибке от рабочих: RuntimeError: forward () отсутствует значение для аргумента 'input'. Объявление: forward (собственный ClassType, входы Tensor, выходы Tensor) -> (Tensor) , что странно, поскольку моя функция forward определена следующим образом внутри класса AE:
class AutoEncoder(nn.Module):
def __init__(self, code_size):
super().__init__()
self.code_size = code_size
# Encoder specification
self.enc_cnn_1 = nn.Conv2d(3, 10, kernel_size=5)
self.enc_cnn_2 = nn.Conv2d(10, 20, kernel_size=5)
self.enc_linear_1 = nn.Linear(53 * 53 * 20, 50)
self.enc_linear_2 = nn.Linear(50, self.code_size)
# Decoder specification
self.dec_linear_1 = nn.Linear(self.code_size, 160)
self.dec_linear_2 = nn.Linear(160, IMAGE_SIZE)
def forward(self, images):
code = self.encode(images)
out = self.decode(code)
return out, code
def encode(self, images):
code = self.enc_cnn_1(images)
code = F.selu(F.max_pool2d(code, 2))
code = self.enc_cnn_2(code)
code = F.selu(F.max_pool2d(code, 2))
code = code.view([images.size(0), -1])
code = F.selu(self.enc_linear_1(code))
code = self.enc_linear_2(code)
return code
def decode(self, code):
out = F.selu(self.dec_linear_1(code))
out = F.sigmoid(self.dec_linear_2(out))
out = out.view([code.size(0), 3, IMAGE_WIDTH, IMAGE_HEIGHT])
return out
Loss function (cross entropy)
```
@torch.jit.script
def loss_fn(inputs, outputs):
return torch.nn.functional.mse_loss(input=inputs, target=outputs)
def set_gradients(model, finetuning):
"""Helper function to exclude all gradients from training
used for transfer learning in feature extract mode; See: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
Args:
model (torch.nn.Module): model object.
finetuning (bool): if true, nothing will be changed; transfer learning will be used in finetuning mode, i.e., all gradients are trained;
if false, all gradients get excluded from training, used in feature extract mode
"""
if not finetuning:
for param in model.parameters():
param.requires_grad = False
```
```
def initialize_model():
model = AutoEncoder(code_size)
set_gradients(model, False)
return model
async def train_model_on_worker(
worker: websocket_client.WebsocketClientWorker,
traced_model: torch.jit.ScriptModule,
dataset_key: str,
batch_size: int,
curr_round: int,
lr: float,
):
traced_model.train()
print("train mode on")
train_config = sy.TrainConfig(
model=traced_model,
loss_fn=loss_fn,
batch_size=batch_size,
shuffle=True,
epochs=1,
optimizer="Adam",
optimizer_args={"lr": lr}
)
logger.info(worker.id + " send trainconfig")
train_config.send(worker)
print("Model sent to the worker")
logger.info(worker.id + " start training")
await worker.async_fit(dataset_key=DATASET_KEY, return_ids=[0])
logger.info(worker.id + " training done")
results = dict()
logger.info(worker.id + " get model")
model = train_config.model_ptr.get().obj
results["worker_id"] = worker.id
results["model"] = model
return results
def validate_model(identifier, model, dataloader, criterion):
model.eval() # changes the mode of the model, in evaluation mode we don't have dropout
loss = []
for i, (inputs,_) in enumerat(dataloader):
print("validation mode on")
#with torch.set_grad_enabled(False):
outputs, code = model(Variable(inputs)) #a tensor with 2 values: one for leak and one for no leak
loss = criterion(outputs, inputs)
loss = loss.sqrt()
loss.append(loss.item())
print("Loss = %.3f" % loss.data)
async def main():
args = define_and_get_arguments()
hook = sy.TorchHook(torch) #with this we can override some pytorch methods with pysyft
# Create WebsocketClientWorkers using IDs, Ports and IP addresses from arguments
worker_instances = []
for i in range(len(args.workers) // 3):
j = i * 3
worker_instances.append(websocket_client.WebsocketClientWorker(
id=args.workers[j], port=args.workers[j + 1], host=args.workers[j + 2], hook=hook, verbose=args.verbose))
model = initialize_model()
# optional loading of predefined model weights (= dictionary):
if args.basic_model:
model.load_state_dict(torch.load(args.basic_model))
# model serialization (creating an object of type ScriptModule):
model.eval()
traced_model = torch.jit.trace(model, torch.rand([1, 3, 224, 224], dtype=torch.float)) #we need to change the form of the model in order to make it
#serialisable and send it to the workers
# Data / picture transformation:
data_transforms = transforms.Compose([
transforms.Resize(INPUT_SIZE),
transforms.CenterCrop(INPUT_SIZE),
transforms.ToTensor()
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Create validation dataset and dataloader
validation_dataset = datasets.ImageFolder(os.path.join(args.dataset_path, 'val'), data_transforms)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)
# Create test dataset and dataloader
test_dataset = datasets.ImageFolder(os.path.join(args.dataset_path, 'test'), data_transforms)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)
# Lists to plot loss and acc after training
train_loss_values = []
train_acc_values = []
val_loss_values = []
val_acc_values = []
np.set_printoptions(formatter={"float": "{: .0f}".format})
for curr_round in range(1, args.training_rounds + 1):
logger.info("Training round %s/%s", curr_round, args.training_rounds)
print("entered training ")
# reduce learn rate every 5 training rounds (adaptive learnrate)
lr = args.lr * (0.1 ** ((curr_round - 1) // 5))
completed, pending = await asyncio.wait(
[
*[
train_model_on_worker(
worker=worker,
traced_model=traced_model,
dataset_key=DATASET_KEY,
batch_size=args.batch_size,
curr_round=curr_round,
lr=lr,
)
for worker in worker_instances
]
],
timeout=40
)
results = []
for entry in completed:
print("entry")
print(entry)
results.append(entry.result())
for entry in pending:
entry.cancel()
new_worker_instances = []
for entry in results:
for worker in worker_instances:
if (entry["worker_id"] == worker.id):
new_worker_instances.append(worker)
worker_instances = new_worker_instances
# Setup the loss function
criterion = torch.nn.functional.mse_loss()
#optimizer = optimizer_cls(autoencoder.parameters(), lr=lr)
# Federate models (note that this will also change the model in models[0]
models = {}
for worker in results:
if worker["model"] is not None:
models[worker["worker_id"]] = worker["model"]
logger.info("aggregation")
traced_model = utils.federated_avg(models)
logger.info("aggregation done")
# Evaluate federated model
logger.info("Validate..")
loss = validate_model("Federated", traced_model, validation_dataloader, criterion)
logger.info("Validation done")
val_loss_values.append(loss)
#val_acc_values.append(acc)
if __name__ == "__main__":
# Logging setup
date_time = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
FORMAT = "%(asctime)s | %(message)s"
logging.basicConfig(filename='logs/orchestrator_' + date_time + '.log', format=FORMAT)
logger = logging.getLogger("orchestrator")
logger.setLevel(level=logging.INFO)
asyncio.get_event_loop().run_until_complete(main())
The code of the workers:
def load_dataset(dataset_path):
"""Helper function for setting up the local datasets.
Args:
dataset_path (string): path to dataset, images must be arranged in this way
dataset_path/train/class1/xxx.jpg
dataset_path/train/class2/yyy.jpg
"""
data_transform = transforms.Compose([
transforms.RandomResizedCrop(INPUT_SIZE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
dataset = datasets.ImageFolder(os.path.join(dataset_path, 'train'), data_transform)
return dataset
def start_websocket_server(id, port, dataset, verbose):
"""Helper function for spinning up a websocket server.
Args:
id (str or id): the unique id of the worker.
port (int): the port on which the server should be run.
dataset: dataset, which the worker should provide.
verbose (bool): a verbose option - will print all messages sent/received to stdout.
"""
hook = sy.TorchHook(torch)
server = WebsocketServerWorker(id=id, host="0.0.0.0", port=port, hook=hook, verbose=verbose)
server.add_dataset(dataset, key=DATASET_KEY)
server.start()
return server
def _fit(self, model, dataset_key, loss_fn):
logger = logging.getLogger("worker")
logger.info(dataset_key)
print("dataset key")
model.train()
data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)
#data_loader = self._create_data_loader(
#dataset_key=dataset_key, shuffle=self.train_config.shuffle
#)
print("worker")
print(data_loader)
loss = None
iteration_count = 0
for _ in range(self.train_config.epochs):
for data in enumerate(data_loader):
# Set gradients to zero
self.optimizer.zero_grad()
# Update model
output,code = model(data)
logger.info(data)
logger.info(output)
loss = loss_fn(data, output)
loss.backward()
self.optimizer.step()
# Update and check interation count
iteration_count += 1
if iteration_count >= self.train_config.max_nr_batches >= 0:
break
return model
if __name__ == "__main__":
# Parse args
args = define_and_get_arguments()
# Logging setup
date_time = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
FORMAT = "%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d, p:%(process)d) - %(message)s"
logging.basicConfig(filename='logs/worker_' + args.id + '_' + date_time + '.log', format=FORMAT)
logger = logging.getLogger("worker")
logger.setLevel(level=logging.INFO)
# Load train dataset
dataset = load_dataset(args.dataset_path)
# Start server
server = start_websocket_server(
id=args.id,
port=args.port,
dataset=dataset,
verbose=args.verbose,
)
Does anybody know what the problem is?