QCRD: Quality-guided Contrastive Rationale Distillation for Large Lanauge Models

논문 정보

  • Date: 2024-10-03
  • Reviewer: 전민진
  • Property: Knowledge Distillation

Abstract

  • LLM은 좋은 성능을 갖고 있으나 resource 제한, inference 효율성 등으로 다양한 application에서 사용되기엔 한계가 존재

  • 최근 LLM을 기반으로 한 knowledge distillatinon으로 smaller, task-specific한 모델을 학습하는 여러 방법론이 제안 됨

  • 하지만 기존 연구들은 knowledge의 disversity와 quality에 크게 집중하지 않음.

    • 특히, negative knowledge을 distillation에 사용하지 않음
  • 본 논문에서는 constrative knowledge learning을 통한 reasoning capability 향상을 목표로 하는 quality-guided contrstive rationale distillation(QCRD)을 제안

    • 특히 small lm의 이전 iteration모델에서 rationale을 sampling, negative로 사용

    • discriminator를 같이 학습하여 rationale의 pos, neg를 판별, quality score를 계산해 이를 학습에 반영

  • 실험 결과, 기존의 distillation method보다 뛰어난 성능을 보임

Introduction

  • LLM의 모델 크기가 커지면서 reasoning ability가 발생 → 이를 이용해 작은 모델을 학습하자

    • 하지만 아직도 LLM과 distilled small model간의 성능 차이가 심한 task가 reasoning task
  • 이를 해소하기 위한 다양한 방법이 제안

    • LLM이 생성한 rationale을 생성하도록 small LM을 학습하는 방법(distill step-by-step )

      • L = L{prediction}+\lambda L{generation}

      • 이 방법의 경우 postivie knowledge만 사용, knowledge가 한정적이고 noisy가 있을 수 있음

    • LLM이 생성한 rationale을 golden answer로 보고 작은 모델이 생성한 rationale과 정답의 차이를 줄이도록 하는 연구도 존재

      • 대부분 LLM의 zero-shot/few-shot 결과를 그대로 사용, reaosning step에서 오류가 발생할 확률이 높음

⇒ 하지만 이러한 방법들 모두 negative rationale을 생성해서 학습에 사용하진 않음

  • 본 논문에서는, Quality-guided Contrastive Rationale Disilltation(QCRD)을 제안

    • positive example : LLM이 생성한 rationale 중 self-consisteny O

    • negative example : LLM이 생성한 rationale중 self-consistency X + previous iteration student model이 생성한 rationale (with high temp)

    • discriminator가 rationale의 quality를 계산, 이를 학습에 반영

  • QCRD의 우수성을 입증하기 위해서, T5-base(220M), small(60M)을 student로 사용해 실험 진행, 여러 벤치마크 데이터셋에서 우수한 성능을 보임
  • Multi-task learning with LLM generated ratioanles

    • 기존에 rationale을 바탕으로 하는 여러 KD 방법론이 제안

    • rationale을 학습에 활용하는 것이 효과가 있다는 것이 밝혀져 있음

    • 이전에는 multi-task learning framework방식으로, prefix를 기반으로 모델이 label을 예측하면서 동시에 rationale도 생성할 수 있도록 학습, 내재적으로 rationale에 있는 knowledge를 학습하도록 함

    • 하지만 smalle model의 rationale과 LLM의 rationale이 align 되도록 하나의 loss form에만 집중

  • Contrastive learning for LLMs

    • LLM에 contrastive leraning을 하는 방법론이 기존에 많이 제안되었으나, CoT distillation쪽에선 한번도 차용되지 않음

Methodology

  • QCRD는 3가지 파트로 구성이 되어 있음

    • multi-task learning framework : L = prediction_loss + CL_loss

    • generate contrastive knowledge from LLM and student model

    • quality-guided contrastive learning strategy

      • positive와 negative를 구별하는 online-updated discriminator의 guidance를 사용

Multi-task learning framework for the student model

  • 이전 연구와 유사하게 prefix를 활용해서 small model이 다양한 형태의 output을 생성할 수 있도록 학습

    • for prediction label task,

    • for rationale generatrion task,

Generation of contrastive knowledge

  • Positive sample

    • CoT prompting with sampling using LLM
  • 각 data마다 K개의 rationale을 생성, self-consisteny를 만족하는 rationale을 positive로 사용

  • Negative sample

    • 위에서 self-consistency를 만족하지 못한 rationale을 negative rationale로 봄

      • self-consistency방법론 특성상 negative가 positive에 비해 적음

      • LLM이 생성한 negative는 student가 봤을 땐, positive처럼 보일 수 있음 → 학습 효과가 떨어짐

    • 이전 iteration의 student model에 high temperature를 사용해서 rationale을 sampling

→ low quality rationale이라 판단, negative로 사용

  • \mathbf{x} = [x_1, x_2,..,x_n], S{pos}={r_1^{pos},…,r_m^{pos}}, S{neg} = {r_1^{neg},..,r_k^{neg}}

Constrastive knowledge distillation

  • Train a discriminator to judge rationales

    • 같은 question에 대한 rationale의 quality는 상이, 하지만 학습하면 할수록 student가 생성하는 rationale은 positive rationale과 가까워질 것

      • 무작정 student model이 생성했다고 계속 negative로 보는 것은 합리적이지 않을 수 있음
    • 그래서 효과적으로 positive와 negative rationale를 판별, quality score를 계산해줄 discriminator가 필요

    • Discriminator에 input으로 question과 rationle을 넣고, score를 계산하도록 함

- discriminator로는 encoder architecture를 사용

  - T5-base의 encoder를 사용, max pooling layer랑 2개의 linear layer를 추가

  - 학습 전에 LLM이 생성한 rationale로 pretrain진행(with 500 max step), pos와 neg의 비율을 맞춰주기 위해 word_mask and replacement로 data augmenatation.

  - online-updated during training
  • LLM으로 생성한 positive, negative로 D를 pretrain, 학습 동안에는 D를 regular epoch interval로 업데이트
  • Quality-guided contrastive distillation

    • 위의 단계로 여러 positive, negative sample을 수집

→ many-to-one contrastive distillation loss를 사용

- l(f(\mathbf{x}_i),s_{pos}^i)=min_{r_{j}^{pos,i}\in S^i_{pos}}\{l(f(\mathbf{x}_i),r_j^{pos,i})\}

  - 모델이 생성한 rationale과 가장 유사한 postivie rationale과의 Cross entropy loss

- l(f(\mathbf{x}_i),s_{pos}^i)=min_{r_{j}^{neg,i}\in S^i_{neg}}\{l(f(\mathbf{x}_i),r_j^{neg,i})\}

  - negative는 가장 차이 많이 나는 것 선택

- l(f(\mathbf{x}),r_j^{neg})=min(l(f(\mathbf{x},r^{neg}_j))-\delta,0)

  - 너무 단순한 neg는 거르기 위해 margin사용
  • 또한, student model이 생성했다고 해서 무조건 negative로 볼 경우, 학습이 진행되면서 local optima에 빠질 수 있기 때문에 discriminator를 사용하는 quality-guided distillation을 사용
- Discriminator로 quality score s를 계산, pos의 경우 quality가 높으면 더 학습에 반영되도록, neg의 경우 quality가 높으면 학습에 덜 반영되도록 함
  • Training loss
  • 이 total loss를 보면 최종적으론 student model의 encoder를 똑 떼서 discriminator로 학습한다는 느낌같은데.. 확실히 맞는지 모르겠음

Experiments

Datasets

  • SVAMP : arithmetic word problem solving

  • CQA : commonsense QA

  • e-SNLI, ANLI : NLI

  • rationale을 GPT-3.5-turbo로 생성

Implementation details

  • T5-base(220M),T-small(60M)사용

  • \alpha_1,\alpha_2,\alpha_3은 실험적으로 0.5으로 세팅 \alpha_3은 매 iteration마다 0.9를 곱함

  • \beta=0.2,\delta=3

  • LLM temp 0.7로 5번 샘플링, small model은 5-iteration-before model에서 temp 1.5로 1개 샘플링

Baselines

  • Fintuning

  • Single-supervision : teacher model이 예측한 label을 맞추도록 학습

  • DSS : multi task learning with label prediction and rationale generation

  • MI : DSS기반에 prediction label과 rationale사이의 mutual information이 최대화하도록 task 추가

Experimental result

  • Experiments across four benchmarks
  • 다른 유사한 CoT distillation 방법론과 비교해도 높은 성능을 보임

  • Distillation with LLM labels

    • ground truth를 사용하지 않고, LLM이 생성한 label로 학습했을 때의 실험

      • temperature sampling, self-consistency의 효과로 noisy label을 사용해도 비교적 높은 성능을 보임

        • 이전에 SC로 rationale을 pos랑 neg로 분류했다더니.. 아니었나..?
  • Distillation with smaller datasets
  • Ablation study on QCRD

    • ED : sample the outputs of the LLM and leverage the self-consistency to denoise rationales

    • NK : generator low-quality rationaels as negative rationales

    • QJ : use of discriminator

Discussion

  • Different contrastive distillation schemes

    • min-max방식이 가장 좋더라!
  • influence of negative sample

    • 5번 전 iteration의 모델로 negative를 만들때가 가장 효과적

    • Fixed의 경우 DSS로 학습된 모델이 생성한 rationle을 negative로 썼을 때

  • Assessment for generated rationales

    • GPT-3.5로 DSS와 QCRD로 만든 rationale중 무엇이 더 좋은지 평가

    • DSS win / tie / QCRD win

    • QCRD가 더 좋은 rationale을 생성한다!

  • Distribution of rationale quality scores

    • 확실히 좋은 rationale과 나쁜 rationale은 잘 scoring하는 것을 볼 수 있음

Conclusion

  • Contrastive learning을 CoT distillation에 접목한 최초의 논문!