欲速不達

일을 급히 하고자 서두르면 도리어 이루지 못한다.

Fantastic AI, Fantastic World

DS | Data Science/ML | Machine Learning

[Pytorch] hook & apply

_껀이_ 2022. 9. 28. 17:12
728x90
반응형

1. hook

hook은 패키지화된 코드에서 custom 코드를 중간에 실행시킬 수 있도록 만들어 놓은 인터페이스이다.

주로, 아래와 같은 경우에 사용한다.

  • 프로그램의 실행 로직을 분석
  • 프로그램에 추가적인 기능을 제공
def program_A(x):
    print('program A processing!')
    return x + 3

def program_B(x):
    print('program B processing!')
    return x - 3

class Model_Package(object):
    def __init__(self):
        self.programs = [program_A, program_B]
        #################################################
        self.hooks = []
        #################################################

    def __call__(self, x):
        for program in self.programs:
            x = program(x)
            #################################################
            if self.hooks:
                for hook in self.hooks:
                    output = hook(x)

                    if output:
                        x = output
            #################################################

        return x

model_package = Model_Package()

input = 3
output = model_package(input)

print(f"Process Result! [ input {input} ] [ output {output} ]")
>> program A processing!
   program B processing!
   Process Result! [ input 3 ] [ output 3 ]

위와 같이 사용자가 self.hook 부분에 custom program을 추가하여 사용할 수 있도록 한 것을 말한다. __call__ 부분의 if문은 self.hook 안에 custom program이 있는지를 확인하여 실행할 수 있도록 한다.

 

위의 코드에서는 self.hook이 비어있으므로 program_A와 program_B만 실행하고 종료한다.

 

def custom_program(x):
	print('insert custom program')
    
Model_Package.hooks.append(custom_program)

self.hook에 프로그램을 추가할때는 위와 같은 코드를 통해 추가하게 된다.

 

hook을 어느 위치에 놓느냐에 따라서 결과가 달라지며, 여러 개의 hook을 추가할 수 있다.

 

 

 

pytorch에서의 hook은 다음과 같은 메서드를 따른다.

 

- _backward_hook으로 hook 확인 예시

import torch

tensor = torch.rand(1, requires_grad=True)

def tensor_hook(grad):
    pass

tensor.register_hook(tensor_hook)

tensor._backward_hooks
>> OrderedDict([(0, <function __main__.tensor_hook(grad)>)])

 

- __dict__를 통해 module에 적용되는 parameter를 모두 확인 할 수 있음

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()

def module_hook(grad):
    pass

model = Model()
model.register_forward_pre_hook(module_hook)
model.register_forward_hook(module_hook)
model.register_full_backward_hook(module_hook)

model.__dict__
>> {'training': True,
    '_parameters': OrderedDict(),
    '_buffers': OrderedDict(),
    '_non_persistent_buffers_set': set(),
    '_backward_hooks': OrderedDict([(3, <function __main__.module_hook(grad)>)]),
    '_is_full_backward_hook': True,
    '_forward_hooks': OrderedDict([(2, <function __main__.module_hook(grad)>)]),
    '_forward_pre_hooks': OrderedDict([(1, <function __main__.module_hook(grad)>)]),
    '_state_dict_hooks': OrderedDict(),
    '_load_state_dict_pre_hooks': OrderedDict(),
    '_load_state_dict_post_hooks': OrderedDict(),
    '_modules': OrderedDict()}

 

1-1) tensor에 적용하는 hook

- register_backward_hook

: backward_hook은 backward pass 시에 module의 input_grad와 output_grad를 받아 실행된다. 주로 gradient를 다룰때 사용되며, tensor와 module 모두에서 사용 가능하다.

 

-register_full_backward_hook

: backward_hook과 같은 기능을 한다.

 

1-2) module에 적용하는 hook

- forward_hook, forward_pre_hook, backward_hook, full_backward_hook

 

- register_forward_pre_hook(pre_hook_name)

: forward_pre_hook은 forward pass 시에 module 앞에서 실행되는 hook이다. 그렇기 때문에 인자로 input값만을 받아 실행되며 input 값은 module에 전달되는 input 데이터이다.

 

- register_forward_hook(hook_name)

: forward_hook은 forward pass 시에 module 직후에 실행되는 hook이다. 인자로는 module의 input과 output을 받아 실행된다.

 

  • module에서의 backward hook 은 module 기준으로 input, output gradient 값만 가져오기 때문에 module 내부의 tensor gradient 값은 알아 낼 수 없음
  • 그래서, model의 paramrter W의 gradient 값을 알려면 tensor backward hook을 사용해야함
def tensor_hook(grad):
    answer.append(grad)
model.W.register_hook(tensor_hook)

 

 

2. apply

module이나 tensor에 대해서 custom method를 적용시키는 것은 hook을 통해 가능하다. 그러나 다량의 module들이 얽혀있는 modul 단위에 hook을 적용시키려면 일일히 module마다 달아줘야 하기때문에 번거롭다. 이 문제를 해결하기 위해 만들어진 함수가 apply()이며, model 단위에서 custom method를 적용할 수 있다.

 

2-1) 가중치 초기화(Weight Initialization)

import torch
from torch import nn

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
>> Linear(in_features=2, out_features=2, bias=True)
   Parameter containing:
   tensor([[1., 1.],
           [1., 1.]], requires_grad=True)
   Linear(in_features=2, out_features=2, bias=True)
   Parameter containing:
   tensor([[1., 1.],
           [1., 1.]], requires_grad=True)
   Sequential(
     (0): Linear(in_features=2, out_features=2, bias=True)
     (1): Linear(in_features=2, out_features=2, bias=True)
   )
   Sequential(
     (0): Linear(in_features=2, out_features=2, bias=True)
     (1): Linear(in_features=2, out_features=2, bias=True)
   )

apply 함수는 일반적으로 가중치 초기화(Weight Initialization)에 많이 사용되고, parameter로 지정한 tensor의 값을 원하는 값으로 지정해주는 것을 의미한다.

 

apply은 model의 모든 module을 입력으로 받아 처리한다.

위의 코드에서 net은 Sequential로 Linear module이 두 개 연결되어 있고 각각의 module에 init_weight가 적용되어 m.weight.fill_(1.0)으로 가중치를 초기화한 것을 볼 수 있다.

 

 

2-2) 모델 출력 표현 변경 - repr

def print_module(module):
    print(module)
    print("-" * 30)

returned_module = model.apply(print_module)
>> Function_A()
   ------------------------------
   Function_B()
   ------------------------------
   Layer_AB(
     (a): Function_A()
     (b): Function_B()
   )
   ------------------------------
   Function_C()
   ------------------------------
   Function_D()
   ------------------------------
   Layer_CD(
     (c): Function_C()
     (d): Function_D()
   )
   ------------------------------
   Model(
     (ab): Layer_AB(
       (a): Function_A()
       (b): Function_B()
     )
     (cd): Layer_CD(
       (c): Function_C()
       (d): Function_D()
     )
   )
------------------------------

model에 print_module()을 적용시켜보면

각각의 module이 실행될때마다 - 가 30개 출력되는 것을 확인할 수 있다. 이렇게 module 이름을 출력하면서 custom method가 적용된다.

 

module 이름이 출력되면서 동시에 custom method를 출력 하려면 functools의 partial을 import해서 사용하면 좋다.

from functools import partial

def function_repr(self):
    return f'name={self.name}'

def add_repr(module):
    try:
        er = lambda repr:repr
        module.extra_repr = partial(er, function_repr(module))
    except:
        pass

returned_module = model.apply(add_repr)

model_repr = repr(model)

print("모델 출력 결과")
print("-" * 30)
print(model_repr)
print("-" * 30)

>> 모델 출력 결과
   ------------------------------
   Model(
     (ab): Layer_AB(
       (a): Function_A(name=plus)
       (b): Function_B(name=substract)
     )
     (cd): Layer_CD(
       (c): Function_C(name=multiply)
       (d): Function_D(name=divide)
     )
   )
------------------------------
728x90
반응형