관심있는 NLP 논문을 읽어보고 간단히 정리했습니다.
혹시 부족하거나 잘못된 내용이 있다면 댓글 부탁드립니다 🙇♂️
[FAIR, AI at Meta, The Hebrew University of Jerusalem]
- decoder-only transformer가 infinite multi-state RNNs으로 개념화 될 수 있다는 것을 입증
- 나아가 사전학습된 transformers를 finite multi-state RNNs으로 전환
- 이때 사용되는 새로운 compression policy, TOVA를 제시
1. Introduction
transformer의 아키텍쳐가 자연어처리 분야에서 핵심으로 자리잡게 되었지만, 이것과 기존 RNN과의 관계에 대한 연구 및 논의도 끊이지 않고 있습니다.
attention을 보다 효율적으로 수행하게끔 하는 연구도 많고 transformer의 아키텍쳐를 취하지 않는 연구도 많습니다.
본 논문에서는 transformer를 현재의 'state'를 다음 스텝에 전달하는 RNN의 개념으로 설명 가능하다고 주장합니다.
이때 여러 개의 state를 가지고 RNN을 구조를 갖추게 되므로 multi-state RNNs (MSRNN)이라고 칭합니다.
(연산만 가능하다면) 이론상 sequence의 길이가 어떻든지 간에 이전 KV에 모두 attend 가능한 방식을 'Infinite' 하다고 표현할 수 있습니다.
하지만 현실적으로 길이가 긴 텍스트를 처리하는 것은 쉽지 않으므로 attention을 압축해야(attention 연산에 필요한 KV를 취사 선택) 하는데, 이에 대한 표현을 'Finite'하다고 합니다.
2. Background
2.1. RNNs
sequential data를 recurrent 방식으로 처리하는 딥러닝 아키텍쳐를 뜻합니다.
$l$번째 layer의 $t$번째 토큰을 나타내는 $x^{l}_{t}$와 이전 time step의 hidden state를 나타내는 $h^{l}_{t-1}$을 입력으로 받습니다.
그리고 업데이트된 token representation $x^{l+1}_{t}$와 새로운 hidden state $h^{l}_{t}$를 출력으로 반환합니다.
일반적으로는 $x^{l+1}_{t}$와 $h^{l}_{t}$을 같게 설정합니다.
$$x^{l+1}_{t},h^{l}_{t}=f^{l}_{RNN}(x^{l}_{t},h^{l}_{t-1})$$
2.2. Transformers
$X^{l}=(x^{l}_{1},...,x^{l}_{t}) \in \mathbb{R}^{t \times d}$ 를 입력으로 받는 transformer의 각 layer별 출력은 다음과 같이 정의됩니다.
2.3. Transformer decoders
next token prediction을 수행할 수 있도록 attention matrix의 우측 삼각형을 가려줍니다.
이를 통해 auto-regressive decoding이 가능해집니다.
3. Transformers as Multi-State RNNs
3.1. Multi-State RNNs
RNN의 아키텍쳐를 그대로 따르고 있으나, hidden state를 나타내는 것이 vector 대신 행렬이라는 점만 다릅니다.
이 행렬 $H^{l}_{t}$는 고정된 상수 또는 input 길이를 나타내는 $g$만큼의 row를 갖습니다.
$$H^{l}_{t} \in \mathbb{R}^{g(t) \times d}$$
$$x^{t+1}_{t},H^{l}_{t}=f^{l}_{MSRNN}(x^{l}_{t},H^{l}_{t-1})$$
만약 $g=1$이라면 기존 standard (single-state) RNN을 의미하게 됩니다.
3.2. Transformers are Infinite MSRNNs
이때 사용되는 hidden state matrix를 self-attention으로 구합니다.
즉, $H^{l}_{t}=(K^{l}_{t},V^{l}_{t})$가 됩니다.
self attention의 구체적인 수식과, 이를 layer 단위로 최종 적용한 수식을 순서대로 제시하면 다음과 같습니다.
3.3. Converting Pretrained Transformers into Finite MSRNNs
$g(t)=min(t,k)$로 설정해줌으로써 finite MSRNNs를 정의할 수 있습니다.
전체 context를 제한된 memory에 fit하게끔 압축하는 알고리즘은 다음과 같습니다.
1) Window: First In First Out (FIFO) strategy
2) Window + $i$: 첫 $i$개 토큰을 retaining
3) $H_{2}O$: window 범위 바깥의 토큰을 dynamically selects
3.4. Our Proposed Policy: TOVA
이전 토큰에 대한 attention 가중치를 기반으로 top states를 보유(유지)합니다.
매 decoding step마다 score가 가장 낮은 것들은 drop됩니다.
이를 도식화한 것은 다음과 같습니다.
이 방식도 $H_{2}O$와 마찬가지로 head-wise 또는 layer-wise하게 적용 가능한 방식입니다.
4. Experiments
4.1. Setup
성능 평가를 위한 태스크는 크게 세 가지로 구성됩니다.
1) Language modeling: PG-19 test set에 대한 perplexity를 측정합니다.
2) Long range understanding: ZeroSCROLLSS 벤치마크에서 두 개의 테스트셋을 사용
3) Text generation: 각 버전의 모델로부터 100개의 long story를 생성하고 결과를 비교
실험에 사용된 모델은 'LLaMA-2, Mistral, Yi'이며 7B 사이즈의 버전을 사용했습니다.
1) Language modeling 태스크의 경우 vanilla 버전을,
2) Long range understanding 태스크의 경우 LlaMA-2-chat, Mistral-Instruct, neural-chat 버전을,
3) Text generation 태스크의 경우 MythoLogics, LLaMA-2-13B 버전을 사용했다고 합니다.
모든 모델들의 입력 제한은 4,096 토큰입니다.
이를 초과하는 것은 truncated 되었다고 밝혔습니다.
4.2. Results
4.2.1. Langugae Modeling
multi-state size를 exponential sclae로 두고 실험한 결과를 제시합니다.
모델의 종류와 상관 없이 TOVA policy를 적용했을 때의 성능이 압도적으로 좋습니다.
4.2.2. Long Range Understanding
Long range summarization과 Long range QA로 나눠 실험한 결과입니다.
위에 설명했던 바와 같이 instruction-tuned LLM을 모델로 사용했습니다.
4.2.3. Text Generation
생성되는 텍스트의 길이는 짧지 않으므로 사람이 일일이 이를 전부 읽고 평가를 내리기가 쉽지 않습니다.
따라서 모델의 출력 결과를 GPT-4 모델을 사용하여 비교 평가하는 방식을 취했습니다.
아래의 그래프는 TOVA 방식을 적용했을 때의 출력과 Topline을 비교한 것으로, 512 토큰 이상을 처리할 때부터 성능이 급격히 향상되는 것을 알 수 있습니다.
4.3. Analysis
Recency is not all you need
TOVA가 보유 중인 토큰이 무엇인지 확인한 결과 73-36%의 토큰은 최근의 것, 나머지는 그렇지 않은 것으로 확인되었습니다.
이는 현재 시점을 기준으로 당연히 최근의 토큰이 큰 영향을 주고 더 중요한 것은 맞지만, 이것만으로 어떤 문맥 정보를 다 파악하는 것은 불가낭하다는 것을 시사하기도 합니다.
First token matters
multi-state의 사이즈를 다르게 변경해가면서 실험해본 결과, 첫 번째 토큰의 중요성을 확인할 수 있었습니다.
즉, 2-4번째 position에 해당하는 토큰들은 금방 drop되는 것과 달리 첫 번째 토큰은 거의 끝까지 attend 대상으로 남는 것이 확인되었습니다.
5. Insights
솔직히 뭐가 좋고 나쁜 건지 잘 구분이 되지 않습니다 😅
transformer 아키텍쳐를 굳이 RNN 형식으로 바꿈으로써 얻는 장점에 대해 논하지 않은 것으로 보입니다.
그리고 결국 attention 기법에서 attend의 대상을 선별적으로 취함으로써 연산량을 줄이는 것 같은데, 그게 기존의 attention 관련 연구들과도 어떤 차별점이 있는 것인지 잘 모르겠습니다.
출처 : https://arxiv.org/abs/2401.06104