Ahogrammer

Deep Dive Into NLP, ML and Cloud

ブロックごとの量子化を実装する

QLoRAについて少し書く機会があったので、その要素技術であるブロックごとの量子化(block-wise quantization)の解説とその実装をしてみました。実際のところ、bitsandbytesなどのライブラリに実装されているので、自前で実装する必要はまったくないのですが、学習用兼サボってたブログのリハビリ用です。

量子化

量子化とは、浮動小数点数を整数による表現に変換する技術です。たとえば、モデルの重みが32ビットの浮動小数点数として格納されているとします。量子化によって、これらの重みは32ビットからたとえば8ビットの整数に変換できます。量子化は計算時間やメモリ使用量を削減できる一方、モデルの予測の正確さを低下させる可能性があります。

量子化がどのようにして重みを整数で表現するのかを理解するために、簡単な例を見ていきましょう。ここでは、32ビット浮動小数点数(FP32)のテンソルを8ビット整数(INT8)に量子化します。対称的な範囲に変換するとすれば、[-127, 127]の範囲の整数に変換されます。変換は以下のように行われます。


\mathbf{X}^{\text{Int8}} = \text{round} \left( \frac{127}{\text{absmax}(\mathbf{X}^{\text{FP32}})} \mathbf{X}^{\text{FP32}} \right) = \text{round}\left( c^{\text{FP32}} \cdot \mathbf{X}^{\text{FP32}} \right)

X^{\text{FP32}}はFP32形式の元の入力テンソルであり、\text{absmax}(X^{\text{FP32}})テンソルX^{\text{FP32}}内の絶対値の最大値です。出力のX^{\text{Int8}}は、元のテンソルのINT8表現です。c^{\text{FP32}}は、量子化定数と呼ばれ、以下で定義されます。この定数は、あとで逆量子化(元の精度への復元)を行う際に必要になります。


\frac{127}{\text{absmax}(\mathbf{x}^{\text{FP32}})} = c^{\text{FP32}}

それでは、以下のテンソルを使って、量子化の具体例を見てみましょう。

import torch
tensor = torch.tensor([-.59, -.21, -.07, .13, .28])

このデータをINT8形式に変換するために、先ほどの式に基づいて関数を作成します。

def quantize(X_FP32):
    c_FP32 = 127 / torch.max(torch.abs(X_FP32))
    X_Int8 = torch.round(c_FP32 * X_FP32).to(torch.int8)
    return X_Int8, c_FP32

用意したテンソルに対して関数を実行すると、次の結果が得られます。

quantized_tensor, c = quantize(tensor)

print("Original Tensor :", tensor)
print("Quantized Tensor:", quantized_tensor)

# 出力
Original Tensor : tensor([-0.5900, -0.2100, -0.0700,  0.1300,  0.2800])
Quantized Tensor: tensor([-127,  -45,  -15,   28,   60], dtype=torch.int8)

FP32テンソルを効果的にINT8形式に変換できました。これにより、元のFP32形式ではなく、INT8としてテンソルの値を保存できるため、メモリ使用量を大幅に削減できます。

量子化

量子化された重みはメモリ使用量を削減するために保存されますが、学習時には元の精度に変換し直すために逆量子化(dequantization)が行われます。これは、量子化したままの数値だと表現力が不足して学習が困難になるためで、学習時には各層の計算ごとに量子化された重みを元の精度に変換しなおします。

量子化を行うには、以下の式を使用します。


\text{dequant}(c^{\text{FP32}}, \mathbf{X}^{\text{Int8}}) = \frac{\mathbf{X}^{\text{Int8}}}{c^{\text{FP32}}} = \mathbf{X}^{\text{FP32}}

量子化のときと同様、逆量子化をするための関数を定義します。

def dequantize(c_FP32, X_Int8):
    return X_Int8 / c_FP32

用意したテンソルに対して関数を実行すると、次の結果が得られます。

quantized_tensor, c = quantize(tensor)
dequantized_tensor = dequantize(c, quantized_tensor)

print("Original Tensor   :", tensor)
print("Quantized Tensor  :", quantized_tensor)
print("Dequantized Tensor:", dequantized_tensor)

# 出力
Original Tensor   : tensor([-0.5900, -0.2100, -0.0700,  0.1300,  0.2800])
Quantized Tensor  : tensor([-127,  -45,  -15,   28,   60], dtype=torch.int8)
Dequantized Tensor: tensor([-0.5900, -0.2091, -0.0697,  0.1301,  0.2787])

結果を見ると、逆量子化されたテンソルは、元のテンソルと完全には一致しないことがわかります。たとえば、元の値が 0.21 の場合、逆量子化後は 0.2091 となり、0.0009 の差が生じます。とはいえ、計算に量子化された値を使用するよりも正確な値を使用できます。

ブロックごとの量子化

単純な量子化の問題点として、入力テンソルに極端に大きな値や小さな値(外れ値)が含まれる場合、量子化しようとしている値の範囲を歪めてしまう可能性があります。つまり、外れ値以外の値が同じような値に変換され、ほかとの区別が難しくなってしまうのです。

たとえば、前回のテンソルに外れ値として 100 を追加するとしましょう。そのようなテンソルに対して関数を実行すると、次の結果が得られます。結果を見ると、多くの値が狭い範囲にマッピングされてしまい、元のテンソルとの違いが大きくなっていることがわかります。

tensor = torch.tensor([-.59, -.21, -.07, .13, .28, 100])

quantized_tensor, c = quantize(tensor)
dequantized_tensor = dequantize(c, quantized_tensor)

print("Original Tensor   :", tensor)
print("Quantized Tensor  :", quantized_tensor)
print("Dequantized Tensor:", dequantized_tensor)

# 出力
Original Tensor   : tensor([-0.5900, -0.2100, -0.0700,  0.1300,  0.2800, 100])
Quantized Tensor  : tensor([ -1,   0,   0,   0,   0, 127], dtype=torch.int8)
Dequantized Tensor: tensor([ -0.7874,   0.0000,   0.0000,   0.0000,   0.0000, 100.0000])

この問題を解決するための手法として、テンソルを小さな「ブロック」に分割し、それぞれのブロックに対して個別に量子化を適用する手法があります(block-wise quantization)。これにより、あるブロック内の外れ値が他のブロックの値に影響を与えることを防ぐことができます。テンソル𝑛 個のブロックに分割する場合、𝑛 個の固有の量子化定数が生成されます。

ブロックごとの量子化。画像は『8-bit Optimizers via Block-wise Quantization』より引用

では、ブロックごとの量子化を実装してみましょう。ここでは、わかりやすさのためにループを使っています。

class BlockwiseQuantizer:
    def __init__(self, block_size: int):
        self.block_size = block_size
        self.quant_constants = None

    def quantize(self, data):
        data = data.flatten()
        n = data.numel()
        num_blocks = -(n // -self.block_size)

        quantized_data = torch.zeros_like(data, dtype=torch.int8)
        self.quant_constants = torch.zeros(num_blocks, device=data.device, dtype=torch.float32)

        for i in range(num_blocks):
            start = i * self.block_size
            end = min(start + self.block_size, n)
            block = data[start:end]

            # Compute quantization constant
            abs_max = torch.max(torch.abs(block))
            c = 127 / abs_max
            self.quant_constants[i] = c

            # Quantize the block
            quantized_data[start:end] = torch.round(c * block).clamp(-127, 127).to(torch.int8)

        return quantized_data

    def dequantize(self, quantized_data):
        quantized_data = quantized_data.flatten()
        n = quantized_data.numel()
        num_blocks = len(self.quant_constants)

        dequantized_data = torch.zeros_like(quantized_data, dtype=torch.float32)

        for i in range(num_blocks):
            start = i * self.block_size
            end = min(start + self.block_size, n)
            c = self.quant_constants[i]

            # Dequantize the block
            dequantized_data[start:end] = quantized_data[start:end].float() / c

        return dequantized_data

外れ値を含むテンソルをブロックごとに量子化すると、外れ値の影響が他のブロックに広がることがなくなります。これにより、量子化精度が向上します。以下の結果を見ると、先ほどの結果と比べて、逆量子化されたテンソルが元のテンソルに近い値を持っていることがわかります。

quantizer = BlockwiseQuantizer(block_size=2)
quantized_tensor = quantizer.quantize(tensor)
dequantized_tensor = quantizer.dequantize(quantized_tensor)

print("Original Tensor   :", tensor)
print("Quantized Tensor  :", quantized_tensor)
print("Dequantized Tensor:", dequantized_tensor)

# 出力
Original Tensor   : tensor([-0.5900, -0.2100, -0.0700,  0.1300,  0.2800, 100])
Quantized Tensor  : tensor([-127,  -45,  -68,  127,    0,  127], dtype=torch.int8)
Dequantized Tensor: tensor([-5.9000e-01, -2.0906e-01, -6.9606e-02,  1.3000e-01,  0.0000e+00,
         1.0000e+02])

参考資料