NLP/📚 paper

논문 리뷰 💬 [Chain-of-Thought Reasoning without Prompting]

MINAIR 2024. 11. 21. 23:26

📚 Paper

https://arxiv.org/abs/2402.10200

 

Chain-of-Thought Reasoning Without Prompting

In enhancing the reasoning capabilities of large language models (LLMs), prior research primarily focuses on specific prompting techniques such as few-shot or zero-shot chain-of-thought (CoT) prompting. These methods, while effective, often involve manuall

arxiv.org

 

Abstract


LLM의 reasoning ability를 향상시키는 기존의 연구는 few-shot이나 zero-shot CoT (Chain-of-Thought) prompting에 초점이 맞춰져 있었다. 그러나, 이러한 prompt engineering은 manually intensive, 즉 수고스럽다.  본 논문은 LLM이 prompting 없이 효과적으로 reasoning할 수 있는가? 에 대한 새로운 접근법을 제시한다. 

 

pre-trained LLM의 decoding process에서 greedy decoding, 즉 하나의 토큰만을 고려하지 않고 top-k alternative tokens을 explore하여 CoT reasoning path를 도출할 수 있다. 

 

Introduction


  • 기존의 reasoning ability elict 방법
    • few-shot prompting: 질문-답 쌍으로 이루어진 몇 가지 예시를 prompt에 포함시킴. 그러나 질문에서 어떻게 답이 도출되었는지는 주어지지 않음. 
    • zero-shot prompting: 예시 없이 바로 어떤 instruction을 하는 prompting. 
    • CoT: 질문-답 쌍이 주어지지만, 어떻게 질문에서 답이 유도되었는지 그 과정도 함께 주어짐으로써 모델이 사고 과정을 모방할 수 있도록 유도함. zero-shot CoT는 질문-답 쌍이 주어지지 않기 때문에 prompt에 "Let's think step by step"을 포함시킴. 
    • 상당한 양의 CoT reasoning data로 training / instruction tuning시킨 모델 이용. 

 

이상적으로 사람이 prompt를 비틀거나 답변을 반복적으로 수정하게 하는 것 없이 모델은 독립적으로 optimal response를 제공할 수 있어야 한다. 이때 prompting techniques는 모델이 task-specific human priors를 크게 반영하도록 하여 모델의 intrinsic reasoning ability를 제대로 평가하지 못하도록 한다. 그렇다면, (상당히 인위적인) prompting 없이도 모델이 효과적으로 reason하게 할 수 있는가? 할 수 있다면 그 능력은 어느 정도인가?

 

본 논문은 decoding process에서 greedy decoding 대신 alternative top-k tokens를 고려함으로써 prompting 없이도 모델의 CoT ability를 확인할 수 있다고 주장한다. 

 

Figure 1은 그 예시다. 모델의 prompt로 "Q: [question]\nA:"를 사용했다. 첫 번째 decoding step (Decoding step 0)에서 top-k alternative tokens를 explore하고 그 뒤로는 greedy decoding으로 이어나가는 것이다. 예시에서는 greedy decoding이 아닌 (not top-1) path 2와 path 4 (CoT path)에서 옳은 답이 도출되었다. 

 

LLM이 greedy decoding만을 사용하면 reasoning에 상당히 어려움을 겪지만, top-k alternative path를 고려하면 decoding process에서 CoT reasoning path가 자연스럽게 출현한다. 또한, decoding process에 CoT reasoning path가 있다면 그 답변에 대한 모델의 confidence가 높아진다. Figure 1에서 오답인 5에 대한 certainty는 낮지만 정답인 (CoT reasoning path를 따른) 8에 대한 certainty는 높다. CoT decoding이란 k-paths 중에서 CoT reasoning path를 포함하는 답변을 선택하는 decoding 방법이다.  

 

  • 논문의 contributions
    • We present a novel finding that LLMs can reason by simple decoding changes, without the use of prompting. (prompting 없이 decoding만으로 모델이 reason 가능)
    • Our method enables a better understanding of LLM's instrinsic reasoning capabilities without imposing human priors. (prompting, 즉 사람의 이래라 저래라 없이 모델의 own independent intrinsic reasoning ability 확인 가능)
      • 모델에겐 intrinsic reasoning ability가 있고, prompting은 이런 ability를 이끌어내 top decoding path에 표면화시킨다. 따라서 여러 alternative decoding을 보면 prompting 없이도 reasoning ability가 잘 나타난 response를 뽑아낼 수 있다.
    • We further propose CoT-decoding that reliably selects CoT-paths based on answer confidence. : CoT path가 존재하면 answer confidence도 높다. 

Chain-of-Thought (CoT) Decoding


Pre-trained Langauge Models Can Reason without Prompting

  • 실험 환경
    • math problem dataset (GSM8K)
    • commonsense reasoning dataset (year parity): 연도가 짝수인지 홀수인지 맞히는 문제
    • PaLM-2 large model 사용
    • first decoding step의 k번째 토큰 선택
    • greedy decoding path: k = 0
    • alternative top-k decoding path: k > 0
  • 실험 결과
    • LLMs indeed cannot reason if we only consider the greedy decoding path: 모델이 대부분 간단한 문제들로만 pre-trained 됐기 때문에 CoT path 없이 direct problem-solving을 한다. 이러한 경향성은 reasoning task에서 상당히 낮은 accuracy를 보이게 한다. 
    • LLMs can reason if we consider the alternative decoding paths: 반대로, first decoding step에서 alternative top-k (k > 0) tokens를 선택하고 이후에는 (second decoding step부터는) greedy decoding을 하면 CoT reasoning path가 출현한다. 이는 모델이 pre-training 과정에서 inherent reasoning ability를 습득했음을 의미하고 이러한 능력이 greedy decoding으로 인해 잘 드러나지 않았던 것이다. 따라서 alternative path를 탐색함으로써 이러한 능력을 발현시킬 수 있다.

파란색은 모델의 final answer에 대한 confidence (average of probability differences of tokens consisting of k-th path)

CoT-Decoding for Extracting CoT Paths

그렇다면 어떻게 이런 CoT path를 효과적으로 뽑아낼 수 있는가?

 

Table 1을 보면 CoT path가 non-CoT path보다 항상 높게 rank하는 것도 아니고 CoT path가 predominant answer인 것도 아니다. (모델로 하여금 추론 과정 + 답을 여러 번 출력하게 한 다음, 가장 빈번하게 등장한 답을 최종 답으로 사용하는 self-consistency가 CoT path 추출에 적용될 수 없음.)

 

모델의 logit을 조사하다 CoT path는 답에 대해 상당히 높은 confidence를 보인다는 것을 발견해 냈고 이는 top token과 secondary token의 probability 차이로 나타낼 수 있다.

k는 k번째 path를 의미하고 x_t^1과 x_t^2는 k번째 path에서 answer에 속하는 t번째 토큰을 생성할 때 top probability와 secondary probability 단어을 의미한다. 본 논문에서는 모델의 final answer에 대한 confidence를 각 토큰 x_t의 delta (probability difference)의 평균으로 정의한다. 예를 들어 Table 1의 answer "60"의 confidence는 "6"자리에 올 top 단어와 secondary 단어의 확률 차이와 "0"자리에 올 top 단어와 secondary 단어의 확률 차이의 평균으로 구할 수 있다. 

 

delta (probability difference)를 이용해서 CoT path를 추출하는 방법을 CoT-decoding이라고 한다. CoT path는 확실하게 non-CoT answer보다 높은 confidence (delta)를 가지므로 이 사실을 이용해 CoT path를 뽑아낼 수 있다.

 

GSM8K에 있는 100개의 질문에 대해 양적 조사를 했을 때, 각 질문에 대해 highest answer confidence를 갖는 decoding path (answer)를 뽑았더니 88개가 CoT path였다. 이를 통해 모델의 answer confidence score를 통해 CoT path를 추출할 수 있음을 알 수 있다. (아하~ 조사해보니 CoT path는 높은 confidence score를 가지니까 높은 confidence score를 가지는 답을 추출하면 CoT path, 즉 correct answer일 것이다...는 아이디어 이용)

 

  • Comparing different CoT-path extraction approaches: 아래 4가지 방법을 이용해 10개의 path를 decoding하고 그 결과를 이용한 task accuracy는 Table 2와 같다. 모델의 확률 자체 또는 length-normalized 확률 자체는 CoT path를 잘 뽑아내지 못해 성능이 낮고 delta (probability difference)를 이용한 CoT-decoding이 CoT path를 가장 잘 추출하여 성능 개선이 대폭 일어났다. (이때, CoT path는 길이가 길다는 직관을 이용해 length-normalized log prob.을 이용했지만 성능 개선이 크지 않았다...)
    • greedy decoding은 각 단계마다 가장 높은 prob.을 가지는 토큰을 선택해서 final answer generate.
    • highest log-prob은 first decoding step에서 10개의 path (1 greedy + alternative top-9 tokens)를 생성한 뒤 이후 greedy decoding을 통해 10개의 paths 완성. 이후 각 path의 log-prob를 계산하고 가장 높은 highest log-prob를 가지는 answer를 final answer로 사용.
    • highest LN log-prob은 first decoding step에서 10개의 path (1 greedy + alternative top-9 tokens)를 생성한 뒤 이후 greedy decoding을 통해 10개의 paths 완성. 이후 각 path의 LN log-prob를 계산하고 가장 높은 LN log-prob를 가지는 answer를 final answer로 사용.
    • CoT-decoding은 first decoding step에서 10개의 path (1 greedy + alternative top-9 tokens)를 생성한 뒤 이후 greedy decoding을 통해 10개의 paths 완성. 이후 각 path의 confidence score를 계산하고 가장 높은 confidence score (delta)를 가지는 answer를 final answer로 사용. => 이렇게 해서 구한 response와 GT와의 accuracy

  • Identify the answer spans: delta를 계산하려면 answer span을 찾아야 한다. 그 방법으로 math reasoning task에서는 response의 last numerical value를 answer span으로 사용할 수 있다. 또는 prompt에 "So the answer is"를 추가해 이 continuation을 answer span으로 사용할 수 있다. 
  • Sampling under the standard QA format:  CoT-decoding은 first decoding step에서 alternative tokens를 탐색한다. 그렇다면 sampling도 비슷하게 CoT reasoning path를 추출할 수 있는가? sampling w/o few-shot CoT prompt의 성능은 좋지 않았다...sampling은 first decoding step에서 direct answer를 제공하는 모델의 성향에 영향을 크게 받기 때문에 first decoding step에서 CoT-decoding에 비해 diversity가 낮다. (완벽히 이해하진 못했지만...sampling은 top-k words 중에서 랜덤으로 하나를 뽑는 거 + 모델의 direct answer tendency에 영향을 받음 but CoT-decoding은 그냥 top-k words의 path를 모두 생성한 다음에 가장 높은 confidence score를 갖는 path를 final answer로 쓰는 거니까 first decoding step에서 diversity가 더 높다는 말 같음...)

  • Branching at other decoding steps: 또다른 질문은, first decoding step에서뿐 아니라 뒤의 decoding step에서도  branching이 가능하는 것이 좋냐는 것이다. Figure 2에서 second decoding step에서도 분기를 하고 있다. 확실히 first decoding step (step 0)에서의 alternative top-k tokens가 답변의 diversity를 높여주는 것을 확인할 수 있다. 이에 반해 first token 뒤에 생성되는 단어들은 상당히 앞서 생성된 단어들에 영향을 받는다. 예를 들어 "5" 다음에 생성되는 토큰은 오답을 고치지 않고 바로 direct answer를 내버린다. 물론~~~ optimal branching은 어떤 테스크냐에 따라 달라진다. 예를 들어 year parity task는 중간에 분기를 하는 것이 correct CoT path를 출력하게 한다.
  • Aggregation of the decoding paths: top-k paths를 생성했을 때 이들을 합치는 하나의 방법으로 가장 빈번하게 등장한 answer a를 final answer로 사용하는 self-consistency w/o prompts가 있다. aggregation은 모델의 logit 값의 변화에 따른 sensitivity를 완화시켜주고 maximum delta를 갖는 path에만 의존하는 것을 막기 위해서다. 그러나 dominant answers가 항상 correct answer는 아니다. 따라서, 본 논문에서는 weighted aggregation module을 사용한다. answer = a인 k-th path의 delta를 모두 더했을 때 가장 큰 값이 나오는 answer a를 final answer로 사용한다. 이 aggregation approach는 결과에 대한 stability를 높여준다. 

 

Experiments


Experiment Setup

default input은 "Q: [question]\nA:"이고 모델에게 prefix의 continuation을 생성하라고 한다. k = 10 as default이고 first decoding step 이후에는 greedy decoding으로 각 10개의 path에 대해 responses를 생성한다. 

Datasets

mathematical reasoning에는 GSM8K (Grade-school math problems) 데이터셋과 MultiArith를 사용했다. commonsense reasoning에는 year parity task를 수행시켰다. 이 테스크는 심지어 SoTA 모델인 GPT-4도 direct prompt에 대해선 어려움을 겪는다고 한다. 추가로, Big-Bench-Hard dataset을 사용해 symbolic reasoning task에 대한 성능을 조사했다. 

Models

PaLM-2 (X-small, small, medium, large), Mistral-7B, Gemma-7B (all pre-trained models)

CoT-Decoding Effectively Elicits Reasoning from LMs

  • CoT-decoding is the only ecoding strategy that effectively improves LM reasoning
    • Mistral-7B + GSM8K dataset 사용했을 때의 acc

  • CoT-decoding effectively elicits reasoning across LMs

  • CoT-decoding elicits reasoning across model scales

  • CoT-decoding partially closes the reasoning gap between pre-trained and isntruction-tuned models, without using any supervised data: CoT-decoding은 pre-trained model이 instruction-tuned model과 비슷한 성능을 가질 수 있도록 한다. Figure 4에서 pre-trained PaLM-2 Large model의 성능 63.2%는 같은 크기의 instruction-tuned model의 성능 67.8%와 비슷한 성능을 갖는다. 이는 곧 많은 양의 CoT data로 instruction-tuning한 모델의 성능을, decoding 방법을 살짝 바꾼 pre-trained model로 achieve할 수 있음을 의미한다. 더욱 흥미로운 것은, CoT decoding이 instruction-tuned model의 성능도 향상시킬 수 있다는 것이다. 조사 결과, 대량의 CoT data로 instruction-tuned된 모델이 사실은 CoT를 이용하지 않고 direct answer를 하고자 하는 경향이 있음을 밝혀냈다. 이때 CoT decoding을 사용해주면 성능이 올라간다. 

  • Choice of K: k가 커질수록 성능이 좋아졌다. 이는 곧 모델이 decoding할 때 correct CoT path가 이미 존재하지만 paths들이 상당히 낮게 ranked되었다는 것을 의미한다. instruction-tuned model에 대해서는 k의 설정이 덜 중요했는데, 그 이유는 이미 instruction-tuning이 decoding process에서 CoT-paths를 first decoding path에서 높게 rank했기 때문이다. 

 

CoT-decoding Enables a Better Understanding of Model's Intrinsic Reasoning Abilities

본 논문의 approach에서 중요한 점은, human-provided prompts를 제거해 LM의 instrinsic reasoning ability에 대한 truthful assessment가 가능해졌다는 점이다. 이전 section에서는 LM이 grade-school math와 commonsense task에 대해 뛰어난 reasoning 능력을 가졌음을 보였다면 이번 section에서는 reasoning task의 난이도를 달리해서 CoT-decoding을 통한 모델의 inherent reasoning ability에 대해 더 종합적으로 소개한다.

 

  • symbolic reasoning task: 기호(symbol)와 규칙을 사용하여 논리적 추론을 수행하는 작업
    • Coin Flip task - 2, 3, 4 rounds: n번째 round만큼 동전을 던지고 그 결과와 이에 대한 질문을 모델에게 준 다음 답을 요구
    • Web of lies: 3-5개의 truth/lie statements를 모델에게 주어준 다음 이에 대해 질문하는 것
    • Multi-step arithmetic wl various depth level d and length l: 여러 스텝을 거쳐야 하는 수학적 문제
    • Sports Understanding and Object Counting
    • 난이도가 더 어려워질수록 더 synthetic한 문제

 

  • The presence of correct CoT paths depends on the task difficulty levels and correlates with task prominence in the pre-training distribution: 비록 CoT-decoding이 reasoning task를 수행하는 데 도움을 주긴 했지만 이는 task diffuclty level에 큰 영향을 받았다. task가 간단할수록 CoT-path를 더 쉽게 찾을 수 있었다. 모델의 top-k paths를 직접 확인해보니 1-2개의 round에서는 correct CoT path를 generate했지만 3 round부터는 correct CoT path를 잘 generate하지 못했다. Figure 5의 결과를 봐도 그렇다. 이는 모델이 pre-trained할 때 보다 간단한 문제들이 사용되었고 LM이 trained distribution에 큰 영향을 받기 때문에 복잡한 synthetic task에 대해선 잘 해결하지 못하는 모습을 보인다. 
  • CoT-decoding unveils models' instrinsic vulnerabilities in reasoning: 본 논문은 LM이 여전히 어려워하는 분야를 밝혀내느 데에도 기여했다. 예를 들어 Coin-Flip이나 Web-of-Lies task에 대해 모델은 CoT path를 generate할 수 있기는 하지만 이는 테스크의 복잡도가 올라갈수록 길을 잃는 듯했다. 수학 테스크를 수행할 때에는 correct 수학적 해결 절차보다는 그냥 left-to-right하게 연산을 진행했다. 
  • 기존의 CoT prompts는 'teaching' role을 하여 모델로 하여금 그냥 그 prompt를 흉내내도록 했다. Sport Understanding task에 대해 prompt는 모델로 하여금 인위적인 전략을 따르도록 하여 모델의 능력을 제한한다는 것을 밝혔다. 

Combining CoT-decoding with CoT-Prompting

CoT-decoding과 CoT-Prompting을 결합하면 더 뛰어난 reasoning gain을 얻을 수 있었다. self-consistency w/ prompts와 w/o prompts는 strong performance를 보였고, CoT-decoding의 aggregation method도 비슷한 cost로 뛰어난 성능을 보엿다.

Related Work


  • Chain-of-Thought Reasoning in LLMs: few-shot prompting은 task-specific하고 이는 generability를 제한하고 높은 퀄리티의 prompting은 수고스럽다. 또 어떤 prompt를 쓰냐에 따라 모델의 성능이 달라져서 inconsistent performance를 보인다. 더욱이, 이런 prompting은 모델의 posterior distribution을 알기 어렵도록 바꿔놓는다. 그러나 본 논문의 CoT-decoding은 explicit prompting 없이도 모델이 various task에 걸쳐 CoT path를 생성할 수 있다는 것을 밝혔다.  
  • Instruction-tuning to elicit CoTs in LMs: supervision이 효과적이기는 하다만....resource intensive해~~
  • Decoding algorithms for LMs: 기존의 decoding methods (e.g., greedy, temperatrue sampling, top-k sampling, nucleus sampling 등)은 accuracy보다는 diversity에 초점을 맞췄고 이는 reasoning task의 성능을 높이기에 적절하지 않다. small model의 특징은 제거하고 large model의 특징은 강조하는 Contrastive decoding은 reasoning task에서 좋은 성능을 보이지만 추가로 모델이 필요하다는 단점이 있다. 

Conclusion


여기에서 말하는 CoT-decoding 성능 좋아요~~ 는 first decoding step에서 top-10 tokens를 시작으로 하는 10개의 path를 초기화하고, 이 다음부터 (second decoding step부터는) greedy decoding을 사용함으로써 10개의 path를 완성하고 각각의 confidence score (delta)를 계산해준다. 그 다음, answer = a에 대해 가장 높은 summation of delta를 갖는 a (<- 이 answer가 CoT path로부터 도출된 answer!!!! 왜냐면 CoT path들이 가장 높은 confidence score delta를 가질 거니까)를 final answer로 사용한다. 이 final answer와 GT를 비교함으로써 acc 계산 가능~~~ 높은 acc가 나오므로 아~~ CoT path로부터 도출된 값이 correct answer구나~~를 다시 확인할 수 있다.