【PyTorch】2値化ニューラルネットワークで認識精度を上げる – 重み更新のアルゴリズム

Python

入力と重みを2値化する

前回の記事では入力を2値に量子化しました。しかし、自分で書いておいてなんですが、あのままでは認識精度が著しく低いです。

より実用的な量子化ニューラルネットワークのために重みの更新手法を工夫します。

【誤差逆伝搬】重みの更新

重みを更新するとき、量子化された重みを誤差逆伝搬のときに更新していました。

W ← W_org – δ×∂L/∂W

ここのW_orgが量子化された数値(2値化では-1,+1)でした。

W_orgを量子化前の浮動小数点精度に変更することで精度が向上します。

PyTorchで実装するには、まず重みを量子化する前に浮動小数点精度の重みを保持しておきます。順伝搬は量子化後の値で積和演算します。

そして誤差逆伝搬で重みを更新するのですが、このときの重みは浮動小数点精度を使います。

PyTorchで実装

実際にPyTorchで作成していきます。先人のGitを元にしています。

modules.pyのBinaryConv2dクラスを書き換えます。

class BinaryConv2d(nn.Conv2d):

    def __init__(self, *kargs, **kwargs):
        super(BinaryConv2d, self).__init__(*kargs, **kwargs)

    def forward(self, input):
       
        input.data = binarize(input.data)
 
        if not hasattr(self.weight,'org'):
            self.weight.org=self.weight.data.clone()
        self.weight.data =binarize(self.weight.org)
        
        out = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

        return out

次にmain.py内のtrain関数を工夫します。重みに限らず、バイアスにも同様の更新手法を適応したいのでmodel.parameters()からパラメータを呼び出してます。

def train(args,epoch_index,train_loader,model,optimizer,criterion):
    #global best_acc
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.to('cuda:0'), target.to('cuda:0')
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)

        loss.backward()
        
        ############ weight update ###########
        for p in list(model.parameters()):
            if hasattr(p,'org'):
                p.data.copy_(p.org)
        
        optimizer.step()
        
        for p in list(model.parameters()):
            if hasattr(p,'org'):
                p.org.copy_(p.data.clamp_(-1,1))
       #########################################

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch_index, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

ここで、パラメータ更新のあとに(-1,1)にクリップしています。-1~+1の範囲に抑えることで精度が良くなるらしいです。以下のようにクリップされています。

for p in list(model.parameters()):
            if hasattr(p,'org'):
                p.org.copy_(p.data.clamp_(-1,1))

最後に

入力・重みを量子化するとどうしても精度が落ちてしまいますが、重み更新の方法を工夫することで認識精度が向上します。

コメント

タイトルとURLをコピーしました