"Attention Is All You Need" (Vaswani et al., 2017) を軸に、系列モデリングがどう進化してきたかを数式で追いかけるノート。RNNが死んでLSTMが延命して、結局Transformerが全部持っていった話。
1. 系列モデリングの課題
自然言語は系列データ。文 x=(x1,x2,…,xT) が与えられたとき、モデルは各トークン間の依存関係を捉えないといけない。問題は2つある。
長距離依存性
系列中の離れたトークン同士が意味的に強く結びついている場面は日常的に存在する。
"The trophy didn't fit in the suitcase because it was too big."
"it" が指すのは "trophy" であって "suitcase" じゃない。モデルがこの照応を正しく解決するには、数トークン離れた "trophy" の情報を保持して、"too big" との整合性を評価する必要がある。
"The keys to the cabinet on the second floor of the old building are missing."
主語 "keys" と動詞 "are" の間に修飾句がたくさん挟まっている。途中の "cabinet", "floor", "building"(全部単数)に惑わされず、遠くの "keys"(複数形)を正しく参照しなきゃいけない。
形式的には、入力系列の中でトークン xi と xj(∣i−j∣≫1)の間に依存関係がある場合、モデルはその情報経路を確保しなければならない。
並列化
現代の深層学習はGPU/TPUの大規模並列演算に支えられている。だけどRNN系は本質的に逐次計算で、この恩恵を受けられない。T=1000 トークンの文を処理するとき、RNNは h1→h2→⋯→h1000 と1ステップずつ計算する。GPUに数千コアがあっても、この依存連鎖は並列化できない。
この2つは独立した問題じゃない。逐次処理モデルは長い依存連鎖を辿る間に情報が減衰して、かつその処理自体が遅い。理想的なアーキテクチャは任意の2トークン間を短い経路で結びつけ、かつその計算を並列に実行できるもの。RNN → LSTM → Transformerの進化はこの2つへの回答の歴史。
2. RNNの数理
基本構造
Vanilla RNNは時刻 t における隠れ状態 ht を再帰的に計算する。
ht=tanh(Whhht−1+Wxhxt+bh)
yt=Whyht+by
xt∈Rd が入力、ht∈Rn が隠れ状態、Whh,Wxh,Why が重み行列。シンプルだけど使い物にならない。
勾配消失・爆発問題
損失 L の hk(k≪t)に対する勾配を連鎖律で展開すると、
∂hk∂L=∂ht∂Li=k+1∏t∂hi−1∂hi
各ヤコビアン ∂hi−1∂hi=diag(1−hi2)⋅Whh だから、この積は Whh の最大特異値 σmax に依存して、σmax<1 なら勾配消失(指数的に0へ)、σmax>1 なら勾配爆発(指数的に発散)。
つまりVanilla RNNは長距離依存性を学習できない。理論的に死んでるんだよね。
3. LSTMの数理
ゲート機構
LSTM (Hochreiter & Schmidhuber, 1997) はセル状態 ct と3つのゲートを導入して、勾配消失を緩和した。延命措置としてはかなり優秀。
ft=σ(Wf⋅[ht−1,xt]+bf)(忘却ゲート)
it=σ(Wi⋅[ht−1,xt]+bi)(入力ゲート)
c~t=tanh(Wc⋅[ht−1,xt]+bc)(候補セル値)
ct=ft⊙ct−1+it⊙c~t(セル状態更新)
ot=σ(Wo⋅[ht−1,xt]+bo)(出力ゲート)
ht=ot⊙tanh(ct)(隠れ状態)
σ はシグモイド関数、⊙ はアダマール積。
なぜ勾配消失が緩和されるか
セル状態の勾配経路を見ると ∂ct−1∂ct=ft。忘却ゲート ft≈1 のとき、勾配はほぼそのまま過去に伝播する。RNNのように Whh の累乗が掛かるんじゃなくて、スカラーゲート値の積になるから勾配の流れが安定する。「定数誤差カルーセル」と呼ばれる仕組み。
LSTMの限界
計算は依然として逐次的(ht は ht−1 に依存)で並列化不可。セル状態は固定次元ベクトルで情報のボトルネックになる。超長系列ではゲート積の累積で情報が減衰する。延命はしたけど、根本的な解決じゃない。
4. Bidirectional LSTM
Bi-LSTMは系列を前方向と後方向の2つのLSTMで処理する。
ht=LSTMfwd(xt,ht−1),ht=LSTMbwd(xt,ht+1)
ht=[ht;ht]∈R2n
各時刻の表現が未来の文脈も含むから、NERや品詞タグ付けで単方向LSTMを大きく上回った。だけど情報経路は依然として逐次的な隠れ状態の伝搬に依存しているし、並列化は前方向・後方向の2パスに分かれるだけで各パス内は逐次的。パラメータ数も約2倍。
5. Seq2Seqと Attentionの萌芽
Seq2Seqのボトルネック
Seq2Seq (Sutskever et al., 2014) はEncoder-Decoder構造で、Encoderの最終隠れ状態 c=hTenc を単一の固定長ベクトルとしてDecoderに渡す。入力系列全体が1つのベクトルに圧縮されるから、長い入力ほど情報が失われる。
Bahdanau Attention (2014)
Bahdanau et al. はデコーダの各ステップでエンコーダの全隠れ状態に動的に注目する仕組みを提案した。これが全ての始まり。
et,s=va⊤tanh(Wahtdec+Uahsenc)
αt,s=∑s′exp(et,s′)exp(et,s),ct=s∑αt,shsenc
コンテキストベクトル ct はデコードの時刻 t ごとに異なるベクトルになって、入力系列の必要な部分に焦点を当てることができる。「じゃあ、再帰構造いらなくない? Attentionだけでよくない?」。3年後にVaswaniがそれを証明した。
Vaswani et al. (2017) は再帰構造を完全に排除して、Attentionのみでエンコーダ・デコーダを構成した。
Encoder Block(× N=6)
Decoder Block(× N=6)
Encoder-Decoder結合の全体像
7. Scaled Dot-Product Attention
Query, Key, Value
入力行列 X∈RT×dmodel から3つの射影行列でQ, K, Vを生成する。
Q=XWQ,K=XWK,V=XWV
WQ,WK∈Rdmodel×dk、WV∈Rdmodel×dv。
Attention関数
Attention(Q,K,V)=softmax(dkQK⊤)V
やってることを分解すると、QK⊤ でトークン対ごとの類似度を計算、dk で割ってsoftmaxの飽和を防ぎ、softmaxで確率分布にして、V の重み付き和を取る。
なぜ dk で割るのか
q と k の各成分が独立に平均0・分散1なら、内積 q⊤k=∑i=1dkqiki の分散は dk。softmaxは入力の絶対値が大きいと勾配がほぼ0になるから、分散を1に正規化する。
Var(dkq⊤k)=dkdk=1
地味だけど、これがないとTransformerはまともに学習しない。
8. Multi-Head Attention
単一のAttentionでは dmodel 次元空間の1つの部分空間しか見れない。h 個のヘッドで異なる部分空間を並列に学習する。
MultiHead(Q,K,V)=Concat(head1,…,headh)WO
headi=Attention(QWiQ,KWiK,VWiV)
原論文では h=8、dk=dv=dmodel/h=64。各ヘッドは異なる射影を持つから、構文的関係、意味的関係、距離関係など異なる種類の関係性を独立に学習できる。計算コストは単一の大きなAttentionと同等(各ヘッドの次元を dk=dmodel/h に縮小するから)。
9. Positional Encoding
Self-Attentionは集合演算で、入力の順序に対して置換不変。トークンを並べ替えても出力が同じになる。これだと言語モデルとして致命的だから、位置情報を注入する必要がある。
位置 pos、次元インデックス i に対して、
PE(pos,2i)=sin(100002i/dmodelpos),PE(pos,2i+1)=cos(100002i/dmodelpos)
入力埋め込みに加算する。zpos=Embedding(xpos)+PEpos。
Positional Encoding:次元が上がるほど波長が長くなる
低次元は高周波で振動して局所的な位置差を捉え、高次元は低周波で振動して広域的な位置関係を捉える。波長は 2π から 10000⋅2π までの等比数列。
数学的に面白い性質として、任意の固定オフセット ϕ に対して t に依存しない 2×2 回転行列 Mϕ が存在して、
Mϕ[sin(ωkt)cos(ωkt)]=[sin(ωk(t+ϕ))cos(ωk(t+ϕ))]
これによりモデルは相対位置を線形変換として学習できる。加法定理からの導出なんだけど、これを思いついた人は天才だと思う。
10. Encoderブロック
各Encoderブロックは2つのサブレイヤーからなる。
Self-Attentionサブレイヤーは SA(X)=MultiHead(X,X,X)。Q, K, Vが全部同じ入力 X から生成される。
Position-wise FFNは各位置に独立に適用される2層の全結合層。
FFN(x)=max(0,xW1+b1)W2+b2
内部次元 dff=2048、dmodel=512。ReLU活性化。
各サブレイヤーの出力に残差接続と層正規化を適用する。SubLayer(X)=LayerNorm(X+f(X))。残差接続は勾配の流れを安定させ、LayerNormは学習を加速する。(ResNetの遺産がここにも。)
11. Decoderブロック
Decoderブロックは3つのサブレイヤーを持つ。ここがEncoderとの結合の核心。
Masked Self-Attention
MaskedAttn(Q,K,V)=softmax(dkQK⊤+M)V
マスク行列 Mij は i≥j なら 0、i<j なら −∞。softmax後に未来のトークンへの注意重みが0になって、自己回帰的な生成が可能になる。
Cross-Attention
ここがEncoderとDecoderをくっつける部分。
CrossAttn=MultiHead(Qdec,Kenc,Venc)
Qはデコーダから、K/Vはエンコーダの最終出力から。Decoderの各位置がEncoderの入力系列全体に対してAttentionを計算できる。Bahdanau Attentionの一般化で、加法的スコアの代わりにScaled Dot-Productを使う。
機械翻訳(英→日)だと、"上" を生成するとき、QはDecoder側から来て、K/VはEncoder側の "on" や "sat" に高い注意を向ける。最終出力は語彙サイズの確率分布。P(yt∣y<t,X)=softmax(Wvocab⋅htdec+b)。
情報経路長
| モデル | 最大経路長 |
|---|
| RNN / LSTM / Bi-LSTM | O(T) |
| Transformer | O(1) |
Bi-LSTMでは x1 の情報が xT に届くには T−1 ステップのゲート伝搬が必要で、各ステップでゲートの積が掛かって情報が減衰する。Transformerでは任意の2トークン間がAttentionの1ステップで直接接続される。αij が十分大きければ距離に関係なく情報が伝播する。
並列化
| モデル | 計算量(1層あたり) | 逐次演算数 |
|---|
| RNN / LSTM | O(T⋅n2) | O(T) |
| Self-Attention | O(T2⋅d) | O(1) |
Self-Attentionの QK⊤ は行列積で全トークンを同時に処理できる。逐次演算数 O(1)。GPUの並列性を最大限に活かせる。
文脈の捉え方
Bi-LSTMの各位置の表現は隠れ状態を通じて圧縮・蒸留されたもので、遠くのトークンの情報ほどぼやける。Transformerの各位置の表現は入力系列全体へのAttention重み分布によって構成されるから、どの位置の情報にも直接的かつ選択的にアクセスできる。多層Attentionにより、低層では局所的な構文関係、高層では大域的な意味関係を階層的に学習する。
13. 計算量の比較
T = 系列長、d = 表現次元として。
| 層の種類 | 計算量 | 逐次演算数 | 最大経路長 |
|---|
| Self-Attention | O(T2⋅d) | O(1) | O(1) |
| RNN / LSTM | O(T⋅d2) | O(T) | O(T) |
| CNN (カーネル幅 k) | O(k⋅T⋅d2) | O(logkT) | O(logkT) |
Self-Attentionは T2 の項があるから超長系列では不利。だけど典型的なNLPタスク(T<1024)では T<d であることが多くて、Attentionのほうが速い。T が非常に大きい場合にはLinear AttentionやSparse Attentionなんかの近似手法が研究されている。T2 を T に落とす戦いは今も続いてるけど、それはまた別の話。
参考文献
- Vaswani et al., "Attention Is All You Need" (2017)
- Bahdanau et al., "Neural Machine Translation by Jointly Learning to Align and Translate" (2014)
- Hochreiter & Schmidhuber, "Long Short-Term Memory" (1997)
- Sanford, Hsu, Telgarsky, "Transformers & Massively Parallel Computation" (2023-2024)
- Rigollet, "The Mean-Field Dynamics of Transformers" (2026)