까먹으면 적어두자

PyTorch 모델과 state_dict - 1 본문

인공지능

PyTorch 모델과 state_dict - 1

whiteglass 2021. 4. 6. 15:37

개요

Pytorch는 모델의 실제 내용 (가중치, 양자화 방법, 옵티마이저 등)을 저장할 때 state_dict라는 파이썬 딕셔너리에 저장한다.

따라서 Pytorch의 가중치를 직접 조작하거나 살펴볼때는 state_dict을 통해 가중치에 접근해야 한다.

 

여기서 조금 헷갈리는데 모델 그 자체와 state_dict에 접근하는 것은 조금 다른 의미를 가진다는 것이다.

 

예시

예를 들어 다음과 같은 모델을 정의했다.

import torch
import torch.nn as nn

class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        self.w1 = nn.Linear(2, 10)
        self.bias1 = torch.zeros([10])

        self.w2 = nn.Linear(10, 3)
        self.bias2 = torch.zeros([3])
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=0)

    def forward(self, x):
        y = self.w1(x) + self.bias1
        y = self.relu(y)

        y = self.w2(y) + self.bias2
        return y


model = DNN()

DNN이라는 모델의 객체인 model을 print(model) 을 통해 출력하면

DNN(
  (w1): Linear(in_features=2, out_features=10, bias=True)
  (w2): Linear(in_features=10, out_features=3, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=0)
)

이렇게 모델의 구조에 대해서 summary한 내용이 출력된다. 

 

만약 직접 가중치를 보고 좀더 자세한 내용을 위해서는 print(model.state_dict()) 를 사용해야한다.

OrderedDict([('w1.weight', tensor([[-0.1524,  0.4417],
        [ 0.2150,  0.1317],
        [ 0.4589,  0.6667],
        [-0.0104,  0.3577],
        [ 0.5059, -0.0157],
        [ 0.1874, -0.2470],
        [-0.5674, -0.5179],
        [ 0.1378, -0.3285],
        [ 0.5782,  0.7016],
        [-0.1318, -0.0728]])), ('w1.bias', tensor([ 0.3405, -0.5623,  0.3785, -0.0829,  0.2551, -0.1665,  0.4087,  0.5271,
        -0.6479,  0.1198])), ('w2.weight', tensor([[ 0.2526, -0.2523,  0.1249,  0.0892,  0.2619, -0.0909, -0.0945, -0.2930,
          0.0442, -0.0053],
        [-0.2775,  0.2231, -0.0841, -0.0164, -0.0358,  0.1823, -0.0198,  0.1489,
          0.1366,  0.3035],
        [ 0.1344,  0.0355, -0.2670, -0.2763, -0.2737,  0.0446,  0.2192, -0.2894,
         -0.1984,  0.1941]])), ('w2.bias', tensor([-0.2245,  0.1045, -0.2196]))])

그러면 세세한 가중치와 가중치의 변수 이름까지 모두 표시된다.

 

state_dict()는 모델의 state_dict를 반환하는 함수이므로 이 자체를 변형할 수 없다.

 

모델의 가중치를 변경하기 위해서는 다른 함수를 거쳐야 한다.

 

이는 다음 포스팅에서 다루도록 한다.

반응형

'인공지능' 카테고리의 다른 글

nan 혹은 말도 안되는 loss  (0) 2021.04.06
Keras에서 GPU 제대로 쓰기  (0) 2021.04.06
우분투에서 Cuda 버전 다운그레이드  (1) 2021.03.26
Comments