【PyTorch】入力・重みを量子化する – Binarized Neural Network 畳み込み層・全結合層をカスタマイズ

Python

PyTorch 入力値を-1,+1の2値に量子化する

ニューラルネットワークの入力xに自作の関数を適応する方法を紹介します。ライブラリはPyTorchです。

入力xや重みwを量子化することで画像分類や物体検知を効率化する研究が行われています。

以前書いた記事では重みweightを2値に量子化する方法を説明しましたが今回は重みに加え、入力inputも量子化する方法を解説します。

重みに関数を適応するときは単純に

self.weight.data = Binarize(self.weight.data)

とするだけでした。ここでBinarize()は2値化をするための自作関数です。

self.weight.data = self.weight.data.sign()

これでも似たような2値化ができますが厳密には+1に量子化される値に0を含むかどうかが違うみたいです。あとで調べておきます。

入力も同じように

input.data = Binarize(input.data)

これでいいのではと思うかもしれませんが、単純にこれを加えただけだと上手く画像認識できません。

ここでは誤差逆伝搬(Back Propagation)が具体的にどんな処理なのかを理解する必要があります。

誤差逆伝搬では入力値を微分しています。しかし、2値に量子化する場合、2値化関数を通しているため上手く微分できないです。2値化では0以上の値を+1、0より小さい値を-1としています。この関数を微分してしまうと0付近の微分が無限に発散してしまいます。

そこで誤差逆伝搬のときは2値化関数を別の関数で置き換えます。これがHtanhと呼ばれるものです。

Htanhは-1以下の数値を-1、+1以上の数値を+1にクリッピングする関数です。-1と+1の間の数値はそのまま出力されます。

要は、-1以下の値と+1以下の値は微分すると0になり、-1と+1の間にある数値は微分すると1になるということです。

この処理をPyTorchで実装します。function.pyファイルを作成しそこへbinarize関数を記述します。先人のgitに載っていたクラスをそのままコピペしました。

class BinarizeF(Function):

    @staticmethod
    def forward(cxt, input):
        output = input.new(input.size())
        output[input >= 0] = 1
        output[input < 0] = -1
        return output

    @staticmethod
    def backward(cxt, grad_output):
        grad_input = grad_output.clone()
        return grad_input

# aliases
binarize = BinarizeF.apply

def forward():が順伝搬、def backward():が逆伝搬です。

これだと処理が遅かったのでちょっとアレンジしたのが以下です。結構速くなりました。

def Binarize(tensor):
    th = 0.0
    out = (tensor >= th).type(torch.cuda.FloatTensor) - (tensor < -th).type(torch.cuda.FloatTensor)
    return out

class BinarizeF(Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return Binarize(input.data)
 
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input.ge(1)]=0
        grad_input[input.le(-1)]=0
        return grad_input

# aliases
binarize = BinarizeF.apply

これをmodules.py内のMyConv2dクラスに適応します。入力・重みをbinarize()で2値化します。

modules.pyファイルを作成しました。

from functions import *

class BinaryConv2d(nn.Conv2d):

    def __init__(self, *kargs, **kwargs):
        super(BinaryConv2d, self).__init__(*kargs, **kwargs)

    def forward(self, input):

        input.data = binarize(input.data)
       
        self.weight.data = binarize(self.weight.data)
  
        out = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        
        return out

最後にモデルを作成します。

from modules import *

class MyNet(nn.Module):

    def __init__(self):
        super(MyNet, self).__init__()
        self.features = nn.Sequential(
  
            BinaryConv2d(3, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),

            BinaryConv2d(128, 128, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(128),

            BinaryConv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),

            BinaryConv2d(256, 256, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(256),

            BinaryConv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),

            BinaryConv2d(512, 512, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(512)
        )
        self.classifier = nn.Sequential(
    
            BinaryLinear(512 * 4 * 4, 1024),
            nn.BatchNorm1d(1024),
        
            BinaryLinear(1024, 1024),
            nn.BatchNorm1d(1024),
          
            BinaryLinear(1024, 10),
            nn.BatchNorm1d(10)
            
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 512 * 4 * 4)
        x = self.classifier(x)
        return x

最後に

以上がPyTorchで入力画像を量子化する方法です。重み同様にそのまま入力を量子化関数に通すとバックプロパゲーションがうまい行きません。

順伝搬と誤差逆伝搬を別で記述し、誤差逆伝搬ではHtanhとみなすことで正常に量子化できるようになります。

参考資料を載せておきます。

参考資料

コメント

タイトルとURLをコピーしました