if args.fused_backward_pass:
    # use fused optimizer for backward pass: other optimizers will be supported in the future
    import library.adafactor_fused

    library.adafactor_fused.patch_adafactor_fused(optimizer)
    for param_group in optimizer.param_groups:
        for parameter in param_group["params"]:
            if parameter.requires_grad:

                def __grad_hook(tensor: torch.Tensor, param_group=param_group):
                    if accelerator.sync_gradients and args.max_grad_norm != 0.0:
                        accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
                    optimizer.step_param(tensor, param_group)
                    tensor.grad = None

                parameter.register_post_accumulate_grad_hook(__grad_hook)
Edit Report
Pub: 31 Jan 2025 18:15 UTC
Edit: 31 Jan 2025 18:16 UTC
Views: 51