Сталкивались ли вы с подобной проблемой, как потеря дрожания во время тренировки? - PullRequest
0 голосов
/ 20 марта 2020

Справочная информация: Речь идет о дрожании потерь, которое генерируется на начальной стадии каждой тренировочной эпохи. Когда загрузчик данных загружает данные первой партии для подачи в сеть, значение потерь всегда внезапно возрастает, затем возвращается ко второй норме и продолжает уменьшаться. Кривая такая странная. Мне нужна твоя помощь!

image
    for epoch in range(begin_epoch, end_epoch):
        print('PROGRESS: %.2f%%' % (100.0 * epoch / end_epoch))

        # set epoch as random seed of sampler while distributed training
        if train_sampler is not None and hasattr(train_sampler, 'set_epoch'):
            train_sampler.set_epoch(epoch)

        # reset metrics
        metrics.reset()

        # set net to train mode
        net.train()

        # clear the paramter gradients
        # optimizer.zero_grad()

        # init end time
        end_time = time.time()

        if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            name, value = validation_monitor.metrics.get()
            val = value[name.index(validation_monitor.host_metric_name)]
            lr_scheduler.step(val, epoch)

        # training
        train_loader_iter = iter(train_loader)
        for nbatch in range(total_size):
            try:
                batch = next(train_loader_iter)
            except StopIteration:
                print('reset loader .. ')
                train_loader_iter = iter(train_loader)
                batch = next(train_loader_iter)
            global_steps = total_size * epoch + nbatch

            os.environ['global_steps'] = str(global_steps)

            # record time
            data_in_time = time.time() - end_time

            # transfer data to GPU
            data_transfer_time = time.time()
            batch = to_cuda(batch)
            data_transfer_time = time.time() - data_transfer_time

            # forward
            forward_time = time.time()
            outputs, loss = net(*batch)
            loss = loss.mean()
            if gradient_accumulate_steps > 1:
                loss = loss / gradient_accumulate_steps
            forward_time = time.time() - forward_time

            # backward
            backward_time = time.time()
            if fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            backward_time = time.time() - backward_time

            optimizer_time = time.time()
            if (global_steps + 1) % gradient_accumulate_steps == 0:

                # clip gradient
                if clip_grad_norm > 0:
                    if fp16:
                        total_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                                                    clip_grad_norm)
                    else:
                        total_norm = torch.nn.utils.clip_grad_norm_(net.parameters(),
                                                                    clip_grad_norm)
                    if writer is not None:
                        writer.add_scalar(tag='grad-para/Total-Norm',
                                        scalar_value=float(total_norm),
                                        global_step=global_steps)

                optimizer.step()

                # step LR scheduler
                if lr_scheduler is not None and not isinstance(lr_scheduler,
                                                            torch.optim.lr_scheduler.ReduceLROnPlateau):
                    lr_scheduler.step()

                # clear the parameter gradients
                optimizer.zero_grad()
            optimizer_time = time.time() - optimizer_time

            # update metric
            metric_time = time.time()
            metrics.update(outputs)
            if writer is not None and nbatch % 50 == 0:
                with torch.no_grad():
                    for group_i, param_group in enumerate(optimizer.param_groups):
                        writer.add_scalar(tag='Initial-LR/Group_{}'.format(group_i),
                                        scalar_value=param_group['initial_lr'],
                                        global_step=global_steps)
                        writer.add_scalar(tag='LR/Group_{}'.format(group_i),
                                        scalar_value=param_group['lr'],
                                        global_step=global_steps)
                    writer.add_scalar(tag='Train-Loss',
                                    scalar_value=float(loss.item()),
                                    global_step=global_steps)
                    name, value = metrics.get()
                    for n, v in zip(name, value):
                        if 'Logits' in n:
                            writer.add_scalar(tag='Train-Logits/' + n,
                                            scalar_value=v,
                                            global_step=global_steps)
                        else:
                            writer.add_scalar(tag='Train-' + n,
                                            scalar_value=v,
                                            global_step=global_steps)
                    for k, v in outputs.items():
                        if 'score' in k:
                            writer.add_histogram(tag=k,
                                                 values=v,
                                                 global_step=global_steps)

            metric_time = time.time() - metric_time

1 Ответ

0 голосов
/ 20 марта 2020

В вашем наборе данных есть партия с большими потерями, вот и все.

Нередко люди сохраняют метрики для каждого пакета, обычно это среднее значение за эпоху (или среднее значение за несколько шагов пакета). Вы не увидите таких пиков, если будете хранить средние значения.

Вы также можете уменьшить эти пики, перетасовывая свои данные, чтобы проблемный пакет c распространился по всей эпохе. В целом, это хорошая практика в начале каждой эпохи.

...