Attention Is All You Need
Attention Is All You Need
- Link : https://arxiv.org/abs/1706.03762*
๐ก Attention์ ๋ฑ์ฅ ๋ฐฐ๊ฒฝ
๊ธฐ์กด์๋ ๋ชจ๋ธ์ด ์์ฐ์ด๋ฅผ ์ดํดํ๊ธฐ ์ํด Seq2Seq ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ์ต๋๋ค. Seq2Seq ๋ชจ๋ธ์ RNN์์ many-to-many์ ํด๋นํ๋ ๋ชจ๋ธ์ ๋๋ค. ๊ทธ ์ค ์ ๋ ฅ ๋ฌธ์ฅ์ ์ฝ์ด์ค๋ ๋ถ๋ถ์ โ์ธ์ฝ๋(Encoder)โ, ์ถ๋ ฅ ๋ฌธ์ฅ์ ์์ฑํ๋ ๋ถ๋ถ์ โ๋์ฝ๋(Decoder)โ๋ผ๊ณ ํฉ๋๋ค.
๋ชจ๋ธ์ด ๋ฌธ์ฅ์ ์ฝ์ด์ฌ ๋, ์ธ์ฝ๋์์๋ ๋ฌธ์ฅ์ ๋งจ ์ ๋จ์ด๋ถํฐ ์์ฐจ์ ์ผ๋ก ์ฝ์ด์ ๋ง์ง๋ง hidden state ๋ฒกํฐ์ ๋ชจ๋ ์ธ์ฝ๋ฉ๋ ์ ๋ณด๋ฅผ ์ฐ๊ฒจ ๋ฃ์ต๋๋ค. ์ด๋ก ์ธํด, ์์ ๋์จ ๋จ์ด์ ๋ํ ์ ๋ณด๋ ์ ์ฐจ ์ฌ๋ผ์ง๊ณ , ์ ๋ ฅ ๋ฌธ์ฅ์ ๊ธธ์ด๊ฐ ๊ธธ์ด์ง๋ฉด Vanishing Gradient์ ๊ฐ์ Long-Term problem์ด ๋ฐ์ํฉ๋๋ค.
์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด Attention์ด๋ผ๋ ๊ฐ๋ ์ด ๋ฑ์ฅํ์์ต๋๋ค.
๐ก Attention
์ฐ๋ฆฌ๋ ๋ฌธ์ฅ์ ์ดํดํ ๋, ๋ฌธ์ฅ ๋ด์ ๋ชจ๋ ๋จ์ด๋ฅผ ์ง์คํด์ ๋ณด์ง ์์ต๋๋ค. โAttention Is All You Need.โ๋ผ๋ ๋ฌธ์ฅ์ด ์์ผ๋ฉด ์ฐ๋ฆฌ๋ โAttentionโ์ด๋ผ๋ ๋จ์ด๋ฅผ โIsโ๋ผ๋ ๋จ์ด๋ณด๋ค ๋์ฑ ์ง์คํด์ ๋ณด๋ ๊ฒ์ฒ๋ผ ๋ง์ด์ฃ .
๋ค์ ๋งํด, ์์ธก ๋จ์ด(Output)์ ์ถ๋ ฅํ๊ธฐ ์ํด, ์ ๋ ฅ ๋ฌธ์ฅ ๋ด์์ Output๊ณผ ๊ด๋ จ์ด ๋์ ๋จ์ด, ์ฆ ์ค์๋๊ฐ ๋์ ๋จ์ด์๋ง ์ง์ค(Attention)ํ์๋ ์ปจ์ ์ด ๋ฐ๋ก Attention์ ๋๋ค.
Attention์ด๋, ๋์ฝ๋๊ฐ ๊ฐ ํ์ ์คํ ์์ ์์ธก ๋จ์ด๋ฅผ ์์ฑํ ๋ ์ธ์ฝ๋์ ๋ช ๋ฒ์งธ ํ์ ์คํ ์ ๋ ์ง์ค(Attention)ํด์ผ ํ๋ ์ง๋ฅผ ์ ์(Score)ํํ๋ก ๋ํ๋ด๋ ๊ฒ์ ๋๋ค.
๋์ฝ๋์ ๊ฐ ํ์ ์คํ ๋ง๋ค ์ธ์ฝ๋์ hidden state ๋ฒกํฐ์์ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ์ฌ, ์ธ์ฝ๋์ ๋ช ๋ฒ์งธ hidden state ๋ฒกํฐ๊ฐ ๋ ํ์ํ์ง(์ค์ํ์ง)๋ฅผ ๊ณ ๋ คํ ์ ์๊ฒ ๋ฉ๋๋ค.
๐ก Seq2Seq with Attention
๊ธฐ์กด Seq2Seq ๋ชจ๋ธ์ Attention ๊ธฐ๋ฒ์ ์ ์ฉํ๋ฉด ์ด๋ป๊ฒ ๋ ๊น์?
๊ธฐ์กด RNN ๊ธฐ๋ฐ Seq2Seq ๊ตฌ์กฐ์ ๊ฒฝ์ฐ, โ์ด์ ํ์ ์คํ ์ hidden state ๋ฒกํฐ(Output)โ์ โํ์ฌ ํ์ ์คํ ์ ๋์ฝ๋ ์ ๋ ฅ๊ฐโ์ ํตํด โ๋์ฝ๋์ hidden stateโ๋ฅผ ๊ตฌํ์ต๋๋ค.
Attention ๊ตฌ์กฐ๊ฐ ์ถ๊ฐ๋ Seq2Seq์์๋, ํ์ฌ ํ์ ์คํ ์ ๋์ฝ๋ hidden state์ ๊ฐ๊ฐ์ ์ธ์ฝ๋ ํ์ ์คํ ์ hidden state ๋ฒกํฐ๋ค์ ๋ด์ (Dot-product)ํ์ฌ Attention score๋ฅผ ๊ตฌํฉ๋๋ค.
๊ตฌํด์ง Attention score๋ฅผ ์ธ์ฝ๋ hidden state ๋ฒกํฐ๋ค์ ๊ฐ์ค์น(weight)๋ก ์ฌ์ฉํ์ฌ, ๊ฐ์ค ํ๊ท ํ์ฌ Attention ๋ฒกํฐ, ์ฆ ํ๋์ output ๋ฒกํฐ๋ฅผ ๊ตฌํด์ค ์ ์์ต๋๋ค. ์ด๋ ๊ฒ ๊ตฌํด์ง Attention ๋ฒกํฐ์ ๋์ฝ๋์ ๋ง์ง๋ง ํ์ ์คํ ์ hidden state ๋ฒกํฐ๋ฅผ concatํ์ฌ ๋ง์ง๋ง output layer์ ์ ๋ ฅ์ผ๋ก ๋ฃ์ด์ค๋๋ค.
์ถ์ฒ: ์ํค๋ ์ค ๋ฅ๋ฌ๋์ ์ด์ฉํ ์์ฐ์ด ์ฒ๋ฆฌ ์ ๋ฌธ
๐ก Transformer
Transformer๋ ์ค๋ก์ง Attention Mechanism์๋ง ์์กดํ simple network ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๋๋ค.
Attention mechanism์ ์ ๋ ฅ ๋ฌธ์ฅ(input sequence)๊ณผ ์ถ๋ ฅ ๋ฌธ์ฅ(output sequence)์ ๊ฑฐ๋ฆฌ์ ์๊ด์์ด ์์กด์ฑ(dependency)๋ฅผ ๋ชจ๋ธ๋งํ ์ ์์ง๋ง, ์์ ์ธ๊ธํ ๋ชจ๋ธ์ฒ๋ผ ๋๋ถ๋ถ์ ๊ฒฝ์ฐ recurrent network์ ํจ๊ป ์ฌ์ฉ๋๊ณ ์์์ต๋๋ค.
Recurrent ๋ชจ๋ธ์ ๊ตฌ์กฐ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ํ ์ ์๊ณ , ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ์ ๋ฐ๋ผ ๋ฌธ์ฅ(sequence)์ ๊ธธ์ด๊ฐ ๊ธธ์ด์ง๋ฉด ๊ทธ์ ๋ฐ๋ฅธ ํ์ต์ ๋ฌธ์ ๊ฐ ์๊น๋๋ค.
๋ฐ๋ผ์, Transformer๋ recurrence์์ด ์ค๋ก์ง attention mechanism์๋ง ์์กดํ์ฌ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ ๊ฐ์ ์ ์ญ ์์กด์ฑ(Global Dependency)๋ฅผ ๋ชจ๋ธ๋ง ํ ์ ์๋ ๋ชจ๋ธ์ ๋๋ค.
๋ ์ด์ RNN์ด๋ CNN ๋ชจ๋์ ํ์ํ์ง ์๊ณ , Attention Mechanism๋ง ์์ผ๋ฉด ๋๊ธฐ์ ๋ ผ๋ฌธ์ ์ ๋ชฉ ๋ํ Attention Is All You Need๋ผ๋ ์ ์ ํ์ธํ ์ ์์ต๋๋ค.
๐ก Query, Key, Value๋ก Attention Vector ๊ตฌํ๊ธฐ
์ฐ๋ฆฌ๊ฐ ์ ์ฌ๋๋ฅผ ๊ตฌํ๊ณ ์ ํ๋ vector๋ฅผ Query, ๊ทธ Query์์ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ ๋ค๋ฅธ ๋ฒกํฐ๋ค์ Key๋ผ ํฉ๋๋ค.
์ด๋ ๊ฒ ๊ตฌํ ์ ์ฌ๋ ์ ์์ Softmax๋ฅผ ์ ์ฉํ์ฌ attention score๋ฅผ ๊ตฌํ๊ณ , ๊ฐ ๋ฒกํฐ๋ค์ ๊ฐ์ธ Value์(์ฆ, key์ ๊ฐ์์ value์ ๊ฐ์๋ ๋์ผ) ๊ฐ์ค ํ๊ท ํ์ฌ Attention ๋ฒกํฐ๋ฅผ ๊ตฌํฉ๋๋ค. ์ด Attention ๋ฒกํฐ๊ฐ Query์ hidden state๊ฐ์ผ๋ก ๋ค์ด๊ฐ๊ฒ ๋ฉ๋๋ค.
Query, Key, Value๋ ์ ๋ ฅ ๋จ์ด์ Embedding ๊ฐ(X)์ ๊ฐ๊ฐ์ ๊ฐ์ค์น ํ๋ ฌ์ ๊ณฑํ์ฌ ๊ตฌํ ์ ์์ต๋๋ค.
-
$XW^Q=Q,\ XW^K=K,\ XW^V=V$
์ถ์ฒ: ์ํค๋ ์ค ๋ฅ๋ฌ๋์ ์ด์ฉํ ์์ฐ์ด ์ฒ๋ฆฌ ์ ๋ฌธ
๋จ์ผ query q์ ๋ํด, key๋ค์ ํ๋ ฌ์ธ K์ value๋ค์ ํ๋ ฌ์ธ V๊ฐ ์์ ๋ Attention ๋ฒกํฐ๋ฅผ ๊ตฌํ๋ ๊ณผ์ ์ ์์์ผ๋ก ํํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
$A(q,K,V)= \sum_i \frac{\exp(q\cdot k_i)}{\sum_j\exp(q\cdot k_j)}v_i$
์ฌ์ค ์ด๋ ๊ฒ ๋ฒกํฐ ๋จ์๋ก ๊ฐ ๋จ์ด์ ๋ํ Attention ๋ฒกํฐ๋ฅผ ๊ตฌํ๋ ๋์ , ํ๋ ฌ ๋จ์๋ก ๊ณ์ฐํ์ฌ Attention ํ๋ ฌ์ ๊ตฌํ ์ ์์ต๋๋ค.
๊ธฐ์กด์๋ ๊ฐ ๋จ์ด์ ๋ํ ๋ฒกํฐ ๊ณ์ฐ์ ํตํด Attention score๋ฅผ ๊ตฌํ๋ค๋ฉด, ํ๋ ฌ์ ์ด์ฉํ์ฌ ๋ฌธ์ฅ ๋ด ๋ชจ๋ ๋จ์ด์ ๋ํ ํ๋ ฌ ๊ณ์ฐ์ ํตํด Attention ํ๋ ฌ์ ๊ตฌํ ์ ์์ต๋๋ค.
์ถ์ฒ: ์ํค๋ ์ค ๋ฅ๋ฌ๋์ ์ด์ฉํ ์์ฐ์ด ์ฒ๋ฆฌ ์ ๋ฌธ
์ถ์ฒ: ์ํค๋ ์ค ๋ฅ๋ฌ๋์ ์ด์ฉํ ์์ฐ์ด ์ฒ๋ฆฌ ์ ๋ฌธ
$A(Q,K,V) = softmax(QK^T)V$
์ด๋ฌํ ํ๋ ฌ ๊ณ์ฐ ๋ฐฉ์์ ํตํด(๋ ผ๋ฌธ์์๋ โhighly optimized matrix multiplication codeโ๋ผ๊ณ ํํํจ) ๊ธฐ์กด RNN ๊ณ์ด์ ๋ชจ๋ธ๋ณด๋ค ์๋์ ๊ณต๊ฐ ์ธก๋ฉด์์ ์ด์ ์ ๊ฐ๊ฒ ๋ฉ๋๋ค.
๐ก Scaled Dot-Product Attention
Attention score๋ฅผ ๊ณ์ฐ ์ query์ key์ dimension์ ๋ฐ๋ผ ๋ด์ ์ ๋ถ์ฐ๊ฐ์ด ์ข์ง์ฐ์ง ๋ ์ ์๊ณ ์ด์ ๋ฐ๋ผ Softmax์ ๋ถํฌ์ ์ํฅ์ ์ค ์ ์์ต๋๋ค. ์ด๋ฅผ ๋ณด์ ํด์ฃผ๊ธฐ ์ํด ํ์คํธ์ฐจ๋ก ๋๋ ์ฃผ๋ ๊ณผ์ ์ ํตํด ๋ถ์ฐ์ 1๋ก ์ ์งํ ์ ์์ต๋๋ค. Q์ K๊ฐ ํ๊ท ์ด 0์ด๊ณ ๋ถ์ฐ์ด 1์ธ vector๋ก ์ด๋ฃจ์ด์ ธ ์๋ค๋ฉด, ํต๊ณ์ ์ผ๋ก ๊ณ์ฐ ํ์๋ $Q\cdot K$์ ๋ถ์ฐ๊ฐ๊ณผ $d_k$์ ๊ฐ์ด ๋์ผํฉ๋๋ค. ๋ฐ๋ผ์, Q์ K์ Dot-Product ๊ฐ์ key์ dimension์ธ $d_k$๋ก ๋๋ (Scaled) ์ต์ข ์ ์ผ๋ก Attention ๋ฒกํฐ๊ฐ์ ๊ตฌํ ์ ์์ต๋๋ค.
๐ก Multi-Head Attention
Multi-Head Attention์ ํ์ฉํ๋ฉด ๋์์ ์ฌ๋ฌ ๋ฒ์ ์ Attention์ ์งํํ ์ ์์ต๋๋ค. ํ ํค๋(head)๋ ํ ์ข ๋ฅ์ ๊ฐ์ค์น ํ๋ ฌ($W_i^Q, W_i^K, W_i^V$)์ ํตํด Q, K, V๋ฅผ ๊ตฌํ๊ณ ์ต์ข ์ ์ผ๋ก Attention ๋ฒกํฐ๋ฅผ ๊ณ์ฐํฉ๋๋ค. ๋ง์ฝ ์ด ํค๋๊ฐ ์ฌ๋ฌ ๊ฐ ์๋ค๋ฉด? ์ฐ๋ฆฌ๋ ์ฌ๋ฌ ์ข ๋ฅ์ ๊ฐ์ค์น ํ๋ ฌ($head_0=Attention(QW_0^Q,KW_0^K,VW_0^V), \ head_1=Attention(QW_1^Q,KW_1^K,VW_1^V) โฆ$)์ ์ด์ฉํ์ฌ ์ฌ๋ฌ ๋ฒ์ ์ Attention ๋ฒกํฐ๋ฅผ ๊ตฌํ ์ ์์ต๋๋ค. ์ด๋ ๊ฒ ๊ฐ ํค๋๋ณ๋ก ์ป์ด์ง Attention ๋ฒกํฐ๋ฅผ concatํ์ฌ ์ ์ฒด ๊ฒฐ๊ณผ ๋ฒกํฐ๋ฅผ ์ป์ ์ ์์ต๋๋ค. ์ด๋ฅผ ํตํด, ๋ฌธ์ฅ์ ์ฌ๋ฌ ๊ด์ ์์ ๋ฐ๋ผ๋ณผ ์ ์๊ฒ ๋ฉ๋๋ค.
๐ก 3๊ฐ์ง Attention
Transformer์์๋ 3๊ฐ์ง์ Multi-Head Attention์ ์ฌ์ฉํฉ๋๋ค.
-
Encoder Self-Attention: Self-Attention์ ์๊ธฐ ์์ ์๊ฒ attention์ ์ํํ๋ค๋ ์๋ฏธ๋ก, ์ฝ๊ฒ ๋งํด ์ธ์ฝ๋๋ก ๋ค์ด์จ ์ ๋ ฅ ๋ฌธ์ฅ์ ๋ชจ๋ ๋ฒกํฐ๋ค์ ๋ํด ๊ฐ๊ฐ Attention ๋ฒกํฐ๋ฅผ ๊ตฌํ๋ค๋ ๋ป์ ๋๋ค. ๋ฐ๋ผ์, ์ด๋ Q, K, V๋ ์ ๋ ฅ ๋ฌธ์ฅ์ ๋ชจ๋ ๋ฒกํฐ๋ค์ด ํด๋น๋ฉ๋๋ค. ์ด๋ ๊ฒ self-attention์ ํตํด ์ ๋ ฅ ๋ฌธ์ฅ ๋ด์ ๋จ์ด๋ค๋ผ๋ฆฌ์ ์ ์ฌ๋๋ฅผ ๊ตฌํ ์ ์์ต๋๋ค.
-
Masked Decoder Self-Attention: RNN์ ๊ตฌ์กฐ์ ์ผ๋ก ๋ค์ ๋จ์ด๋ฅผ ์์ธกํ ๋, ์ด์ ๊น์ง ์ ๋ ฅ๋ ๋จ์ด๋ค๋ง์ ์ฐธ๊ณ ํ ์ ์์์ต๋๋ค. ํ์ง๋ง Transformer๋ ๋ฌธ์ฅ ํ๋ ฌ์ ์ ๋ ฅ์ผ๋ก ๋ฐ๊ธฐ์ ๋ค์ ๋จ์ด๋ฅผ ์์ธกํ ๋, ๊ทธ ๋ค์ ๋์ค๋ ๋จ์ด๋ค๊น์ง๋ ์ฐธ๊ณ ํ ์ ์์ต๋๋ค. ์ด๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด, Masking์ ํ์ฌ Attention score matrix์์ ์๊ธฐ ์์ (๋์ฝ๋๋ก ๋ค์ด์ค๋ embedding)๊ณผ ๊ทธ ์ด์ ์ ๋์จ ๋จ์ด๋ค๋ง ์ฐธ๊ณ ํ ์ ์๊ฒ ํฉ๋๋ค.
-
Encoder-Decoder Attention: ์ด๋ฒ์๋ Self-Attention์ด ์๋, ๋์ฝ๋์์์ Query์ ๋ํด ์ธ์ฝ๋์ ๋ง์ง๋ง ์ธต์์ ๋Key์ Value๋ฅผ ์ด์ฉํด Attention์ ์งํํฉ๋๋ค.
๐ก Residual Connection & Layer Normalization (Add&Norm)
block(๋ชจ๋)์ ๋ณด๋ฉด 3๊ฐ์ ๋ฒกํฐ๊ฐ ๊ฐ๊ฐ query, key, value๋ก ๋ค์ด๊ฐ์ Multi-Head Attention์ ํตํด ๊ณ์ฐ๋๊ณ , ์ด๋ ๊ฒ Attention์ ๊ฑฐ์น ์๋ฒ ๋ฉ ๋ฒกํฐ์ ์๋์ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ฅผ ๋ํด์ฃผ๋ ๊ฒ(Add)์ residual connection์ด๋ผ ๋ถ๋ฆ ๋๋ค. Residual connection์ ํตํด ๋ ์ด์ด๊ฐ ๊น์ด์ง์๋ก gradient๊ฐ ์ ์ ์ปค์ง๊ฑฐ๋ ์์์ง๋ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์์ต๋๋ค.
๊ทธ ํ Layer Normalization์ ํตํด ๊ฐ ์ ๋ ฅ๊ฐ๋ค์ Feature๋ค์ ๋ํ ํ๊ท ๊ณผ ๋ถ์ฐ์ ๊ตฌํด batch์ ์๋ ๊ฐ ์ ๋ ฅ๊ฐ๋ค์ ์ ๊ทํ ํด์ค๋๋ค.
๐ก Positional Encoding
RNN์ ๋ชจ๋ธ์ ๊ตฌ์กฐ์ ์์ฐ์ค๋ฝ๊ฒ ๋จ์ด๋ค์ ์์์ ๋ํ ์ ๋ณด๋ฅผ ํ์ตํ์ง๋ง, Attention ๋ชจ๋ธ์ ๊ตฌ์กฐ์ ์์ ์ ๋ณด๋ฅผ ํ์ตํ์ง ์์ต๋๋ค. ๋ฐ๋ผ์ ๊ฐ ๋จ์ด์ ํฌ์ง์ ๋ง๋ค ๊ทธ ํฌ์ง์ ์ ๋ํ๋ด๋ ์ ๋ณด๋ฅผ ์ถ๊ฐํด์ฃผ๋ ๊ฒ์ Positional Encoding์ด๋ผ ํฉ๋๋ค. Positional encoding์ ๋ค์๊ณผ ๊ฐ์ ์ฌ์ธํ ํจ์๋ฅผ ์ด์ฉํ๊ณ , ์๋ฒ ๋ฉ ๋ฒกํฐ ๋ด์ ์ฐจ์ ์ธ๋ฑ์ค์ ๋ฐ๋ผ sinํจ์์ cosํจ์๋ฅผ ์ด์ฉํ์ฌ ๊ณ์ฐํฉ๋๋ค.
๐ก Transformer ๋ชจ๋ธ์ ์ฅ์
์ด๋ฌํ Transformer ๋ชจ๋ธ์
- Parallelization: ๋ณ๋ ฌํ๋ฅผ ํตํด ๊ณ์ฐ ํจ์จ์ฑ์ ๋์ด๊ณ ํ์ต์๋๋ฅผ ๊ฐ์ ํ ์ ์์๊ณ
- Long-range Dependencies: ๋ชจ๋ ์์น์์ ๋ค๋ฅธ ์์น๊น์ง์ ๊ด๊ณ(Long-Term Dependency)๋ฅผ ์ฝ๊ฒ ํ์ตํ ์ ์์์ผ๋ฉฐ
- Interpretable: ๋ํ Attention score๋ฅผ ์๊ฐํํ์ฌ ๊ฐ ์์๋ค ๊ฐ์ ๊ด๊ณ๋ฅผ ์๊ฐํํ์ฌ ๋ณผ ์ ์์ต๋๋ค.
Leave a comment