【PyTorch】重みの分布を可視化する – matplotlibでヒストグラムの作成

Python

重みを可視化したいモデル

PyTorchで下記のようなネットワークモデルを作成し、その重みの値をグラフにプロットすることを考えます。

class MyNet(nn.Module):

    def __init__(self):
        super(MyNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1,bias=True),#0
            nn.BatchNorm2d(128),#1

            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True),#2
            nn.MaxPool2d(kernel_size=2, stride=2),#3
            nn.BatchNorm2d(128),#4
        
            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True),#5
            nn.BatchNorm2d(256),#6
          
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True),#7
            nn.MaxPool2d(kernel_size=2, stride=2),#8
            nn.BatchNorm2d(256),#9
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=True),#10
            nn.BatchNorm2d(512),#11
           
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True),#12
            nn.MaxPool2d(kernel_size=2, stride=2),#13
            nn.BatchNorm2d(512),#14
        )
        self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 1024, bias=True),#0
            nn.BatchNorm1d(1024),#1

            nn.Linear(1024, 1024, bias=True),#2
            nn.BatchNorm1d(1024),#3

            nn.Linear(1024, 10, bias=True),#4
            nn.BatchNorm1d(10),#5
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 512 * 4 * 4)
        x = self.classifier(x)
        return x

重み分布をmatplotlibでヒストグラムにプロットする

以下の関数を作成し、学習後にmain.py内で関数を呼び出します。今回は9層からなるモデルだったため縦3つ×横3つのヒストグラムを作成します。

features.0.weightが第1層目の重みを表します。features.0.biasとするとバイアスを取り出すことができます。

def plothist(model,N=9):
    for n,p in model.to('cpu').named_parameters():
        
        if n in ['features.0.weight','features.2.weight','features.5.weight','features.7.weight','features.10.weight','features.12.weight','classifier.0.weight','classifier.2.weight','classifier.4.weight']:
       
            print(n)
            plt.subplot(3,3,N)
            N=N+1
            tmp0 = p.detach()
            tmp1 = tmp0.numpy()
            w = tmp1.flatten()
            
            plt.hist(w, bins=100,color='deepskyblue')
            plt.gca().yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
            plt.gca().yaxis.offsetText.set_fontsize(3)
            plt.gca().ticklabel_format(style="sci",  axis="y",scilimits=(0,0))
            plt.tick_params(labelsize=5)
            
    plt.savefig(hist.png')
    plt.show()

main.py内のどこかで関数を呼び出して重み分布を可視化します。

plothist(model,N=9)

おわりに

PyTorchで作成したモデルの重みやバイアス等パラメータをグラフにプロットする方法を紹介しました。

matplotlibを使えば簡単にヒストグラムを作成できるのでぜひ挑戦してみてください。

コメント

タイトルとURLをコピーしました