BitNet: Scaling 1-bit Transformers for Large Language Models
논문 정보
- Date: 2024-03-11
- Reviewer: 김재희
- Property: LLM, Quantization
The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits
1. Intro
BitNet
-
1-Bit(1 or -1) parameter로 scratch부터 학습
-
기존 LLM 대비 적은 Inference/Train Cost를 가짐
-
기존 Post Quantization 방법론 대비 높은 성능 기록
1-Bit
-
1.58Bit(1,0,-1) parameter로 scratch부터 학습
-
동일 파라미터를 가지는 LLaMA 구조 대비 높거나 비슷한 성능 기록
-
(1,0,-1)의 상태를 가지는 bit 구조를 이용한 하드웨어 설계를 통해 모델 학습/추론 파이프라인 최적화 방향성 제안
결론
-
정말 제대로 동작하는지 잘 모르겠음
-
최근 LLM과 엄밀한 비교 수행 X
-
1-Bit는 결국 더이상의 Quantization 불가
→ FP32/FP16/BF16의 모델들과 정확한 성능 비교가 필요
재밌는 아이디어이지만 이 방법론이 미래인지는 더욱 검증이 필요 ⇒ 70B의 1-bit가 knowledge를 제대로 담을 수 있을까? ⇒ Instruction Tuning과 같이 복잡한 태스크를 학습할 수 있을까?
2. BitNet
Architecture

-
Transformer의 일부 레이어를 1-bit로 quantization하여 사용(original weight가 존재)
-
Linear 레이어: 1-Bit(1,-1) quantization
-
이외 레이어: 8-bit quantization으로 연산 진행(attention, …)
-
Input/Output Embedding: high precision으로 진행(16 or 32 bit)
-
→ Sampling을 위해서는 high-precision이 필요하기 때문
-
BitLinear: 기존 Transformer 구조에서 연산량이 막대한 Linear Layer를 대체
-
Input: 8 bit quantization
-
Input Quantization: AbsMax Quantization 사용(Q_b: quantize할 데이터 범위)
-
⇒ 벡터를 max로 normalizing 후 부호만 남김

-
Input for Non-Linear Function Quantization
- Activation Function(GELU)의 입력의 경우 범위를 [0, Q_b]로 제한

-
Linear: 1 bit quantization
- weight의 평균 대비 크기 비교를 통해 Quantization 실행

-
Matrix Multiplication: Quantized Linear와 Quantized Input은 단순 연산을 통해 계산 가능
- 하지만 이대로 수행한다면 기존 LLM의 Layer Norm이 사라짐
→ Layer Norm: 학습 안정화 및 발산 방지
- Input Quantization 이전에 Layer Norm을 적용

-
연산이 완료된 벡터는 다시 Quantization 시 계산 된 수치를 이용하여 Dequantization 진행 → Precision 복원
-
Pretrain 과정에서 Linear Layer의 연산량을 감소 및 속도 개선 가능
-
Distributed Training:
- 기존 Pretrain과 달리 Input 별로 Quantization 수치를 계산해야 함
⇒ 분산 학습 시 Machine 별로 독립적으로 계산 → Machine 간 통신 비용 감소
-
Mixed Precision Training
-
Forwarding 과정
- Linear Layer(FP16) 및 Sub-Module 별 Input에 대한 Quantization 진행
-
→ Low Precision(1Bit)으로 Fowarding
-
Backwarding 과정
- Gradient와 Optimizer 내 state은 모두 high precision 사용
→ FP16 Linear Layer weight 업데이트
-
High Learning Rate
-
1 bit로 Quantized 하다보니 Learning Rate가 낮을 경우 실제 weight에서 작은 변화가 발생
-
1.24214 → 1.24232
-
1 → 1
-
학습 효과가 forwarding 과정에서 반영 X
-
-
Learning Rate를 대폭 높혀 학습 진행
- 2.4e-3 ~ 4e-4
-
Experiments
-
FP16 Trasnformer와 비교
- 125m ~ 6.7b까지 Transformer와 BitNet을 scratch부터 학습하여 비교 진행
-
Quantization Method와 비교
- 기존 Post Quantization 방법론들과 비교 진행(w:weight precision a: input precision)


-
fp16에 비해서는 낮지만 quantization 방법론 대비 매우 높은 성능 달성
-
Energe Consumption 대비 성능 비교 (zero/few shot)

-
동일 에너지 사용 시 더 높은 성능 달성
-
동일 에너지 사용=fp16 대비 더 큰 모델 사용 가능
3. 1.58bit
Architecture
-
1.58 bit…?
-
bit= 정보량을 표현할 수 있는 이진분류 표기 체계 단위
-
1bit : (-1, 1)
-
0: -1
-
1: 1
-
-
2bit: (0,1,2,3)
-
00: 0
-
01: 1
-
10: 2
-
11: 3
-
-
만약 BitNet에서 0만 추가한다면?
-
weight를 통해 input의 정보를 사용하지 않도록 만들 수 있음
-
log_23 \approx 1.58
-
현재 하드웨어 상 구현: 2 bit 필요
-
00: 0
-
11: 1
-
01: -1
-
-
-
-
1.58bit: fp16과 비슷한 성능을 내면서 inference cost를 줄일 수 있는 방법

-
Modification: BitLinear 구조 거의 그대로 활용
- AbsMean Quantization 사용 (-1, 0, 1로 quantization)

-
non-linear function 입력에 대한 scaling
-
BitNet: [0,Q_b]
-
1.58B: [-Q_b, Q_b]
-
-
모델 구조: LLaMA configuration 사용
Experiments
-
LLaMA Configuration을 이용하여 FP16 Transformer/1.58B scratch부터 학습
-
StableLM-3B에서 사용된 데이터 사용(data recipe)
- 2T token 학습
-
메모리 및 latency와 PPL 간 비교

- 동일 모델 크기 시: 더 적은 메모리 사용 및 Latency
⇒ Quantized Weight를 이용하고 있기 때문
- 비슷한 PPL 기록
-
비슷한 메모리 사용량 비교 시(LLaMA-700m vs BitNet b1.58 3B)
- 더 높은 성능 기록
-
모델 크기에 따른 Memory 및 Latency 경향

-
모델 크기가 커질수록 FP16보다 더 빠르고, 더 적은 메모리 사용
-
BitLinear가 개선시키는 부분은 모델 내 Linear 레이어 관련
→ 모델 크기가 커질수록 해당 파트의 비중이 커짐
- OpenSource LLM과 비교

-
StableLM-3B 모델과 성능 비교
-
모든 태스크에서 성능이 더 좋은 모습을 보임
-
속도가 빠른 건 이해가 되는데, 성능이 좋은 이유에 대한 언급이 없음
-
4. Conclusion
-
BitNet
- 얘네 Transformer 학습할 때 Drop-out 안쓰는데요…
⇒ Transformer 제대로 학습된 게 맞는지 모르겟음…
-
Quantization을 위해 Pretrain부터 Quantization된 학습이 필요하다고 주장
-
175B에서도 유의미할지는 생각해봐야 함
-
학습속도가 빠른지도 중요한 요소
-
→ Fowarding 속도 개선, Not Backward
→ 학습 속도 측면에서는 개선이 안되었을 수 있음
⇒ 자세한 언급 X
-
1.58B
-
지나치게 마케팅된 논문
-
1.58B이 아니라 사실상 2bit quantization
-
post-quantization도 아니고 pretrain quantization
-
논문의 가장 강한 언급: (1,0,-1)이 1 bit에서 가능한 하드웨어가 필요하다
-
-
→ 정말로…?
- FP16 모델(StableLM)보다 높은 성능 달성이 가능한 이유에 대한 언급 X
- Original Weight 사용 시 성능은 어떻게 되는지도 리포팅 X