이 포스팅은 Augmenting Language Models with Long-Term Memory 논문에 대한 공부입니다.
이 글은 Augmenting Language Models with Long-Term Memory
논문에 언급되어 있는 이전 연구에 대한 주요 흐름과 이 논문에서 해결하는 문제는 다음과 같다.
논문에서는 두 가지 main task 가 있다.
논문에서 제안한 방식을 이해하기 위해서는 3가지 모듈을 이해해야 한다.
SideNetwork는 LLM의 residual representation과 이전 side network의 이전 레이어의 값을 섞는 방식으로 진행된다.
\[\mathbf{H}^{l}_{\operatorname{Side}} = f_{\Theta^{l}_{\operatorname{Side}}} \mathbf{H}_{\operatorname{Side}}^{l-1} + (\mathbf{H}^{2l}_{\operatorname{LLM}} - \mathbf{H}^{2l-2}_{\operatorname{LLM}})\]Side Network는 LLM의 파라미터를 그대로 가져와서 메모리의 key-value를 섞기 위한 모듈이다. 논문에서는 섞는 과정을 학습하기 위해서 LLM을 weight를 그대로 가져와서 SideNet을 초기화하였다.
SideNet은 외부 메모리의 attention 연산으로부터 value를 가져온다. 기존 LLM의 in-context 정보와 메모리의 정보를 활용하기 위해서 joint attention mechanism 을 활용하였다.
\[\begin{gather} \mathbf{H}^l = \operatorname{sigmoid}(g) \cdot \mathbf{A} + (1-\operatorname{sigmoid}(g)) \cdot \mathbf{M} \\ \mathbf{A} = \operatorname{softmax} (\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d}}) \mathbf{V}, \mathbf{M} = \operatorname{Concat} \{ \operatorname{softmax} (\frac{\mathbf{Q}_i \tilde{\mathbf{K}}_i^{\top}}{\sqrt{d}}) \tilde{\mathbf{V}}_i \}_{i=1}^{\vert x \vert} \end{gather}\]이 때 $\sim$이 붙은 key와 value는 메모리로부터 얻은 정보이다. $g$ 는 각 head별로 정의되는 학습하는 파라미터이다.
메모리는 $M$ 개수 key-value를 CPU에 저장하는 모듈이다. GPT에서 연산할 때, 발생하는 key, value 표현들을 메모리에 queue 형태로 넣는다. 논문에서 사용한 사이즈는 [8K, 16K, 32K, 65K]이다. 해당 메모리는 context정보를 저장하기 위해서 사용되므로, GPT의 내부 파라미터는 아니다.
외부 메모리는 많은 key를 가지고 있으므로 모든 key에 대해서 연산하는 것은 비효율적이다. 논문에서는 이를 개선하기 위해서 토큰을 chunk로 쪼개 attention을 계산하였고, top-K 개의 값을 가져왔다.
학습 세팅
데이터 배치를 형성하는 방법은 길이가 비슷한 애들을 묶어서 고정된 길이로 만들고, batch index에 다른 길이의 문서를 넣는 방식을 사용하였다.
PG22에 대한 데이터 통계는 다음과 같다.
메모리 사이즈를 고정하고 PG22데이터에 대한 Perplexity는 MemTRM과 TRIME 보다 더 낮은 것을 확인할 수 있다. 책의 정보를 암기하는데 있어서 이전 정보들을 같이 주는 것은 효율적이고, LONGMEM 방식의 외부 메모리 저장은 더 높은 성능을 보인 것이다.
Natural Language Understanding 문제에 대해서도 2000개의 demonstration을 주고 주어진 문제를 풀게 만들었다. GPT에 넣는 형태는 di="Review: xi Sentiment: yi
이다.
메모리 사이즈가 반드시 크다고 좋은 것은 아니다. context의 사이즈가 작다면 적은 수의 메모리를 효율적으로 사용하는 것이 더 좋다. 아래 실험에서는 65K메모리에 대해서 상대적으로 적은 수의 메모리를 쓰는 경우 성능이 얼마나 향상되는지 확인하였다.
주어진 context에 대해서 다음 단어를 예측하는 GPT의 쓰임이 늘어남에 따라서 더 길고 많은 정보들을 처리해야 한다. 메모리 기반으로 GPT의 내부 연산을 처리하는 방법은 더욱 정교하고 효율적인 계산을 위해서 필수적인 구조적 개선이다. NeurIPS 2022에 나온 memorizing transformer (MemTRM)과 NeurIPS 2023에 나온 이 연구처럼 앞으로도 더 많인 메모리 기반 모델링이 연구될 것 같다.