foreach関数によるoptimizerの高速化。実装例コード有り。

foreach関数とは

 PyTorchには、torch._foreach_addやtorch._foreach_expといった、foreachの接頭辞を持つ関数が存在する。このforeach関数とは、テンソルのリストを引数として受け取り、その関数の名に含まれる処理をそのリスト内の各テンソルに施す関数である。この関数にはどのような意義があるのだろうか?

 foreach関数の意義は、「多くの」「小さい」テンソルに対する処理を高速化することにある。その一例が、optimizer内での処理である。

 一般的に、optimizerは各パラメータあたり3,4個のCUDAカーネルを呼び出す。複雑なoptimizerであれば、10個以上になることもある。そしてそのようなCUDAカーネルは、数が多いものの、一つ一つのサイズは小さい。そのためoptimizer内では往々にして、計算を行うGPUではなく、CUDAカーネルの呼び出しを行うCPUがボトルネックになる。foreach関数は少ないCUDAカーネル呼び出し回数で同じ処理を行うことで、optimizer内のCPUボトルネックを解消することができる。

通常のAdamWとforeach関数を用いたAdamWの実装例
# 通常のtorch関数を使ったAdamW
import torch

class AdamW(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=1e-2,
    ):
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
        )
        super().__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:

            lr = group["lr"]
            beta1, beta2 = group["betas"]
            weight_decay = group["weight_decay"]
            eps = group["eps"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]
                if len(state) == 0:
                    state["step_t"] = (torch.tensor(0.0))
                    state["exp_avg"] = torch.zeros_like(p)
                    state["exp_avg_sq"] = torch.zeros_like(p)

                step_t = state["step_t"]
                exp_avg = state["exp_avg"]
                exp_avg_sq = state["exp_avg_sq"]

                step_t += 1

                # 重み減衰
                p.mul_(1 - lr * weight_decay)

                # 勾配の指数平均を更新
                exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)

                bias_correction1 = 1 - beta1 ** step_t
                bias_correction2 = 1 - beta2 ** step_t

                step_size = lr / bias_correction1

                denom = (exp_avg_sq / bias_correction2).sqrt().add_(eps)
                
                # パラメータを更新
                p.addcdiv_(exp_avg, denom, value=-step_size)

        return loss
# foreach関数を使ったAdamW
import torch
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
import math

class AdamW_foreach(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=1e-2,
    ):
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
        )
        super().__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:

            lr = group["lr"]
            beta1, beta2 = group["betas"]
            weight_decay = group["weight_decay"]
            eps = group["eps"]
            # foreach関数のためのテンソルのリスト
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            state_steps = []

            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]
                if len(state) == 0:
                    state["step_t"] = (torch.tensor(0.0))
                    state["exp_avg"] = torch.zeros_like(p)
                    state["exp_avg_sq"] = torch.zeros_like(p)

                params_with_grad.append(p)
                grads.append(p.grad)
                exp_avgs.append(state["exp_avg"])
                exp_avg_sqs.append(state["exp_avg_sqs"])
                state_steps.append(state["step_t"])

            if len(params_with_grad) == 0:
                return
            
            # テンソルをデータ型とデバイスによって分類
            grouped_tensors = _group_tensors_by_device_and_dtype([
                params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
            ]
            )

            for (
                device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,device_state_steps
            ) in grouped_tensors.values():
                
                torch._foreach_add_(device_state_steps, 1)

                # 重み減衰
                torch._foreach_mul_(device_params, 1 - lr * weight_decay)

                # 勾配の指数平均を更新
                torch._foreach_mul_(device_exp_avgs, beta1)
                torch._foreach_add_(device_exp_avgs, device_grads, alpha=1 - beta1)
        
                torch._foreach_mul_(device_exp_avg_sqs, beta2)
                torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)
        
                bias_correction1 = [1 - beta1 ** step for step in device_state_steps]
                bias_correction2 = [1 - beta2 ** step for step in device_state_steps]
 
                step_size = [lr / bc for bc in bias_correction1]

                exp_avg_sq_sqrts = torch._foreach_div_(torch._foreach_sqrt(device_exp_avg_sqs), [math.sqrt(bc) in bias_correction2])
                denom = torch._foreach_add(exp_avg_sq_sqrts, eps)

                # パラメータを更新
                torch._foreach_addcdiv_(device_params, device_exp_avgs, denom, step_size)

        return loss

【2023年3月】sota optimizer(最適化手法)まとめ

2023年3月現在のsota optimizer

MADGRAD, Adahessian, Ali-G, Lion

MADGRAD

momentumとdual averagingを用いた、Adagrad系列の最適化手法

【プラスポイント】

・mirror descentよりも理論的前提条件が簡素なdual averagingを使用。

・dual averagingによって、各ステップに依存し、かつ学習の進行と共に弱まっていく正則化が導入されることを証明。

・Adamが良い成果を収められる問題と収められない問題、どちらにおいても、AdamやSGDと同等以上の精度を達成。

・Adamと違い、sparseなモデルでも使用可能。

・ハイパーパラメータのグリッドサーチをすべての最適化手法について行い比較することで、実際の優位性を報告。

【マイナスポイント】

・勾配の和をもとにモデルパラメータを更新するため、学習の進行と共に新しい情報を活用できなくなっていく。

・パラメータと同サイズのテンソルを3つ保持していなくてはならず、Adamと比較してメモリ使用量が多い(1.5倍)。

Adahessian

適応的に学習率を変化させる二次最適化手法

【プラスポイント】

・実用的な二次最適化手法

・Hutchinson近似の応用により、対角Hessianの計算を高速化。

・複数のタスクにおいてSGD, Adam, AdamWと同等以上の精度を達成。

・AdamWよりも学習率の選択に寛容。

・Hutchinson近似の計算頻度を減少させることによって、計算量を抑えながらもほぼ同等の精度を達成可能。

【マイナスポイント】

・一次最適化手法と比較して約2倍の計算時間とメモリ使用量。

・グリッドサーチを行わず、複数の最適化手法で同じパラメータを使用しているため、ハイパーパラメータをチェリーピッキングしている可能性がある。

Ali-G

1.目的関数の最小値が既知の場合、非確率的勾配の方向の適応的学習率を計算することができる

2.内挿モデルにおいて、そのような最小値は、おおよその値が知られている

この二つの前提から導かれた最適化手法

【プラスポイント】

・AdamW, Adagrad, AMSGrad, Yogi, DFW, L4Adam, L4Mom, SGDの内、各タスクにおいて最も精度が高かった手法と同等以上の精度を達成。

・一つのハイパーパラメータで、学習率パラメータと学習率スケジュールを代替。

・理論的に収束を証明。

SGDとほぼ同じ計算量とメモリ使用量。

【マイナスポイント】

・比較実験において各最適化手法のハイパーパラメータ探索をある程度行っているものの、Ali-Gの優位性を確実に示すには不足している。

・損失の減少と共に更新量が少なくなっていくことで、ローカルミニマから抜け出しづらくなっていく。

Lion

進化的アルゴリズムによって導かれた最適化手法

【プラスポイント】

・複数のモデルとタスクにおいて、AdamWとSGDのどちらよりも優れた精度を達成。

・Adamよりも少ない計算量とメモリ使用量。

【マイナスポイント】

・比較実験において各最適化手法のハイパーパラメータ探索を行ったとは報告しているものの、具体的にどのように探索したのかは不明。

・強いaugmentationや低いバッチサイズを用いた場合、そしていくつかのタスクにおいては既存手法に対する優位性を示せず。

・進化アルゴリズムにおける探索範囲が限られており、AdamやSGDといった主要一次最適化手法への強いバイアスがある。

・更新量は常に一定であるため、学習率スケジュールの使用が必須。