Understanding the performance gap between online and offline alignment algorithms
논문 정보
- Date: 2024-05-27
- Reviewer: 전민진
- Property: RLHF
Abstract
-
RLHF은 요즘 LLM alignment를 위한 근본 방법론으로 자리 잡음
- 이에 따라, PPO, DPO, TRPO, self-judge등 엄청나게 많은 variation 방법론이 쏟아지고 있음
-
RLHF를 크게 online, offline alignmnt algorithm으로 나눠서, 둘의 성능 차이가 어떠한 원인에서 발생하는지 분석해보고자 함
Preliminary
-
on-policy, off-policy, online, offline
-
on-policy
- behavior policy(데이터 샘플링할 때 사용한 policy) == target policy(학습할 policy)
-
ex) 직접 롤을 하면서 잘하는 방법을 터득함 ⇒ on-policy
ex) REINFORCE algorithm, PPO
-
off-policy
-
behavior policy(데이터 샘플링할 때 사용한 policy) != target policy(학습할 policy)
-
과거의 자기 모델(파라미터 업데이트하기 전)이 생성한 데이터여도 off-policy
-
ex) 유튜브에서 페이커가 롤하는 영상을 보고 잘하는 방법을 터득함 ⇒ off-policy
ex) Q-learning
-
online
-
agent가 직접 환경과 상호작용
-
policy를 업데이트할 때 지속적으로 환경과 상호작용하면서 샘플을 수집하는 상황을 가정
-
offline과의 가장 큰 차이가 policy optimization하는 동안 데이터를 샘플링하는가
-
online+on-policy(PPO, figure-a), online+off-policy(figure-b)알고리즘 모두 존재
-
-
offline(figure-c)
-
agent가 직접 환경과 상호작용하지 않음
-
환경과 상호작용 없이 고정된 데이터로만 학습하는 경우
-
즉, policy optimization을 수행하는 동안 추가적인 샘플링 과정이나 데이터 수집 과정이 없음 ⇒ 주어진 데이터만으로 학습
-
ex) DPO


Introduction
Is online RL necessary for AI alignment?
-
offline RL(DPO)같은 방법론의 경우 online RLHF에 비해서 훨씬 간단하고 연산량도 적음
-
과연 offline RL로도 충분히 alingment가 가능한지를 분석
-
online vs offline algorithms
-
online이 연산량이 훨신 큼
- due to sampling and training an extra reward model
-
budget(KLD with SFT model)을 기준으로 각 방법론이 어느 정도의 성능을 내는지를 비교
-
일반적인 RL setting에서는 당연히 online이 우세하나, RLHF에서 reward에 bottle neck이 발생하는걸 고려하면 애매
- LLM을RL로 학습할 때는 reward modeling을 human preference data로 한 후에, policy만을 업데이트
-
⇒ policy가 업데이트 되면서 분포가 reward model과 상이해질 수 있음
Comparing online and offline performance under Goodhart’s law
-
우선 online과 offline alignment method가 성능 차이가 나는지를 확인
-
면밀한 비교를 위해 둘다 IPO loss 사용, 차이는 \mu = \pi_\theta(online), \mu = D(offline)

- Online achieves better budget and performance trade-off than offline

- 실험 세팅

- golden policy : 기존의 human preference데이터로 golden preference model(reward model)을 학습, 이를 바탕으로 policy를 학습한 모델(T5X-XXL, 11B 사용)
- 다른 모델은 모두 golden preference model로 학습데이터를 다시 라벨링한 데이터셋 D_{golden}으로 학습(T5X-Large, 770M)
- D_{golden}은 기존 데이터셋에 비해 다소 너프된, 노이즈가 있는 데이터셋이라고 보면 됨
- 각각의 점은 학습된 policy 성능을 의미(미리 정해놓은 hyper-parameter + 학습 정도)
-
실험 결과, online algorithm이 offline에 비해서 더 높은 trade-off performance를 보임
-
online, offline 모두 overoptimization 현상이 발생하는 것을 알 수 있음
- 두 가지 모델 모두 특정 지점에서 peak를 찍고 다시 성능이 하락함
Hypotheses for the performance discrepancy
같은 데이터셋으로 학습하는데 왜 on, offline 모델이 성능 차이가 나는지 확인하기 위해 가설 설정
-
Hypothesis 1 : Data coverage
-
online algorithm are better because the dataset coverage is more diverse than the offline dataset
-
offline은 이미 구축된 데이터셋으로만 학습되는 대신, online의 경우 현재 모델로 response sampling, reward model로 분류한 뒤에 학습
-
당연히 데이터 다양성은 online algorithm이 압살
- online버전은 학습하면서 모델이 좀 더 aligment된 답변을 내놓을 것이기 때문
-
-
-
Hypothesis 2 : sub-optimal offlifne dataset
-
offline algorithms are at a disadvantage because the initial preference dataset is generated by a sub-optimal policy
-
preference dataset을 구축할 때 사용된 모델이 sub-optimal하기 때문에, 이걸로만 학습하는 offline 방식은 더 불리할 것
-
offline 방식은 사실상 SFT의 contrastive version으로 볼 수 있음
- 즉, 데이터의 품질에 더 영향을 받을 수 밖에 없다
-
-

-
Hypothesis 3 : Better classification better performance
-
Offline algorithms typically train policies as classifiers
-
_However, as classifiers they might not be as accurate as proxy preference models (due to effectively different ways to parameterize the classification). _
-
If the accuracy improves, the performance will improve too
-
offline방식은 사실 policy를 preference classifier로 학습하는 것과 같음
-
만약, offline으로 학습한 policy의 classifier성능이 떨어진다면, 당연히 policy 자체의 성능도 향상될 수 없을 것
- 이런 상황이라면, classifier의 accuracy를 올렸을 때, policy 자체의 성능도 올라야 함
-
-
-
Hypothesis 4 : Non-contrastive loss function
-
How much of the performance gap is attributed to the loss function being contrastive, rather than samples being offline?
-
둘의 성능 차이가 sampling아니라 loss function이 contrastive에서 발생하는건지 확인
-
CL의 경우 hard negative와 같은 데이터가 중요한데, 그러면 당연히 계속 데이터를 sampling하는 online방식이 효과적일 것
-
SFT로 loss식을 바꾸면 on,off 성능차가 줄지 않을까?라는 의문에서 시작된 가설
-
-
-
-
Hypothesis 5 : Scaling policy is all you need
-
Scaling policy size upwards is all you need to bridge the gap between online and offline algorithms
- 모델 사이즈를 키우면 on,off 성능 차가 줄지 않을까?
-
** 실험 세팅
-
controlled setting to study KL vs. performance trade-off
-
한정된 세팅에서 실험하기 위해 D_{golden}을 학습 데이터셋으로 사용
-
scaling experiment을 제외하고는 policy와 proxy preference model로 Large T5X(with 770MM params)을 사용
-
learned policy와 SFT policy의 거리를 측정하는 메트릭으로 KLD(sft,learned policy)를 사용
-
-
supervised fine-tuning
-
SFT시, (x,y_w), (x,y_l)을 모두 데이터셋으로 사용
-
모델의 퀄리티를 향상시키는걸 목표가 아니라 RLHF의 시작을 보장하기 위한 용도
-
현실적인 상황을 가정
-
-
-
evaluation
-
2048에 대한 prompt에 대해 fixed policy baseline에 대한 win rate를 평가지표로 사용
-
fixed policy는 golden preference model를 바탕으로 online algorithm으로 학습한 모델
-
win rate는 golden preference model이 결정
-
-
hyper-parameter
-
baseline
- lr : 1e-5, beta : 0.1, gradient step : 4k
-
offline algorithm
- lr : (3e-6, 1e-5, 3e-5), beta : (0.1, 0.5, 1), training step : (4k, 20k steps)
-
Investigating the hypotheses
Hypothesis 1 : Data coverage
: on, off의 성능 차는 데이터의 다양성에서 기인할 것
⇒ online 학습에서 사용되는 데이터셋을 shuffle, 이를 바탕으로 offline 학습을 해보자
(shuffle안하면 online과 똑같음)

-
offline with d_online-shufflfe이 offline에 비해 근소한 성능 향상을 보였으나, 기존과 큰 차이 없었음
-
그나마 Chat arena에서 준수한 성능을 보임 → data coverage가 충분하면 data order는 크게 중요하지 않다..(고 하기엔 사실 online dataset을 활용하는거 자체가 현실성이 없음 ㅎㅎ…)
-
shuffle을 덜하면 성능이 향상됨
-
→ (민진생각) 모델의 상태에 알맞는 학습 데이터를 제공하는 것이 젤 중요한거 같음

- 결론 : data coverage때문은 아님
Hypothesis 2 : Sub-optimal offline dataset
: offline dataset이 sub-optimal해서 offline 성능이 낮은 것이다

-
red dot : online policy로 4k step만큼 학습한 모델로 response sampling, golden preference로 다시 라벨링한 데이터로 학습
- 기존의 golden dataset에 비해 좀 더 tricky한 데이터라고 볼 수 있음
-
실험 결과, 성능에 도움 안됨
Hypothesis 3 : Better classification, better performance
: preference classification을 잘하면, policy의 성능이 높을 것이다
→ 1) proxy preference model이 policy를 classifier로 쓰는 것보다 높은 classification 성능을 보일 것
2) online과 offline의 성능 차이는 이러한 classification accuracy의 차이에서 기인했을 수도 있다
** policy는 preference classifier라고 볼 수 있음

**preference model은 policy를 classifier로 쓰는거보다 더 expressive version이라고 볼 수 있음
왜지??? 그냥 likelihood가 아니라 score를 학습하도록 해서? 잘 모르겠다…
각 response r(x,y)마다 single scalar를 부여

-
실험 세팅
-
proxy preference model : online에서 사용되는 reward model
-
online@4k, offline@4k : policy를 classifier로 사용한 경우
-
각 학습 스텝에서 나오는 online dataset으로 classifier의 성능을 측정
-
-
실험 결과, 학습할수록 on-policy dataset이 기존의 preference dataset과 멀어짐
→ proxy preference model 성능이 하락
→ 맨 처음 실험 장표에서 봤던 online algorithms의 over-optimization 문제를 부분적으로 설명

-
classifier 성능에 따른 policy자체의 성능(win rate)에 관한 장표
-
실험 결과, classifier의 성능과 policy 자체의 성능은 큰 연관성이 없었다..

-
각 방법론의 학습 정도에 따른 classifier 성능과, D_golden에서 y_w의 relative log probs를 측정
-
log probs of the winning reponses from the pairwise dataset, relative to the SFT policy
-
\mathbb{E}{(x,y_w,y_l)\sim D{golden}}[log\pi_{\theta}(y_w x)-log\pi_{sft}(y_w x)]
-
-
(figure 8, top row) online의 classification accuracy가 낮은 것을 확인할 수 있음
- online의 경우 계속 data distribution이 이동
→ 고정된 분포(D_{golden})에 대한 classification은 낮지만 generative는 상대적으로 잘됨
-
(figure 8, bottom row) offline방식의 경우 winning response의 logit을 높이는 방식이 아니라 둘다 logit을 낮추되, losing response의 logit을 훨씬 크게 낮추는 방식으로 학습
-
물론 likelihood가 낮아진다고 해서, offline algorithm이 학습에 실패했다는 것은 아님
- 단지 offline optimization은 offline dataset에서 샘플링 된 winning response쪽으로 더 확률이 커지도록 확률 분포를 움직이는 게 아니라는 뜻
-
→ 우리가 일반적으로 생각하는 것보다 더 간접적으로 optimize..
- online 방식의 경우 현재 정책에서 발생할 가능성이 높은 응답만 관찰, offline과 다른 최적화 과정을 거친다는 것을 알 수 있음
Hypothesis 4 : Non-contrastive loss function
: off가 on보다 성능이 낮은 이유는 loss function이 contrastive 구조이기 때문
→ 데이터의 품질에 더 많은 영향을 받음
⇒ loss fuction을 SFT느낌으로 바꿔서 실험해보자


-
Bo2에서도 on,offline의 성능 차이는 비슷
- 다만 Bo2에선 data coverage를 높이면 성능이 향상됨
-
chat arena sxs에서는 Bo2를 사용한 on, off성능이 유사
→ winning response에 대한 SFT만으로도 성능을 충분히 낼 수 있음
Hypothesis 5 : Non-contrastive loss function
: policy model을 scaling up → 3B, 11B로도 학습 (자원때문에 batch size 낮춤)

-
scaling up하면 model의 peak성능이 높아지긴 함
-
scaling up해서 실험한 결과, 모델이 작을 때와 유사하게 overoptimization문제를 발견

-
모델 크기를 키웠더니 offine with D_online-shuffle 성능이 크게 향상되긴 함
- 모델이 커지면 on,off 성능 차이를 data coverage로 설명할 수도 있다!
Making the dataset more on-policy improves offline learning
어떻게 데이터셋을 구축해야 offline learning의 성능을 향상시킬 수 있을까? 에 대한 ablation study
- 3가지 버전의 데이터셋을 구축
- : 각각의 데이터셋은 D_golden에서 prompt를 샘플링, 버전에 맞는 모델로 response sampling, golden preference model로 라벨링
-
구축한 후 D_golden과 비교
-
D_sft vs 800 : Two sides of the responses are generated by the SFT policy and the online algorithms’ learned policy at 800 step
-
D_800 vs 4k : Two sides of the responses are generated by the online algorithms’ learned policy at 800 step and 4k step
-
D_4k vs 4k : Two sides of the responses are both generated by the online policy at 4k step

-
실험 결과, 학습데이터가 어느정도 SFT와 분포가 유사해야 잘 작동
- 분포가 offline초기 단계와 비슷해야 좀 더 on-policy와 유사하게 작동하기 때문으로 추정
-
응답 간의 퀄 차이만으로는 성능이 향상되지 않았음
Summary
-
Online과 offline의 성능을 budget 측면에서 비교했을 때, online이 성능이 우월
- on-policy data generation이 학습 효율 향상의 핵심인 것으로 추정됨
-
요즘 하도 다양한 RL방법론이 등장하고 있어서, 이들 사이에 차이가 발생해 성능이 차이나는지 분석하는 것도.. 좋을듯