관심있는 NLP 논문을 읽어보고 간단히 정리했습니다.
혹시 부족하거나 잘못된 내용이 있다면 댓글 부탁드립니다 🙇♂️
[Google Research, Google DeepMind]
- 다른 모델 간의 cross-attention을 통해 새로운 capabilities를 획득하게 하는 기법, CALM - Composition to Augment Language Models
- 기존 LLM은 're-using'하면서 새로운 few additional parameters와 data를 사용
- 다양한 도메인과 환경에 적용 가능하다는 특징(장점)을 보유
1. Introduction
LLM은 여러 태스크 중에서도 이전과 달리 commonsense 또는 factual reasoning 관련 태스크에서도 뛰어난 성능을 보인다는 것이 주목할만한 특징입니다.
그러나 아직까지도 특화된 도메인이나 지식을 모두(사실은 모든 것을 다 잘하는 AGI를 원하는 건가 싶기도 합니다만..) 잘 다루기란 쉽지 않습니다.
이러한 문제를 해결하기 위해 추가 데이터로 기존 모델(본 논문에서는 anchor model이라고 표현합니다)을 학습하는 방법을 취할 수도 있습니다.
하지만 잘 아시다시피 사이즈가 큰 모델을 학습시키는 것은 많은 비용을 초래하는 일이기도 하고, 기학습된 내용을 망각(forgetting)하는 또다른 문제에 직면할 가능성도 높습니다.
본 논문에서는 이와 같은 문제를 여러 모델을 composition 함으로써 해결하고자 합니다.
여기서 composition이라는 것은 기존의 anchor 모델에 능력을 더해줄 augmenting 모델을 합쳐 새로운 능력을 얻어내는 방식을 뜻합니다.
여러 분야 중에서도 language inclusivity와 code generation으로 이러한 능력이 생성될 수 있는지에 대해서 테스트한 결과를 연구 성과로 제시하고 있습니다.
이때 language 같은 경우 low/high resource 언어 간의 비교를 중점적으로 실험을 진행했습니다.
2. Related Works
2.1. Prameter efficient fine-tuning
introduction에서 언급했던 것처럼 LLM을 단순히 fine-tuning하는 것은 너무나도 많은 비용을 초래합니다.
따라서 기존 모델이 학습한 corpus에 포함되어 있지 않은 데이터를 추가 학습할 때는 LoRA와 같은 기법을 적용하여 학습에 사용되는 파라미터수를 줄일 수 있습니다.
2.2. Model Merging
벡터에 평균을 취하는 것과 같은 단순한 merging 기법조차도 모델 성능 향상에 좋은 영향을 준다는 것이 잘 알려져 있습니다.
그러나 이러한 방식은 merging을 적용하는 두 대상이 굉장히 잘 align 되어 있어야 한다는 제약이 존재합니다.
2.3. Model and Task Compositionality
encoder-decoder 아키텍쳐를 따르는 것이 굉장히 좋은 성능을 보인다는 것이 알려지게 되면서 널리 쓰이고 있습니다.
이것의 사용 범위는 현재 multi-modal 분야로까지 확장되었습니다.
2.4. Models as Tools
모델의 input text space를 적극적으로 활용하여 모델을 일종의 tool로 취급하는 연구도 활발히 이뤄지고 있습니다.
그러나 이는 정교한, 그리고 많은 양의 prompt engineering을 필요로 한다는 한계를 지닌 방식입니다.
3. Composition to Augment Language Models (CALM)
anchor model $m_{B}$와 augmenting model $m_{A}$를 합쳐 새로운 모델 $m_{A \bigoplus B}$을 만들어내는 CALM 방식을 소개합니다.
$$m_{A \bigoplus B} = f(m_{A}, m_{B}, \Theta_{C}, D_{C})$$
이때 $\Theta_{C}$는 추가적으로 학습 가능한 파라미터를 뜻하고, $D_{C}$는 composition에 사용되는 학습용 데이터셋을 뜻합니다.
이 방식의 중요한 전제 네 가지는 다음과 같습니다.
1) forward/backward pass의 가중치에 접근 가능해야 합니다.
2) 두 모델의 weight를 변경해서는 안됩니다.
3) 두 모델의 하이퍼 파라미터, 학습 상태 등을 변경해서는 안됩니다.
4) 목표하는 composition domain으로부터 몇 개의 예시를 제공받아 학습을 진행합니다.
3.1. Learning to Compose ($\Theta_{C}$)
composition은 $m_{A}$로부터의 $i^{th}$ layer의 representation과 $m_{B}$로부터의 $j^{th}$ layer의 representation을 cross attention 함으로써 수행됩니다.
이를 위해 $m_{A}$로부터 획득한 representation의 차원을 $m_{B}$의 것으로 변형해줄 linear transformation이 필요하고, 이를 $f_{proj}(.)$으로 표기합니다.
$(H_{Ai})$는 $m_{A}$의 $i^{th}$ representation, 그리고 $(H_{Bj})$는 $m_{B}$의 $j^{th}$ representation을 뜻하며, cross attention을 수식으로 표현한 것은 다음과 같습니다.
특히나 마지막의 $H_{A \bigoplus B_{j}}$는 residual connection을 뜻합니다.
3.2. Composition Training Data ($D_{C}$)
set of training examples $D_{C}$는 $\Theta_{C}$가 두 모델이 타겟 태스크에 대해 적절히 attend 할 수 있도록 돕는 'combined skill'을 뜻합니다.
augmenting model인 $m_{A}$는 task $t_{1}$인 key-value pair에 대해 학습하고, anchor model인 $m_{B}$는 task $t_{2}$인 numeric arithemetic 태스크에 대해 학습합니다.
composition에서는 $D_{C}$를 이용하여 $\Theta_{C}$를 학습하게 됩니다.
참고로, 위 상황은 한 개의 augmenting 모델을 한 개의 anchor 모델과 composition 하는 것을 다루고 있지만, 향후에는 여러 개의 anchor 모델과 compostion하는 것도 가능할 것이라고 본 논문에서는 언급하고 있습니다.
4. Experiments
augmenting 모델($m_{A}$)로는 PaLM2-XXS를, anchor 모델($m_{B}$)로는 PaLM2-XS 또는 PaLM2-S를 사용합니다.
4.1. Key-Value Arithemetic
여기서는 세 개의 데이터셋을 생성합니다.
(1) KV-Substitution ($D_{KV-SUBS}$)
e.g. <K1> + <K2> - <K3>, 10 + 22 - 24
(2) KV-Arithmetic ($D_{KV-MATH}$)
e.g. <K1> + <K2> - <K3>, 8
(3) Numeric-Arithmetic ($D_{NUM-MATH}$)
e.g. 10 + 22 - 24, 8
augmenting 모델을 string을 key로 갖고 unique integer를 value로 갖는 pair에 대해 학습을 하게 되고, anchor 모델은 수학적 계산 능력이 뛰어난 사전 학습 모델을 바로 사용합니다.
각 모델이 자신의 태스크만 수행할 수 있고 추가된 태스크에 대해서는 전혀 이해하지 못하는 것과 달리, composition을 통해 학습된 모델은 각 태스크를 어느 정도 준수히 수행할 뿐만 아니라 새로운 태스크에 대한 일반화 능력도 뛰어나다는 것을 알 수 있습니다.
4.2. Low-Resource Language Inclusivity
low-resource 언어에 대해 사전학습된 augmenting LM $m_{A}$를 사용합니다.
low-resource 언어와 high-resource 언어 간의 번역 성능을 테스트 함으로써 composition이 유용한 방식인지 검증합니다.
이때 사용된 데이터셋은 Next Thousand Languages (NTL)과 GSM8K를 다른 나라 언어로 번역한 데이터입니다.
결과는 다음과 같습니다.
번역의 경우 anchor 모델을 NTL 데이터셋에 대해 fine-tuning 한 것에 준하는 성능이 나타나는 것을 확인할 수 있습니다.
GMS8K 데이터셋을 번역한 케이스에 대해서도 마찬가지입니다.
다만 이때 anchor model을 fine-tuning 하는 것은 catastrophic forgetting 문제를 야기한다는 것 또한 확인 가능합니다.
4.3. Code Understanding and Generation
코드에 대한 이해도를 확인하는 태스크도 세 개로 구성할 수 있습니다.
(1) Code-Completion (CC): zero-shot evaluations on HumanEval benchmark dataset
(2) Text-to-Code (T2C): 3-shot inference on the MBPP dataset
(3) Code-to-Text (C2T): 3-shot evaluations on the CodeXGlue benchmark
augmenting model $m_{A}$에 $D_{code}$를 학습시켜 실험을 수행합니다.
이전 실험 결과와 마찬가지로 anchor 모델을 직접 fine-tuning 한 것과 비교했을 때도 그 이상의 성적을 내는 경우가 만습니다.
특정 언어에 대해서는 fine-tuning이 catastrophic forgetting으로 이어졌다고 이해할 수 있겠습니다.
4.4. Ablations
본 연구에서는 augmenting model $m_{A}$를 초기화하는 방식 비교, iterative decoding, LoRA를 중심으로 비교 실험한 결과를 보여주고 있습니다.
결과적으로 말하자면,
- augmenting model $m_{A}$을 아무렇게나(random하게) 초기화하는 것보다는 기존 vanilla 모델을 사용하는 것이 더 좋다
- augmenting model $m_{A}$을 encoder로 사용하는 것은 그렇게 좋은 전략이 아니다
- LoRA로 학습하는 것보다 들이는 비용은 적고 퍼포먼스는 더 좋다
고 할 수 있겠습니다.
5. Insights
모델 간의 cross attention에 대해서는 상상도 못해봤는데 너무나도 참신한 접근 방식이라는 생각이 들었습니다.
역시 구글에는 없는 연구가 없는 것인가..
물론 공개되지 않은 모델들에 대해서는 적용 불가능한 기법이긴 합니다만, LoRA와 같은 PEFT와 비교했을 때도 더욱 뛰어난 결과가 나타난다는 것은 분명 대단한 일인 것 같습니다.
다만 이해되지 않는 것은 왜 구글의 자체 모델을 사용했음에도 불구하고 scability를 확인하지 않았을까 하는 점입니다.
이게 단순히 국소적인 모델들에 한해서만 적용되는 학습 기법이 아니라 사이즈가 큰 모델에 대해서도 적용 가능한 방법론이면 굉장히 큰 영향을 줄 수 있다는 생각이 들었기 때문입니다.
안그래도 요즘은 특정 도메인이나 태스크에 특화된 모델들이 엄청나게 많은데 이를 적극적으로 활용할 수 있지 않을까 하는 생각이 들었습니다.
출처 : https://arxiv.org/abs/2401.02412
'Paper Review' 카테고리의 다른 글
<RLAIF, Self> Self-Rewarding Language Models (2024.01) (1) | 2024.01.22 |
---|---|
<LLM, RNN> Transformers are Multi-State RNNs (2024.01) (0) | 2024.01.20 |
<LLM> [MoE] Mixtral of Experts (2024.01) (1) | 2024.01.16 |
<NLP> [Transformer] Attention Is All You Need (2017.06) (1) | 2024.01.10 |
<sLLM> TinyLlama: An Open-Source Small Language Model (2024.01) (0) | 2024.01.08 |