과거(2020.06)에 나온 논문을 읽어보고 간단히 정리했습니다.
캐글 프로젝트를 하면서 이 모델에 대해 공부를 한 번 하고 싶어서 빠르게 읽고 간단히 정리한 내용입니다!
(버전 3가 올해에 나와 있어서 그것도 얼른 공부를 해야 될 것 같네요)
혹시 부족하거나 잘못된 내용이 있다면 댓글 부탁드립니다 🙇♂️
[Microsoft Research]
disentangled attention mechanism과 enhanced mask decoder라는 새로운 기법을 적용.
기존 BERT 및 RoBERTa 모델의 단점을 개선한 새로운 architecture, DeBERTa를 제시.
- 배경
당시(2020년도)에는 self-attention을 기반으로 한 여러 모델들이 쏟아져 나오고 있었고,
특히 BERT를 중심으로 Pre-trained Language Model (PLM)이 다양하게 연구되고 있었습니다.
RoBERTa, GPT, ELECTRA, XLNet 등 다양한 아키텍쳐를 지닌 모델들이 잘 알려져 있습니다.
본 논문에서 제시한 아키텍쳐를 지닌 DeBERTa 모델은 기존에 NLU(Natural Language Understanding) 분야에서 가장 잘 나가던 BERT와 RoBERTa를 개선한 것이라고 합니다.
따라서 DeBERTA의 아키텍쳐 중 어떤 부분이 기존 모델들과 차이를 보이는지 간단히 살펴보고자 합니다.
- 특징
BERT는 각 단어에 대한 임베딩(word embedding)을 위치 임베딩(position embedding)과 더함으로써 각 단어의 문장 내 상대적 위치를 반영하도록 했습니다.
이와 달리 본 논문에서 제시된 DeBERTa는 content(내용)와 position(위치)을 encoding한 두 개의 벡터를 입력으로 받습니다.(Disentangled attention)
이는 attention weight가 content 뿐만 아니라 relative position에도 영향을 받는다는 점에서 착안한 것입니다.
그러나 Disentangled attentoin 방식을 사용한다고 하더라도 MLM(Masked Language Modeling)에서 단어의 absolute position이 반영되는 것은 아닙니다.
(여기에서는 absolute position이 relative position과는 또다른 중요도를 가질 수 있다고 봅니다)
따라서 masked words를 decode하여 원래 무슨 word인지 맞히는 MLM 과정에서, softmax layer 직전에 absolute word position embedding을 결합하는 방식을 취했습니다.(Enhanced Mask Decoder)
(이 내용들이 수식적으로 표현되는 내용은 본 포스팅에서 전부 생략했습니다. 따라서 추가 정보가 필요하신 분들은 논문/repo를 직접 확인하시거나, 수식을 잘 설명한 다른 블로그 포스팅을 참고하시길 바랍니다.)
마지막으로 여기서 사용된 virtual adversarial training algorithm인 Scale-invariant-Fine-Tuning(SiFT)에 대해 언급하고 있습니다.이는 모델이 adversarial example에 대해 강건함을 가질 수 있도록 input에 약간의 perturbation을 적용하는 방식입니다.
NLP 태스크에서는 원래 단어 sequence 대신 word embedding에 perturbation을 적용하는데,
본 논문에서는 이를 normalized word embedding에 적용하는 방식을 새로이 제안하고 있습니다.
보다 구체적으로는, DeBERTa를 NLP의 downstream task에 fine-tuning할 때, word embedding vector를 -> stochastic vector로 normalize하고, normalized된 embedding vector에 perturbation을 적용하는 방식을 뜻합니다.
- 개인적 감상
이제는 슬슬 논문에서 언급하는 아키텍쳐의 차이들이 잘 이해되는 때가 된 것 같습니다.
이 논문에서는 attention 연산에 입력을 기존과 다르게 취하고 있음과 absolution position embedding을 반영한 enhanced mask decoder를 적용한 것이 가장 중요한 키포인트입니다.
하지만 구체적인 수식과 알고리즘을 이해하는 데에는 시간이 좀 더 걸리는 것이 아쉽습니다.
이 내용은 추가로 공부하더라도 업로드 할 계획은 없지만 그러한 내용도 빠르게 잘 이해하면 좋겠다는 생각이 드네요.