NLP

Transformer에서 어떻게 다음 토큰을 예측할까?

MINAIR 2025. 2. 24. 11:59

LLM의 기반이 되는 transformer는 주어진 토큰 셋에 대해 다음 토큰의 확률을 최대화하도록 학습된다. transformer는 길이가 n인 텍스트에 대해 아래 loss function을 최소화하도록 gradient descent를 이용해 파라미터를 업데이트한다. $$Loss = -\sum_{i=1}^{n} \log P(w_i|w_1, w_2,...,w_{i-1})$$

 

음~ 그렇구나 할 수 있지만 나는 조금 더 구체적으로 알아보기로 했다. 번역 태스크 (한국어 -> 영어) 를 예로 들어보자. 

  • 한국어: 고양이가 매트 위에 앉아있어.
  • 영어: The cat is sitting on the mat. 

 

학습 (Training)


먼저 훈련 (training) 시 encoder와 decoder의 입력은 각각 무엇이며, decoder의 마지막 레이어에서 어떻게 다음 토큰을 예측하는 지 알아보자. 

  • encoder의 입력: 주어진 텍스트 전부 다. 예를 들어 translation 태스크에서는 번역 전 텍스트 전부, classification 태스크에서는 분류할 텍스트 전부, QA 태크스에서는 query 텍스트 전부다. 주어진 텍스트를 고정된 크기의 벡터로 변환한 뒤 decoder로 전달해 decoder가 맥락을 파악할 수 있도록 해준다. 우리 예시에서는 "고양이가 매트 위에 앉아있어"를 토큰 임베딩과 위치 임베딩을 거친 후 encoder의 입력으로 넣어준다. 
  • decoder의 입력: 이전까지 모델이 예측한 토큰 셋. 예를 들어 보자. t=1에서 <start_token>이 decoder의 입력으로 들어가면 decoder의 마지막 레이어에서는 전체 vocab size 차원을 갖는 확률 분포를 출력한다. 이때의 확률 분포가 바로 <start_token> 다음에 나올 토큰에 대한 예측을 나타낸다. 이 확률 분포와 원핫 벡터 형식의 ground-truth "The"를 cross entropy 식을 이용해 loss function을 계산한다. (확률 분포와 원핫 벡터의 ground-truth는 모두 vocab_size의 차원을 갖는다) 이 loss function을 이용해 gradient descent로 파라미터를 업데이트한다. loss function이 맨 처음에 본 것과 다르지만 의미는 같다. 위의 loss function은 어떤 토큰 셋이 주어졌을 때, 그 뒤에 와야 할 정답 토큰의 확률을 최대화하도록 한다는 intution이고, 이를 cross-entropy 식으로 구한 것이 바로 아래 los function이다. $$loss = -\sum_{d=1}^{vocab size} y_d \log \hat{y_d} $$ t=2에서는 또 그 다음 토큰을 예측해야 한다. 이때 모델이 어떤 토큰을 예측했든지 간에 ground-truth 토큰인 "The"을 사용한다. t=1에서 decoder가 <start_token>을 입력으로 받아 다음 토큰으로 "a", "an"을 예측했든지 간에 훈련 때에는 teacher-forcing해야 하기 때문에, ground-truth "The"를 사용한다. 다시 t=2로 돌아가서, t=2의 decoder의 입력은 이전까지 모델이 예측한 토큰 셋, 즉 <start_token> + The를 넣어준다. 그럼 decoder의 마지막 레이어에서는 전체 vocab size의 크기를 갖는 확률 분포를 출력한다. 가장 최신 토큰, 즉 예측할 토큰의 바로 이전 토큰 "The" 행과 ground-truth 벡터 사이의 cross entropy를 구해 loss를 구한다. t=2에서 ground-truth는 "cat"이 되겠다. loss를 이용해 또 역전파 해서 모델의 파라미터를 업데이트한다. 그럼 t=3에서의 decoder의 입력은 <start_token> + The + cat이다. 이 과정을 <eos> 토큰이 출력될 때까지 반복한다.

 

추론 (Inference)


훈련과 추론의 가장 중요한 차이점은, 추론은 teacher-forcing을 사용하지 않는다는 것이다. 

  • encoder의 입력: 훈련 때와 마찬가지로 주어진 텍스트 전부 다. 맥락을 파악하는 벡터를 만들어 decoder에게 전달해준다. 
  • decoder의 입력: 훈련 때와 마찬가지로, 이전에 생성된 토큰 셋. 그렇다면 훈련과 무엇이 다른 지 자세히 보자. t=1에서 <start_token>이 decoder의 입력으로 들어가면 decoder의 마지막 레이어에서는 전체 vocab size 차원을 갖는 확률 분포를 출력한다. 이 확률 분포가 바로 <start_token> 다음에 나올 토큰에 대한 예측을 나타낸다. 추론이기 때문에 ground-truth를 모른다고 가정하기에, 바로 확률 분포에 argmax 함수를 취해 가장 큰 확률로 예측된 토큰을 다음 토큰으로 예측한다. 즉, ground-truth를 고려하지 않고 모델이 예측한 토큰을 바로 사용한다. 그럼 t=2에서 decoder의 입력은 <start_token> + "A"가 되고 (예를 든 것), 그럼 decoder의 마지막 레이어에서는 전체 vocab size 차원을 갖는 확률 분포를 출력한다. 가장 최신 토큰, 즉 예측할 토큰의 바로 이전 토큰 "A" 행에 argmax 함수를 취해 가장 큰 확률로 예측된 토큰을 다음 토큰으로 예측한다. 가장 큰 확률로 예측된 토큰이 "dog"이라면, t=3의 decoder의 입력은 <start_token> + A + dog이다. 마찬가지로 <eos> 토큰이 출력될 때까지 반복한다. 

 

훈련과 추론에서 encoder와 decoder의 입력이 무엇인지, 어떤 흐름으로 다음 토큰이 예측됐는지를 알아봤다. 다시 맨 처음에 제시한 식을 보자. $$Loss = -\sum_{i=1}^{n} \log P(w_i|w_1, w_2,...,w_{i-1})$$ i=1, 즉 t=1, 입력이 <start_token>일 때 다음 토큰이 정답 토큰 "The"일 확률을 최대화하도록 파라미터 업데이트. i = 2, 즉 t=2, 입력이 <start_token> + The 일 때 다음 토큰이 정답 토큰 "cat"일 확률을 최대화하도록 파라미터 업데이트. i=3, 즉 t=3, 입력이 <start_token> + The + cat일 때 다음 토큰이 정답 토큰 "is"일 확률을 최대화하도록 파라미터 업데이트.