Федеративное обучение с использованием пользовательской модели в Pytorch / Pysyft - PullRequest
2 голосов
/ 02 апреля 2020

Я пытаюсь построить федеративную модель обучения. В моем сценарии у меня есть 3 рабочих и оркестратор. Рабочие начинают обучение, и в конце каждого цикла обучения модели отправляются в оркестратор, оркестратор вычисляет среднее значение по федерации и отправляет обратно новую модель, рабочие обучаются этой новой модели и т. Д. c. Пользовательская сеть - это AutoEncoder, который я построил с нуля.

К сожалению, я получаю это сообщение об ошибке от рабочих: RuntimeError: forward () отсутствует значение для аргумента 'input'. Объявление: forward (собственный ClassType, входы Tensor, выходы Tensor) -> (Tensor) , что странно, поскольку моя функция forward определена следующим образом внутри класса AE:

class AutoEncoder(nn.Module):

    def __init__(self, code_size):
        super().__init__()
        self.code_size = code_size

        # Encoder specification
        self.enc_cnn_1 = nn.Conv2d(3, 10, kernel_size=5)
        self.enc_cnn_2 = nn.Conv2d(10, 20, kernel_size=5)
        self.enc_linear_1 = nn.Linear(53 * 53 * 20, 50)
        self.enc_linear_2 = nn.Linear(50, self.code_size)

        # Decoder specification
        self.dec_linear_1 = nn.Linear(self.code_size, 160)
        self.dec_linear_2 = nn.Linear(160, IMAGE_SIZE)

    def forward(self, images):
        code = self.encode(images)
        out = self.decode(code)
        return out, code

    def encode(self, images):
        code = self.enc_cnn_1(images)
        code = F.selu(F.max_pool2d(code, 2))

        code = self.enc_cnn_2(code)
        code = F.selu(F.max_pool2d(code, 2))
        code = code.view([images.size(0), -1])
        code = F.selu(self.enc_linear_1(code))

        code = self.enc_linear_2(code)
        return code

    def decode(self, code):
        out = F.selu(self.dec_linear_1(code))
        out = F.sigmoid(self.dec_linear_2(out))
        out = out.view([code.size(0), 3, IMAGE_WIDTH, IMAGE_HEIGHT])
        return out


Loss function (cross entropy)

```
@torch.jit.script
def loss_fn(inputs, outputs):
    return torch.nn.functional.mse_loss(input=inputs, target=outputs)


def set_gradients(model, finetuning):
    """Helper function to exclude all gradients from training
    used for transfer learning in feature extract mode; See: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html

    Args:
        model (torch.nn.Module): model object.
        finetuning (bool):  if true, nothing will be changed; transfer learning will be used in finetuning mode, i.e., all gradients are trained;
                            if false, all gradients get excluded from training, used in feature extract mode
    """

    if not finetuning:
        for param in model.parameters():
            param.requires_grad = False
```


```
def initialize_model():
    model = AutoEncoder(code_size)
    set_gradients(model, False)
    return model
async def train_model_on_worker(
    worker: websocket_client.WebsocketClientWorker,
    traced_model: torch.jit.ScriptModule,
    dataset_key: str,
    batch_size: int,
    curr_round: int,
    lr: float,
):

    traced_model.train()
    print("train mode on")
    train_config = sy.TrainConfig(
        model=traced_model,
        loss_fn=loss_fn,
        batch_size=batch_size,
        shuffle=True,
        epochs=1,
        optimizer="Adam",
        optimizer_args={"lr": lr}
    )
logger.info(worker.id + " send trainconfig")
    train_config.send(worker)
    print("Model sent to the worker")
    logger.info(worker.id + " start training")
    await worker.async_fit(dataset_key=DATASET_KEY, return_ids=[0])
    logger.info(worker.id + " training done")
    results = dict()
logger.info(worker.id + " get model")
    model = train_config.model_ptr.get().obj

    results["worker_id"] = worker.id
    results["model"] = model

    return results
def validate_model(identifier, model, dataloader, criterion):
model.eval() # changes the mode of the model, in evaluation mode we don't have dropout

    loss = []
    for i, (inputs,_) in enumerat(dataloader):
        print("validation mode on")
        #with torch.set_grad_enabled(False):
        outputs, code  = model(Variable(inputs)) #a tensor with 2 values: one for leak and one for no leak
        loss = criterion(outputs, inputs)

        loss = loss.sqrt()
        loss.append(loss.item())
    print("Loss = %.3f" % loss.data)
async def main():
    args = define_and_get_arguments()
    hook = sy.TorchHook(torch) #with this we can override some pytorch methods with pysyft

    # Create WebsocketClientWorkers using IDs, Ports and IP addresses from arguments
    worker_instances = []
    for i in range(len(args.workers) // 3):
        j = i * 3
        worker_instances.append(websocket_client.WebsocketClientWorker(
            id=args.workers[j], port=args.workers[j + 1], host=args.workers[j + 2], hook=hook, verbose=args.verbose))

    model = initialize_model()

    # optional loading of predefined model weights (= dictionary):
    if args.basic_model:
        model.load_state_dict(torch.load(args.basic_model))

    # model serialization (creating an object of type ScriptModule):
    model.eval()
    traced_model = torch.jit.trace(model, torch.rand([1, 3, 224, 224], dtype=torch.float)) #we need to change the form of the model in order to make it 
    #serialisable and send it to the workers

    # Data / picture transformation:
    data_transforms = transforms.Compose([
        transforms.Resize(INPUT_SIZE),
        transforms.CenterCrop(INPUT_SIZE),
        transforms.ToTensor()
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Create validation dataset and dataloader
    validation_dataset = datasets.ImageFolder(os.path.join(args.dataset_path, 'val'), data_transforms)
    validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)

    # Create test dataset and dataloader
    test_dataset = datasets.ImageFolder(os.path.join(args.dataset_path, 'test'), data_transforms)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)

    # Lists to plot loss and acc after training
    train_loss_values = []
    train_acc_values = []
    val_loss_values = []
    val_acc_values = []

    np.set_printoptions(formatter={"float": "{: .0f}".format})

    for curr_round in range(1, args.training_rounds + 1):
        logger.info("Training round %s/%s", curr_round, args.training_rounds)
        print("entered training ")
        # reduce learn rate every 5 training rounds (adaptive learnrate)
        lr = args.lr * (0.1 ** ((curr_round - 1) // 5))
        completed, pending = await asyncio.wait(
            [
                *[
                    train_model_on_worker(
                        worker=worker,
                        traced_model=traced_model,
                        dataset_key=DATASET_KEY,
                        batch_size=args.batch_size,
                        curr_round=curr_round,
                        lr=lr,
                    )
                    for worker in worker_instances
                ]
            ],
            timeout=40
        )

        results = []
        for entry in completed:
            print("entry")
            print(entry)
            results.append(entry.result())

        for entry in pending:
            entry.cancel()

        new_worker_instances = []
        for entry in results:
            for worker in worker_instances:
                if (entry["worker_id"] == worker.id):
                    new_worker_instances.append(worker)

        worker_instances = new_worker_instances


        # Setup the loss function
        criterion = torch.nn.functional.mse_loss()
        #optimizer = optimizer_cls(autoencoder.parameters(), lr=lr)

        # Federate models (note that this will also change the model in models[0]
        models = {}
        for worker in results:
            if worker["model"] is not None:
                models[worker["worker_id"]] = worker["model"]

        logger.info("aggregation")
        traced_model = utils.federated_avg(models)
        logger.info("aggregation done")

        # Evaluate federated model
        logger.info("Validate..")
        loss = validate_model("Federated", traced_model, validation_dataloader, criterion)
        logger.info("Validation done")
        val_loss_values.append(loss)
        #val_acc_values.append(acc)
if __name__ == "__main__":
    # Logging setup
    date_time = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
    FORMAT = "%(asctime)s | %(message)s"
    logging.basicConfig(filename='logs/orchestrator_' + date_time + '.log', format=FORMAT)
    logger = logging.getLogger("orchestrator")
    logger.setLevel(level=logging.INFO)

    asyncio.get_event_loop().run_until_complete(main())

The code of the workers:
def load_dataset(dataset_path):
    """Helper function for setting up the local datasets.

    Args:
        dataset_path (string):  path to dataset, images must be arranged in this way
                                dataset_path/train/class1/xxx.jpg
                                dataset_path/train/class2/yyy.jpg
    """

    data_transform = transforms.Compose([
        transforms.RandomResizedCrop(INPUT_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    dataset = datasets.ImageFolder(os.path.join(dataset_path, 'train'), data_transform)

    return dataset


def start_websocket_server(id, port, dataset, verbose):
    """Helper function for spinning up a websocket server.

    Args:
        id (str or id): the unique id of the worker.
        port (int): the port on which the server should be run.
        dataset: dataset, which the worker should provide.
        verbose (bool): a verbose option - will print all messages sent/received to stdout.
    """

    hook = sy.TorchHook(torch)
    server = WebsocketServerWorker(id=id, host="0.0.0.0", port=port, hook=hook, verbose=verbose)
    server.add_dataset(dataset, key=DATASET_KEY)
    server.start()

    return server

    def _fit(self, model, dataset_key, loss_fn):
        logger = logging.getLogger("worker")
        logger.info(dataset_key)
        print("dataset key")
        model.train()
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)
        #data_loader = self._create_data_loader(
            #dataset_key=dataset_key, shuffle=self.train_config.shuffle
        #)
        print("worker")
        print(data_loader)

        loss = None
        iteration_count = 0

        for _ in range(self.train_config.epochs):
            for data in enumerate(data_loader):
                # Set gradients to zero
                self.optimizer.zero_grad()

                # Update model
                output,code = model(data)
                logger.info(data)
                logger.info(output)
                loss = loss_fn(data, output)
                loss.backward()
                self.optimizer.step()

                # Update and check interation count
                iteration_count += 1
                if iteration_count >= self.train_config.max_nr_batches >= 0:
                    break

        return model



if __name__ == "__main__":

    # Parse args
    args = define_and_get_arguments()

    # Logging setup
    date_time = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
    FORMAT = "%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d, p:%(process)d) - %(message)s"
    logging.basicConfig(filename='logs/worker_' + args.id + '_' + date_time + '.log', format=FORMAT)
    logger = logging.getLogger("worker")
    logger.setLevel(level=logging.INFO)

    # Load train dataset
    dataset = load_dataset(args.dataset_path)

    # Start server
    server = start_websocket_server(
        id=args.id,
        port=args.port,
        dataset=dataset,
        verbose=args.verbose,
    )
Does anybody know what the problem is?
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...