이 게시물에서 LittleBird[1] 모델 구조 및 구현하는 방법에 대해서 살펴보도록 하겠습니다.
LittleBird는 카카오엔터프라이즈가 직접 개발한 Sparse Attention Transformer 모델이며, BigBird[2] 의 정확도를 유지하면서 메모리 사용량과 모델의 속도를 개선합니다. 간단하게 말씀 드리자면, LittleBird는 BigBird의 Sliding Window Attention과 LUNA[3] 의 Pack & Unpack Attention을 합치고, ALiBi[4] 기반한 새로운 양방향 위치 정보를 표현하는 방법을 사용하는 모델입니다.
LittleBird 구조는 크게 LUNA, Sliding Window Attention, BiALiBi (양방향 ALiBi) 세 개의 부분으로 나눌 수 있고, LittleBird 공식을 살펴보면서 구현하는 방법을 설명해보도록 하겠습니다.
모델에 관련한 이론적인 부분 또는 학습하는 과 정을 직접 논문을 통해 확인하시길 바랍니다.
LittleBird 레이어
C p = A t t n ( P , X ) P ′ = L a y e r N o r m ( C p + P ) C x = U S W A t t n ( X , C p ) A = L a y e r N o r m ( C x + X ) X ′ = L a y e r N o r m ( F F N ( A ) + A ) \begin{aligned}
C_p = Attn(P,X) \\
P^{\prime} = LayerNorm(C_p + P) \\
C_x = USWAttn(X, C_p) \\
A = LayerNorm(C_x + X) \\
X^{\prime} = LayerNorm(FFN(A) + A)
\end{aligned} C p = A tt n ( P , X ) P ′ = L a yer N or m ( C p + P ) C x = U S W A tt n ( X , C p ) A = L a yer N or m ( C x + X ) X ′ = L a yer N or m ( FFN ( A ) + A )
여기서 X ∈ R l x d X \in \R^{l \text{ x } d} X ∈ R l x d 는 입력 시퀀스인데 l l l 와 d d d 는 각각 시퀀스 길이와 토큰 임베딩의 차원입니다.
P ∈ R s x d P \in \R^{s \text{ x } d} P ∈ R s x d 는 Pack Attention의 projection 행렬인데, s s s 는 축소할 시퀀스 길이입니다.
Attention
A t t n ( X , C ) = σ ( Q ( X ) K ( C ) T d ) V ( C ) Attn(X,C) = \sigma(\frac{Q(X)K(C)^T}{\sqrt{d}})V(C) A tt n ( X , C ) = σ ( d Q ( X ) K ( C ) T ) V ( C )
U S W A t t n ( X , C p ) = σ ( Q ( X ) [ K ( C P ) ; K ( X ) ] T d − [ D p ; D ] T ) ⋅ [ V ( C p ) ; V ( X ) ] USWAttn(X, C_p) = \\
\begin{aligned}
\sigma(\frac{Q(X)[K(C_P);K(X)]^T}{\sqrt{d}} - [D_p;D]^T) \\ \cdot [V(C_p);V(X)]
\end{aligned} U S W A tt n ( X , C p ) = σ ( d Q ( X ) [ K ( C P ) ; K ( X ) ] T − [ D p ; D ] T ) ⋅ [ V ( C p ) ; V ( X )]
여기서 [ A ; B ] [A;B] [ A ; B ] 는 A A A 와 B B B 의 접합이며, U S W A t t n USWAttn U S W A tt n 는 Unpack & Sliding Window Attention을 의미한다.
BiALiBi
D p = ( β + γ 2 b ) J s , l D_p = (
\frac{\beta + \gamma}{2}b)J_{s,l} D p = ( 2 β + γ b ) J s , l
D i , j = { 0 , for i = j α , for i = 0 or j = 0 β ( i − j ) , for i > j γ ( j − i ) , for i < j D_{i,j} = \begin{cases}
0, &\text{for } i=j\\
\alpha, &\text{for } i=0 &\text{or } j=0\\
\beta(i-j), &\text{for } i>j\\
\gamma(j-i), &\text{for } i<j
\end{cases} D i , j = ⎩ ⎨ ⎧ 0 , α , β ( i − j ) , γ ( j − i ) , for i = j for i = 0 for i > j for i < j or j = 0
D p ∈ R s x l D_p \in \R^{s \text{ x } l} D p ∈ R s x l