foreach関数によるoptimizerの高速化。実装例コード有り。

foreach関数とは

 PyTorchには、torch._foreach_addやtorch._foreach_expといった、foreachの接頭辞を持つ関数が存在する。このforeach関数とは、テンソルのリストを引数として受け取り、その関数の名に含まれる処理をそのリスト内の各テンソルに施す関数である。この関数にはどのような意義があるのだろうか?

 foreach関数の意義は、「多くの」「小さい」テンソルに対する処理を高速化することにある。その一例が、optimizer内での処理である。

 一般的に、optimizerは各パラメータあたり3,4個のCUDAカーネルを呼び出す。複雑なoptimizerであれば、10個以上になることもある。そしてそのようなCUDAカーネルは、数が多いものの、一つ一つのサイズは小さい。そのためoptimizer内では往々にして、計算を行うGPUではなく、CUDAカーネルの呼び出しを行うCPUがボトルネックになる。foreach関数は少ないCUDAカーネル呼び出し回数で同じ処理を行うことで、optimizer内のCPUボトルネックを解消することができる。

通常のAdamWとforeach関数を用いたAdamWの実装例
# 通常のtorch関数を使ったAdamW
import torch

class AdamW(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=1e-2,
    ):
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
        )
        super().__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:

            lr = group["lr"]
            beta1, beta2 = group["betas"]
            weight_decay = group["weight_decay"]
            eps = group["eps"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]
                if len(state) == 0:
                    state["step_t"] = (torch.tensor(0.0))
                    state["exp_avg"] = torch.zeros_like(p)
                    state["exp_avg_sq"] = torch.zeros_like(p)

                step_t = state["step_t"]
                exp_avg = state["exp_avg"]
                exp_avg_sq = state["exp_avg_sq"]

                step_t += 1

                # 重み減衰
                p.mul_(1 - lr * weight_decay)

                # 勾配の指数平均を更新
                exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)

                bias_correction1 = 1 - beta1 ** step_t
                bias_correction2 = 1 - beta2 ** step_t

                step_size = lr / bias_correction1

                denom = (exp_avg_sq / bias_correction2).sqrt().add_(eps)
                
                # パラメータを更新
                p.addcdiv_(exp_avg, denom, value=-step_size)

        return loss
# foreach関数を使ったAdamW
import torch
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
import math

class AdamW_foreach(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=1e-2,
    ):
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
        )
        super().__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:

            lr = group["lr"]
            beta1, beta2 = group["betas"]
            weight_decay = group["weight_decay"]
            eps = group["eps"]
            # foreach関数のためのテンソルのリスト
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            state_steps = []

            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]
                if len(state) == 0:
                    state["step_t"] = (torch.tensor(0.0))
                    state["exp_avg"] = torch.zeros_like(p)
                    state["exp_avg_sq"] = torch.zeros_like(p)

                params_with_grad.append(p)
                grads.append(p.grad)
                exp_avgs.append(state["exp_avg"])
                exp_avg_sqs.append(state["exp_avg_sqs"])
                state_steps.append(state["step_t"])

            if len(params_with_grad) == 0:
                return
            
            # テンソルをデータ型とデバイスによって分類
            grouped_tensors = _group_tensors_by_device_and_dtype([
                params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
            ]
            )

            for (
                device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,device_state_steps
            ) in grouped_tensors.values():
                
                torch._foreach_add_(device_state_steps, 1)

                # 重み減衰
                torch._foreach_mul_(device_params, 1 - lr * weight_decay)

                # 勾配の指数平均を更新
                torch._foreach_mul_(device_exp_avgs, beta1)
                torch._foreach_add_(device_exp_avgs, device_grads, alpha=1 - beta1)
        
                torch._foreach_mul_(device_exp_avg_sqs, beta2)
                torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)
        
                bias_correction1 = [1 - beta1 ** step for step in device_state_steps]
                bias_correction2 = [1 - beta2 ** step for step in device_state_steps]
 
                step_size = [lr / bc for bc in bias_correction1]

                exp_avg_sq_sqrts = torch._foreach_div_(torch._foreach_sqrt(device_exp_avg_sqs), [math.sqrt(bc) in bias_correction2])
                denom = torch._foreach_add(exp_avg_sq_sqrts, eps)

                # パラメータを更新
                torch._foreach_addcdiv_(device_params, device_exp_avgs, denom, step_size)

        return loss