본문 바로가기
Python

Pytorch Quantization

by pnnote 2023. 4. 25.
반응형

Quantization이란?

딥러닝에서의 Quantization(양자화)은 모델의 가중치와 활성화 값을 낮은 정밀도의 숫자로 변환하는 과정입니다. 보통 신경망은 32비트 부동 소수점 숫자로 파라미터를 저장하고 계산합니다. Quantization은 이러한 신경망 모델의 파라미터를 낮은 비트 수로 표현함으로써 모델의 크기를 줄이고 계산을 더 효율적으로 수행할 수 있게 해줍니다. 일반적으로 가중치와 활성화 값들은 32비트 또는 16비트 부동 소수점 숫자로 저장되지만, Quantization을 적용하면 8비트나 4비트와 같이 낮은 비트 수로 표현할 수 있습니다. 32비트 실수 자료형 모델 매개변수를 8비트 정수 자료형으로 전환했을 때 모델의 크기가 줄어드는 것과 속도 향상의 효과가 나타납니다. 보통 모바일 기기나 edge device 같은 자원이 제한된 환경에서 실행할 때 유용합니다.

 

 

Pytorch 모델에 Quantization 적용하는 방법

Quantization을 적용하는 방법은 몇가지가 있습니다.

1. 첫번째로 모델의 weight와 활성함수를 사전에 모두 8비트 정수로 바꾸는 static quantization(정적 양자화) 방법입니다. 우선은 학습된 모델을 로드하고 다음 코드를 실행하면 Quantization된 모델이 나오게 됩니다.

 

backend = "qnnpack"
self.model.qconfig = torch.quantization.get_default_qconfig(backend)
self.model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
model_qat = torch.quantization.prepare_qat(self.model, inplace=False)

model_qat = torch.quantization.convert(model_qat.eval(), inplace=False)
torch.save(model_qat.state_dict(), "./quantization_model.pt")

* 모바일 장비는 일반적으로 ARM 아키텍처를 탑재하는데 모바일 장비에서 모델이 작동하게 하려면, qnnpack 을 backend 로 사용하고, x86 아키텍처를 탑재한 컴퓨터에서 모델이 작동하게 하려면, x86을 backend로 사용

 

위 방식으로 Quantization된 모델을 load해서 사용하려고하면 아래와 같은 에러가 나게됩니다.

Could not run 'quantized::batch_norm' with arguments from the 'CPUTensorId' backend. 'quantized::batch_norm' is only available for these backends: [QuantizedCPUTensorId].

 

이럴 경우 다음 방식으로 수행해야 합니다.

self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()

 

위 코드는 tensor를 실수형에서 양자화된 자료형으로 전환하거나, 반대로 양자화된 자료형에서 실수형으로 전환하는 역할을 합니다.

 

 

2. 학습을 하는 과정에서 Quantization(양자화)을 고려해서 학습하는 방법이 있습니다. 학습 과정에서 모든 가중치와 활성 함수에 양자화를 하게되고, 학습 후에 양자화하는 방법보다 높은 정확도를 가집니다.

 

self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()

 

양자화를 고려한 학습을 하려면, 모델 정의 부분이 있는 __init__ 메소드에서 QuantStub 과 DeQuantStub 을 정의하고, 모델 forward 메소드 시작과 끝부분에서 x = self.quant(x) 와 x = self.dequant(x) 를 사용하면 됩니다.

추가로 양자화를 고려한 학습을 하려면 다음과 같은 코드를 적용하면 됩니다.

model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
model_qat = torch.quantization.prepare_qat(model, inplace=False)

model_qat = torch.quantization.convert(model_qat.eval(), inplace=False)
반응형