PyTorchでバッチ正規化をカスタマイズする batch_norm()の使い方

Python

Batch Normalizationをカスタマイズ

PyTorchでバッチ正規化の中身を書き換える方法を紹介します。

バッチ正規化のパラメータは入力、重み、バイアス、平均、分散(ルートを取るので標準偏差)、ε(ゼロ除算を避けるためのごく小さい値)です。

ここでself.running_meanが移動平均、self.running_varが移動分散にあたります。

学習時、バッチごとに入力の平均を取りますが推論時には入力データは画像1枚なので平均が取れません。そこで移動平均を使います。なのでself.running_meanという名前です。

分散も同様に、学習時はバッチごとの分散を求めていますが、推論では移動分散を用います(self.running_var)。

これらパラメータにアクセスし、Batch Normalizationをカスタマイズするには下記のようにMyBatchNorm1dクラスを作成します。1dなので全結合用です。

ここでは試しに移動分散を小数点以下を丸めて整数にしてからバッチ正規化する処理を実装しました。

class MyBatchNorm1d(nn.BatchNorm1d):
    def __init__(self, *kargs, **kwargs):
        super(MyBatchNorm1d, self).__init__(*kargs, **kwargs)

    def forward(self, input):

        self.running_var.data = round(self.running_var.data)
        
        out = F.batch_norm(input, self.running_mean, self.running_var,self.weight, self.bias, self.training, self.momentum, self.eps)
        
        return out

最後のF.batch_norm()でバッチ正規化しています。

out = F.batch_norm(input, self.running_mean, self.running_var,self.weight, self.bias, self.training, self.momentum, self.eps)

畳み込み処理に現れるバッチ正規化は2dのほうを使います。

class MyBatchNorm2d(nn.BatchNorm1d):
    def __init__(self, *kargs, **kwargs):
        super(MyBatchNorm2d, self).__init__(*kargs, **kwargs)

    def forward(self, input):

        self.running_var.data = round(self.running_var.data)
        
        out = F.batch_norm(input, self.running_mean, self.running_var,self.weight, self.bias, self.training, self.momentum, self.eps)
        
        return out

実際にネットワークモデルにカスタムバッチ正規化を入れると下記のようなPyhonコードになります。今回はCIFAR10のCNN画像分類を想定しAlexNetベースで作ってみました。

from modules import *

class AlexNet(nn.Module):

    def __init__(self):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3,padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            MyBatchNorm2d(64),
           
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2),
            MyBatchNorm2d(192),
            
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            MyBatchNorm2d(384),
            
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
          
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            MyBatchNorm2d(256),
           
        )
        self.classifier = nn.Sequential(
            
            nn.Linear(256 * 4 * 4, 1024),
            MyBatchNorm1d(1024),
         
            nn.Linear(1024, 1024),
            MyBatchNorm1d(1024),

            nn.Linear(1024, 10),
            MyBatchNorm1d(10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 256 * 4 * 4)
        x = self.classifier(x)
        return x

ニューラルネットワークのカスタマイズ

バッチ正規化のカスタマイズ方法を紹介しました。

バッチ正規化処理には入力、平均、分散などのパラメータがあります。これらパラメータに自分で処理を加えることができます。

バッチ正規化のクラスを作成します。

MyBatchNorm2d(nn.BatchNorm1d):

そのクラス内でバッチ正規化関数を呼び出すことでバッチ正規化が完了します。

F.batch_norm(input, self.running_mean, self.running_var,self.weight, self.bias, self.training, self.momentum, self.eps)

今回はバッチ正規化を自作する方法を紹介しました。この他にも畳み込み演算や全結合もカスタマイズすることができます。

参考資料

コメント

タイトルとURLをコピーしました