Transformer에서 어떻게 다음 토큰을 예측할까?
LLM의 기반이 되는 transformer는 주어진 토큰 셋에 대해 다음 토큰의 확률을 최대화하도록 학습된다. transformer는 길이가 n인 텍스트에 대해 아래 loss function을 최소화하도록 gradient descent를 이용해 파라미터를 업데이트한다. $$Loss = -\sum_{i=1}^{n} \log P(w_i|w_1, w_2,...,w_{i-1})$$
음~ 그렇구나 할 수 있지만 나는 조금 더 구체적으로 알아보기로 했다. 번역 태스크 (한국어 -> 영어) 를 예로 들어보자.
- 한국어: 고양이가 매트 위에 앉아있어.
- 영어: The cat is sitting on the mat.
학습 (Training)
먼저 훈련 (training) 시 encoder와 decoder의 입력은 각각 무엇이며, decoder의 마지막 레이어에서 어떻게 다음 토큰을 예측하는 지 알아보자.
- encoder의 입력: 주어진 텍스트 전부 다. 예를 들어 translation 태스크에서는 번역 전 텍스트 전부, classification 태스크에서는 분류할 텍스트 전부, QA 태크스에서는 query 텍스트 전부다. 주어진 텍스트를 고정된 크기의 벡터로 변환한 뒤 decoder로 전달해 decoder가 맥락을 파악할 수 있도록 해준다. 우리 예시에서는 "고양이가 매트 위에 앉아있어"를 토큰 임베딩과 위치 임베딩을 거친 후 encoder의 입력으로 넣어준다.
- decoder의 입력: 이전까지 모델이 예측한 토큰 셋. 예를 들어 보자. t=1에서 <start_token>이 decoder의 입력으로 들어가면 decoder의 마지막 레이어에서는 전체 vocab size 차원을 갖는 확률 분포를 출력한다. 이때의 확률 분포가 바로 <start_token> 다음에 나올 토큰에 대한 예측을 나타낸다. 이 확률 분포와 원핫 벡터 형식의 ground-truth "The"를 cross entropy 식을 이용해 loss function을 계산한다. (확률 분포와 원핫 벡터의 ground-truth는 모두 vocab_size의 차원을 갖는다) 이 loss function을 이용해 gradient descent로 파라미터를 업데이트한다. loss function이 맨 처음에 본 것과 다르지만 의미는 같다. 위의 loss function은 어떤 토큰 셋이 주어졌을 때, 그 뒤에 와야 할 정답 토큰의 확률을 최대화하도록 한다는 intution이고, 이를 cross-entropy 식으로 구한 것이 바로 아래 los function이다. $$loss = -\sum_{d=1}^{vocab size} y_d \log \hat{y_d} $$ t=2에서는 또 그 다음 토큰을 예측해야 한다. 이때 모델이 어떤 토큰을 예측했든지 간에 ground-truth 토큰인 "The"을 사용한다. t=1에서 decoder가 <start_token>을 입력으로 받아 다음 토큰으로 "a", "an"을 예측했든지 간에 훈련 때에는 teacher-forcing해야 하기 때문에, ground-truth "The"를 사용한다. 다시 t=2로 돌아가서, t=2의 decoder의 입력은 이전까지 모델이 예측한 토큰 셋, 즉 <start_token> + The를 넣어준다. 그럼 decoder의 마지막 레이어에서는 전체 vocab size의 크기를 갖는 확률 분포를 출력한다. 가장 최신 토큰, 즉 예측할 토큰의 바로 이전 토큰 "The" 행과 ground-truth 벡터 사이의 cross entropy를 구해 loss를 구한다. t=2에서 ground-truth는 "cat"이 되겠다. loss를 이용해 또 역전파 해서 모델의 파라미터를 업데이트한다. 그럼 t=3에서의 decoder의 입력은 <start_token> + The + cat이다. 이 과정을 <eos> 토큰이 출력될 때까지 반복한다.
추론 (Inference)
훈련과 추론의 가장 중요한 차이점은, 추론은 teacher-forcing을 사용하지 않는다는 것이다.
- encoder의 입력: 훈련 때와 마찬가지로 주어진 텍스트 전부 다. 맥락을 파악하는 벡터를 만들어 decoder에게 전달해준다.
- decoder의 입력: 훈련 때와 마찬가지로, 이전에 생성된 토큰 셋. 그렇다면 훈련과 무엇이 다른 지 자세히 보자. t=1에서 <start_token>이 decoder의 입력으로 들어가면 decoder의 마지막 레이어에서는 전체 vocab size 차원을 갖는 확률 분포를 출력한다. 이 확률 분포가 바로 <start_token> 다음에 나올 토큰에 대한 예측을 나타낸다. 추론이기 때문에 ground-truth를 모른다고 가정하기에, 바로 확률 분포에 argmax 함수를 취해 가장 큰 확률로 예측된 토큰을 다음 토큰으로 예측한다. 즉, ground-truth를 고려하지 않고 모델이 예측한 토큰을 바로 사용한다. 그럼 t=2에서 decoder의 입력은 <start_token> + "A"가 되고 (예를 든 것), 그럼 decoder의 마지막 레이어에서는 전체 vocab size 차원을 갖는 확률 분포를 출력한다. 가장 최신 토큰, 즉 예측할 토큰의 바로 이전 토큰 "A" 행에 argmax 함수를 취해 가장 큰 확률로 예측된 토큰을 다음 토큰으로 예측한다. 가장 큰 확률로 예측된 토큰이 "dog"이라면, t=3의 decoder의 입력은 <start_token> + A + dog이다. 마찬가지로 <eos> 토큰이 출력될 때까지 반복한다.
훈련과 추론에서 encoder와 decoder의 입력이 무엇인지, 어떤 흐름으로 다음 토큰이 예측됐는지를 알아봤다. 다시 맨 처음에 제시한 식을 보자. $$Loss = -\sum_{i=1}^{n} \log P(w_i|w_1, w_2,...,w_{i-1})$$ i=1, 즉 t=1, 입력이 <start_token>일 때 다음 토큰이 정답 토큰 "The"일 확률을 최대화하도록 파라미터 업데이트. i = 2, 즉 t=2, 입력이 <start_token> + The 일 때 다음 토큰이 정답 토큰 "cat"일 확률을 최대화하도록 파라미터 업데이트. i=3, 즉 t=3, 입력이 <start_token> + The + cat일 때 다음 토큰이 정답 토큰 "is"일 확률을 최대화하도록 파라미터 업데이트.