層の概念と深層学習¶
はじめに:なぜ「深さ」が重要なのか¶
コンパイラの最適化パスを思い出してください。単一のパスでソースコードから最適化されたマシンコードを生成することは可能でしょうか?理論的には可能ですが、実際には複数のパス(字句解析→構文解析→意味解析→最適化→コード生成)に分けることで、各段階が扱いやすい抽象度で問題を解決できます。
深層学習の「深さ」も同じ原理です。複雑な問題を層ごとに段階的に解決することで、各層は比較的単純な変換を学習すればよくなります。この章では、なぜ深層学習が強力なのか、そしてTransformerがどのように深さを活用しているのかを探ります。
8.1 深層学習の本質:表現学習の階層¶
浅い vs 深いネットワーク¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Dict, Optional, Callable
import math
from torch.utils.data import DataLoader, TensorDataset
from matplotlib.patches import Rectangle, FancyBboxPatch
from matplotlib.patches import Circle
from matplotlib.lines import Line2D
class DepthVsWidthComparison:
"""深さと幅のトレードオフを実証"""
def __init__(self):
self.input_dim = 2
self.output_dim = 1
self.hidden_dim = 64
def create_shallow_wide_network(self, width: int) -> nn.Module:
"""浅く広いネットワーク"""
return nn.Sequential(
nn.Linear(self.input_dim, width),
nn.ReLU(),
nn.Linear(width, self.output_dim)
)
def create_deep_narrow_network(self, depth: int, width: int) -> nn.Module:
"""深く狭いネットワーク"""
layers = []
# 入力層
layers.extend([
nn.Linear(self.input_dim, width),
nn.ReLU()
])
# 隠れ層
for _ in range(depth - 2):
layers.extend([
nn.Linear(width, width),
nn.ReLU()
])
# 出力層
layers.append(nn.Linear(width, self.output_dim))
return nn.Sequential(*layers)
def compare_expressiveness(self):
"""表現力の比較"""
print("=== 深さ vs 幅:表現力の比較 ===\n")
# パラメータ数を揃えた比較
shallow_width = 512 # 2層、幅512
deep_width = 64 # 8層、幅64
deep_depth = 8
shallow_net = self.create_shallow_wide_network(shallow_width)
deep_net = self.create_deep_narrow_network(deep_depth, deep_width)
# パラメータ数を計算
shallow_params = sum(p.numel() for p in shallow_net.parameters())
deep_params = sum(p.numel() for p in deep_net.parameters())
print(f"浅いネットワーク(2層、幅{shallow_width}): {shallow_params:,} パラメータ")
print(f"深いネットワーク({deep_depth}層、幅{deep_width}): {deep_params:,} パラメータ")
# 複雑な関数の近似を可視化
self._visualize_function_approximation(shallow_net, deep_net)
def _visualize_function_approximation(self, shallow_net, deep_net):
"""関数近似能力の可視化"""
# 複雑なターゲット関数
def target_function(x, y):
return np.sin(5*x) * np.cos(5*y) + 0.5*np.sin(10*x*y)
# グリッドデータ
x = np.linspace(-1, 1, 100)
y = np.linspace(-1, 1, 100)
X, Y = np.meshgrid(x, y)
Z_target = target_function(X, Y)
# 訓練データ
n_samples = 1000
x_train = torch.rand(n_samples, 2) * 2 - 1 # [-1, 1]の範囲
y_train = target_function(x_train[:, 0], x_train[:, 1])
y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
# 簡易的な学習(デモ用)
for net, name in [(shallow_net, "Shallow"), (deep_net, "Deep")]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
pred = net(x_train)
loss = F.mse_loss(pred, y_train)
loss.backward()
optimizer.step()
# 予測を可視化
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# ターゲット関数
im = axes[0].contourf(X, Y, Z_target, levels=20, cmap='RdBu')
axes[0].set_title('Target Function')
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
plt.colorbar(im, ax=axes[0])
# 各ネットワークの予測
for idx, (net, name) in enumerate([(shallow_net, "Shallow Network"),
(deep_net, "Deep Network")]):
with torch.no_grad():
grid_points = torch.tensor(np.stack([X.ravel(), Y.ravel()], axis=1),
dtype=torch.float32)
Z_pred = net(grid_points).numpy().reshape(X.shape)
im = axes[idx+1].contourf(X, Y, Z_pred, levels=20, cmap='RdBu')
axes[idx+1].set_title(f'{name} Approximation')
axes[idx+1].set_xlabel('x')
axes[idx+1].set_ylabel('y')
plt.colorbar(im, ax=axes[idx+1])
plt.tight_layout()
plt.show()
階層的特徴学習¶
class HierarchicalFeatureLearning:
"""階層的な特徴学習の可視化"""
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def demonstrate_feature_hierarchy(self):
"""特徴の階層性を実証"""
print("=== 階層的特徴学習 ===\n")
# 画像認識での例
print("画像認識での階層:")
hierarchy = [
("Layer 1", "エッジ検出", ["縦線", "横線", "斜め線"]),
("Layer 2", "形状検出", ["角", "曲線", "円"]),
("Layer 3", "部品検出", ["目", "鼻", "口"]),
("Layer 4", "物体認識", ["顔", "車", "建物"])
]
for layer, description, examples in hierarchy:
print(f"{layer}: {description}")
print(f" 例: {', '.join(examples)}")
# 自然言語処理での例
print("\n\n自然言語処理での階層:")
nlp_hierarchy = [
("Layer 1", "文字/サブワード", ["the", "##ing", "##ed"]),
("Layer 2", "単語/句", ["running", "in the park"]),
("Layer 3", "文法構造", ["主語-動詞-目的語", "修飾関係"]),
("Layer 4", "意味/文脈", ["感情", "意図", "含意"])
]
for layer, description, examples in nlp_hierarchy:
print(f"{layer}: {description}")
print(f" 例: {', '.join(examples)}")
# 可視化
self._visualize_feature_hierarchy()
def _visualize_feature_hierarchy(self):
"""特徴階層の可視化"""
fig, ax = plt.subplots(figsize=(12, 8))
# 層の設定
layers = [
{"name": "Input", "width": 8, "features": ["Raw", "Pixels"]},
{"name": "Layer 1", "width": 6, "features": ["Edges", "Colors"]},
{"name": "Layer 2", "width": 5, "features": ["Shapes", "Textures"]},
{"name": "Layer 3", "width": 4, "features": ["Parts", "Objects"]},
{"name": "Layer 4", "width": 3, "features": ["Concepts"]},
{"name": "Output", "width": 2, "features": ["Classes"]}
]
# 各層を描画
y_positions = np.linspace(0, 1, len(layers))
for i, (layer, y) in enumerate(zip(layers, y_positions)):
# 層の矩形
rect = FancyBboxPatch(
(0.1, y - 0.05), layer["width"] * 0.1, 0.08,
boxstyle="round,pad=0.01",
facecolor=plt.cm.viridis(i / len(layers)),
edgecolor='black',
alpha=0.7
)
ax.add_patch(rect)
# 層の名前
ax.text(-0.05, y, layer["name"],
verticalalignment='center',
horizontalalignment='right',
fontsize=12, fontweight='bold')
# 特徴の例
feature_text = ", ".join(layer["features"])
ax.text(0.1 + layer["width"] * 0.05, y, feature_text,
verticalalignment='center',
horizontalalignment='center',
fontsize=10, color='white')
# 層間の接続
if i < len(layers) - 1:
# 矢印で接続
ax.arrow(0.1 + layer["width"] * 0.1 / 2, y + 0.04,
0, y_positions[i+1] - y - 0.08,
head_width=0.02, head_length=0.02,
fc='gray', ec='gray', alpha=0.5)
ax.set_xlim(-0.2, 1.0)
ax.set_ylim(-0.1, 1.1)
ax.set_title('深層学習における階層的特徴学習', fontsize=16)
ax.axis('off')
# 説明文
ax.text(0.5, -0.05, '下位層:単純な特徴 → 上位層:複雑な概念',
horizontalalignment='center', fontsize=12, style='italic')
plt.tight_layout()
plt.show()
8.2 残差接続:深いネットワークを可能にする技術¶
勾配消失問題と残差接続¶
class ResidualConnections:
"""残差接続の理解と実装"""
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def demonstrate_gradient_vanishing(self):
"""勾配消失問題の実証"""
print("=== 勾配消失問題 ===\n")
# 深いネットワーク(残差接続なし)
class DeepNetworkWithoutResidual(nn.Module):
def __init__(self, depth: int):
super().__init__()
self.layers = nn.ModuleList([
nn.Sequential(
nn.Linear(10, 10),
nn.ReLU()
) for _ in range(depth)
])
def forward(self, x):
activations = [x]
for layer in self.layers:
x = layer(x)
activations.append(x)
return x, activations
# 勾配の流れを測定
depths = [5, 10, 20, 50]
gradient_norms = {d: [] for d in depths}
for depth in depths:
net = DeepNetworkWithoutResidual(depth)
x = torch.randn(32, 10, requires_grad=True)
y, activations = net(x)
loss = y.sum()
loss.backward()
# 各層の勾配ノルムを記録
for i, layer in enumerate(net.layers):
grad_norm = layer[0].weight.grad.norm().item()
gradient_norms[depth].append(grad_norm)
# 可視化
self._plot_gradient_flow(gradient_norms)
def _plot_gradient_flow(self, gradient_norms):
"""勾配の流れを可視化"""
plt.figure(figsize=(10, 6))
for depth, norms in gradient_norms.items():
plt.semilogy(range(len(norms)), norms,
marker='o', label=f'Depth={depth}')
plt.xlabel('Layer Index')
plt.ylabel('Gradient Norm (log scale)')
plt.title('勾配消失:深い層ほど勾配が小さくなる')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
def explain_residual_connection(self):
"""残差接続の仕組みを説明"""
print("\n=== 残差接続の仕組み ===\n")
print("通常の層:")
print(" y = F(x)")
print(" 問題:Fが恒等写像を学習するのは困難")
print("\n残差接続:")
print(" y = F(x) + x")
print(" 利点:F(x) = 0を学習すれば恒等写像になる")
# 図解
self._visualize_residual_connection()
def _visualize_residual_connection(self):
"""残差接続の図解"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
# 通常の接続
ax1.set_title('通常の接続', fontsize=14)
ax1.text(0.5, 0.8, 'x', ha='center', va='center',
bbox=dict(boxstyle='round', facecolor='lightblue'), fontsize=12)
ax1.text(0.5, 0.5, 'F(x)', ha='center', va='center',
bbox=dict(boxstyle='round', facecolor='lightgreen'), fontsize=12)
ax1.text(0.5, 0.2, 'y = F(x)', ha='center', va='center',
bbox=dict(boxstyle='round', facecolor='lightcoral'), fontsize=12)
# 矢印
ax1.arrow(0.5, 0.75, 0, -0.15, head_width=0.03, head_length=0.02, fc='black')
ax1.arrow(0.5, 0.45, 0, -0.15, head_width=0.03, head_length=0.02, fc='black')
ax1.set_xlim(0, 1)
ax1.set_ylim(0, 1)
ax1.axis('off')
# 残差接続
ax2.set_title('残差接続', fontsize=14)
ax2.text(0.5, 0.8, 'x', ha='center', va='center',
bbox=dict(boxstyle='round', facecolor='lightblue'), fontsize=12)
ax2.text(0.5, 0.5, 'F(x)', ha='center', va='center',
bbox=dict(boxstyle='round', facecolor='lightgreen'), fontsize=12)
ax2.text(0.5, 0.2, 'y = F(x) + x', ha='center', va='center',
bbox=dict(boxstyle='round', facecolor='lightcoral'), fontsize=12)
# 矢印
ax2.arrow(0.5, 0.75, 0, -0.15, head_width=0.03, head_length=0.02, fc='black')
ax2.arrow(0.5, 0.45, 0, -0.15, head_width=0.03, head_length=0.02, fc='black')
# スキップ接続
ax2.arrow(0.35, 0.8, -0.1, -0.5, head_width=0.03, head_length=0.02,
fc='red', ec='red', linestyle='--', linewidth=2)
ax2.text(0.2, 0.55, 'Skip\nConnection', ha='center', va='center',
color='red', fontsize=10)
ax2.set_xlim(0, 1)
ax2.set_ylim(0, 1)
ax2.axis('off')
plt.tight_layout()
plt.show()
def implement_residual_block(self):
"""残差ブロックの実装"""
class ResidualBlock(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
# 初期化:最初は恒等写像に近い状態
nn.init.xavier_uniform_(self.linear1.weight, gain=0.1)
nn.init.xavier_uniform_(self.linear2.weight, gain=0.1)
nn.init.zeros_(self.linear1.bias)
nn.init.zeros_(self.linear2.bias)
def forward(self, x):
# 残差接続
residual = x
# 変換
out = self.linear1(x)
out = self.activation(out)
out = self.dropout(out)
out = self.linear2(out)
out = self.dropout(out)
# 残差を加算
out = out + residual
return out
# 動作確認
print("\n=== 残差ブロックの実装 ===")
block = ResidualBlock(d_model=256, d_ff=1024)
x = torch.randn(32, 10, 256)
with torch.no_grad():
y = block(x)
# 初期状態では入力に近い出力
difference = (y - x).abs().mean().item()
print(f"入力と出力の差(初期状態): {difference:.6f}")
print("→ 初期状態では恒等写像に近い")
8.3 層正規化:学習の安定化¶
なぜ正規化が必要か¶
class LayerNormalization:
"""層正規化の理解と実装"""
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def demonstrate_internal_covariate_shift(self):
"""内部共変量シフトの実証"""
print("=== 内部共変量シフト ===\n")
print("深いネットワークでは、各層の入力分布が学習中に変化")
print("→ 下位層の更新が上位層の入力分布を変える")
print("→ 学習が不安定になる\n")
# シンプルな実験
class UnstableNetwork(nn.Module):
def __init__(self, use_norm: bool = False):
super().__init__()
self.use_norm = use_norm
self.layers = nn.ModuleList()
for i in range(10):
self.layers.append(nn.Linear(64, 64))
if use_norm:
self.layers.append(nn.LayerNorm(64))
self.layers.append(nn.ReLU())
def forward(self, x):
activations_stats = []
for layer in self.layers:
x = layer(x)
# 活性化の統計を記録
if isinstance(layer, nn.ReLU):
mean = x.mean().item()
std = x.std().item()
activations_stats.append((mean, std))
return x, activations_stats
# 比較
net_without_norm = UnstableNetwork(use_norm=False)
net_with_norm = UnstableNetwork(use_norm=True)
# ランダム入力
x = torch.randn(32, 64)
# 各ネットワークを通す
_, stats_without = net_without_norm(x)
_, stats_with = net_with_norm(x)
# 可視化
self._plot_activation_statistics(stats_without, stats_with)
def _plot_activation_statistics(self, stats_without, stats_with):
"""活性化統計の可視化"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
layers = range(len(stats_without))
# 平均
means_without = [s[0] for s in stats_without]
means_with = [s[0] for s in stats_with]
ax1.plot(layers, means_without, 'r-o', label='Without LayerNorm')
ax1.plot(layers, means_with, 'b-o', label='With LayerNorm')
ax1.set_xlabel('Layer')
ax1.set_ylabel('Mean Activation')
ax1.set_title('活性化の平均')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 標準偏差
stds_without = [s[1] for s in stats_without]
stds_with = [s[1] for s in stats_with]
ax2.plot(layers, stds_without, 'r-o', label='Without LayerNorm')
ax2.plot(layers, stds_with, 'b-o', label='With LayerNorm')
ax2.set_xlabel('Layer')
ax2.set_ylabel('Std Activation')
ax2.set_title('活性化の標準偏差')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
def explain_layer_norm_vs_batch_norm(self):
"""LayerNorm vs BatchNormの説明"""
print("\n=== LayerNorm vs BatchNorm ===\n")
# 違いを可視化
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# データの形状を示す
batch_size, seq_len, d_model = 4, 6, 8
# BatchNorm
ax1.set_title('Batch Normalization', fontsize=14)
for i in range(batch_size):
for j in range(seq_len):
rect = Rectangle((j, i), 1, 1,
facecolor=plt.cm.Blues((i+j)%2 * 0.3 + 0.3),
edgecolor='black', linewidth=0.5)
ax1.add_patch(rect)
# 正規化の方向を示す
ax1.arrow(seq_len/2, -0.5, 0, batch_size + 0.5,
head_width=0.3, head_length=0.2,
fc='red', ec='red', linewidth=2)
ax1.text(seq_len/2 + 0.5, batch_size/2, 'Normalize\nacross batch',
ha='left', va='center', color='red', fontsize=12)
ax1.set_xlim(0, seq_len)
ax1.set_ylim(-1, batch_size)
ax1.set_xlabel('Sequence Position')
ax1.set_ylabel('Batch')
ax1.invert_yaxis()
# LayerNorm
ax2.set_title('Layer Normalization', fontsize=14)
for i in range(batch_size):
for j in range(seq_len):
rect = Rectangle((j, i), 1, 1,
facecolor=plt.cm.Greens((i+j)%2 * 0.3 + 0.3),
edgecolor='black', linewidth=0.5)
ax2.add_patch(rect)
# 正規化の方向を示す
for i in range(batch_size):
ax2.arrow(-0.5, i + 0.5, seq_len + 0.5, 0,
head_width=0.2, head_length=0.2,
fc='blue', ec='blue', linewidth=1, alpha=0.7)
ax2.text(seq_len + 0.5, batch_size/2, 'Normalize\nacross features',
ha='left', va='center', color='blue', fontsize=12)
ax2.set_xlim(-1, seq_len + 1)
ax2.set_ylim(-1, batch_size)
ax2.set_xlabel('Sequence Position')
ax2.set_ylabel('Batch')
ax2.invert_yaxis()
plt.tight_layout()
plt.show()
print("BatchNorm: バッチ次元で正規化(同じ位置の異なるサンプル)")
print(" - 利点:バッチ統計を使った強力な正規化")
print(" - 欠点:推論時にバッチ統計が必要、可変長系列で問題")
print("\nLayerNorm: 特徴次元で正規化(同じサンプルの全特徴)")
print(" - 利点:バッチサイズに依存しない、系列長に柔軟")
print(" - 欠点:バッチ情報を活用できない")
print(" - Transformerに最適!")
def implement_layer_norm(self):
"""LayerNormの実装"""
class CustomLayerNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.d_model = d_model
self.eps = eps
# 学習可能なパラメータ
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
def forward(self, x):
# x: [batch_size, seq_len, d_model]
# 特徴次元で平均と分散を計算
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# 正規化
x_normalized = (x - mean) / torch.sqrt(var + self.eps)
# スケールとシフト
out = self.gamma * x_normalized + self.beta
return out
# 動作確認
print("\n=== LayerNormの実装 ===")
layer_norm = CustomLayerNorm(d_model=256)
x = torch.randn(32, 10, 256) * 5 + 2 # 平均2、標準偏差5
with torch.no_grad():
y = layer_norm(x)
print(f"入力: mean={x.mean():.2f}, std={x.std():.2f}")
print(f"出力: mean={y.mean():.2f}, std={y.std():.2f}")
# 各サンプルの統計
sample_mean = y[0].mean(dim=-1)
sample_std = y[0].std(dim=-1)
print(f"\nサンプル0の各位置での統計:")
print(f" 平均: {sample_mean[:5].tolist()}")
print(f" 標準偏差: {sample_std[:5].tolist()}")
8.4 Transformerブロック:すべての要素の統合¶
標準的なTransformerブロック¶
class TransformerBlock:
"""Transformerブロックの完全な実装"""
def __init__(self,
d_model: int = 512,
n_heads: int = 8,
d_ff: int = 2048,
dropout: float = 0.1):
self.d_model = d_model
self.n_heads = n_heads
self.d_ff = d_ff
self.dropout = dropout
def create_transformer_block(self) -> nn.Module:
"""標準的なTransformerブロックを作成"""
class StandardTransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout):
super().__init__()
# Multi-Head Attention
self.self_attention = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True
)
# Feed-Forward Network
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
# Layer Normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 1. Self-Attention with Residual
attn_output, _ = self.self_attention(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_output))
# 2. Feed-Forward with Residual
ff_output = self.feed_forward(x)
x = self.norm2(x + ff_output)
return x
return StandardTransformerBlock(
self.d_model, self.n_heads, self.d_ff, self.dropout
)
def visualize_transformer_block(self):
"""Transformerブロックの構造を可視化"""
fig, ax = plt.subplots(figsize=(10, 12))
# コンポーネントの位置
components = [
{"name": "Input", "y": 0, "color": "lightblue"},
{"name": "Multi-Head\nAttention", "y": 0.15, "color": "lightgreen"},
{"name": "Add & Norm", "y": 0.25, "color": "lightyellow"},
{"name": "Feed Forward", "y": 0.40, "color": "lightcoral"},
{"name": "Add & Norm", "y": 0.50, "color": "lightyellow"},
{"name": "Output", "y": 0.65, "color": "lightblue"}
]
# 各コンポーネントを描画
box_width = 0.3
box_height = 0.08
for comp in components:
# ボックス
rect = FancyBboxPatch(
(0.5 - box_width/2, comp["y"] - box_height/2),
box_width, box_height,
boxstyle="round,pad=0.02",
facecolor=comp["color"],
edgecolor='black',
linewidth=2
)
ax.add_patch(rect)
# テキスト
ax.text(0.5, comp["y"], comp["name"],
ha='center', va='center',
fontsize=12, fontweight='bold')
# 接続線
connections = [
(0, 1, "straight"), # Input → Attention
(1, 2, "straight"), # Attention → Add&Norm
(2, 3, "straight"), # Add&Norm → FF
(3, 4, "straight"), # FF → Add&Norm
(4, 5, "straight"), # Add&Norm → Output
(0, 2, "residual"), # Input → Add&Norm (residual)
(2, 4, "residual") # Add&Norm → Add&Norm (residual)
]
for start_idx, end_idx, conn_type in connections:
start_y = components[start_idx]["y"] + box_height/2
end_y = components[end_idx]["y"] - box_height/2
if conn_type == "straight":
ax.arrow(0.5, start_y, 0, end_y - start_y - 0.01,
head_width=0.02, head_length=0.01,
fc='black', ec='black')
else: # residual
# 曲線で残差接続を表現
x_offset = 0.15 if start_idx == 0 else -0.15
ax.annotate('', xy=(0.5 + x_offset, end_y),
xytext=(0.5 + x_offset, start_y),
arrowprops=dict(arrowstyle='->',
connectionstyle="arc3,rad=0",
color='red', lw=2))
# 残差接続のラベル
ax.text(0.7, 0.125, 'Residual', color='red', fontsize=10, rotation=90)
ax.text(0.3, 0.325, 'Residual', color='red', fontsize=10, rotation=90)
ax.set_xlim(0, 1)
ax.set_ylim(-0.1, 0.75)
ax.set_title('Transformer Block Architecture', fontsize=16)
ax.axis('off')
plt.tight_layout()
plt.show()
def explain_component_roles(self):
"""各コンポーネントの役割を説明"""
print("=== Transformerブロックの各要素 ===\n")
components = {
"Multi-Head Attention": {
"役割": "文脈情報の統合",
"詳細": "各位置が他のすべての位置の情報を参照",
"なぜ必要": "長距離依存関係の捕捉"
},
"Feed-Forward Network": {
"役割": "位置ごとの非線形変換",
"詳細": "各位置で独立に適用される2層MLP",
"なぜ必要": "表現力の向上"
},
"Residual Connection": {
"役割": "勾配の流れを改善",
"詳細": "入力を出力に直接加算",
"なぜ必要": "深いネットワークの学習を安定化"
},
"Layer Normalization": {
"役割": "活性化の正規化",
"詳細": "各位置で特徴を正規化",
"なぜ必要": "学習の安定化と高速化"
}
}
for name, info in components.items():
print(f"{name}:")
print(f" 役割: {info['役割']}")
print(f" 詳細: {info['詳細']}")
print(f" なぜ必要: {info['なぜ必要']}")
print()
Pre-LN vs Post-LN¶
class NormalizationVariants:
"""正規化の配置バリエーション"""
def compare_pre_ln_post_ln(self):
"""Pre-LN vs Post-LNの比較"""
print("=== Pre-LN vs Post-LN ===\n")
class PostLNBlock(nn.Module):
"""オリジナルのPost-LN構成"""
def __init__(self, d_model):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, 8, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model)
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Attention → Add → Norm
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
# FF → Add → Norm
ff_out = self.feed_forward(x)
x = self.norm2(x + ff_out)
return x
class PreLNBlock(nn.Module):
"""改良されたPre-LN構成"""
def __init__(self, d_model):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attention = nn.MultiheadAttention(d_model, 8, batch_first=True)
self.norm2 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model)
)
def forward(self, x):
# Norm → Attention → Add
attn_out, _ = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))
x = x + attn_out
# Norm → FF → Add
ff_out = self.feed_forward(self.norm2(x))
x = x + ff_out
return x
# 比較表
comparison = {
"Post-LN": {
"利点": ["オリジナル論文の構成", "理論的に解析しやすい"],
"欠点": ["深いモデルで不安定", "Warmupが必須"],
"式": "LN(x + Sublayer(x))"
},
"Pre-LN": {
"利点": ["学習が安定", "Warmup不要", "より深いモデルが可能"],
"欠点": ["最終層に追加のLNが必要"],
"式": "x + Sublayer(LN(x))"
}
}
for variant, props in comparison.items():
print(f"{variant}:")
print(f" 式: {props['式']}")
print(f" 利点: {', '.join(props['利点'])}")
print(f" 欠点: {', '.join(props['欠点'])}")
print()
# 学習の安定性を可視化
self._visualize_training_stability()
def _visualize_training_stability(self):
"""学習の安定性を可視化"""
# 仮想的な学習曲線
epochs = np.arange(100)
# Post-LN: 不安定な初期
post_ln_loss = np.exp(-epochs / 20) * (1 + 0.3 * np.sin(epochs / 5)) + 0.1
post_ln_loss[:10] = post_ln_loss[:10] * (1 + np.random.randn(10) * 0.2)
# Pre-LN: 安定
pre_ln_loss = np.exp(-epochs / 20) * (1 + 0.1 * np.sin(epochs / 5)) + 0.1
plt.figure(figsize=(10, 6))
plt.plot(epochs, post_ln_loss, 'r-', label='Post-LN', linewidth=2)
plt.plot(epochs, pre_ln_loss, 'b-', label='Pre-LN', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('学習の安定性: Pre-LN vs Post-LN')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim(0, 2)
# Warmup期間を示す
plt.axvspan(0, 10, alpha=0.2, color='gray')
plt.text(5, 1.8, 'Warmup期間', ha='center', fontsize=12)
plt.tight_layout()
plt.show()
8.5 深さのスケーリング法則¶
モデルサイズと性能の関係¶
class ScalingLaws:
"""スケーリング法則の理解"""
def demonstrate_scaling_laws(self):
"""スケーリング法則の実証"""
print("=== スケーリング法則 ===\n")
print("Kaplan et al. (2020)の発見:")
print("Loss ∝ N^(-α)")
print(" N: パラメータ数")
print(" α ≈ 0.076")
# モデルサイズと性能の関係
model_sizes = np.logspace(6, 11, 50) # 1M to 100B parameters
# 理論的な損失曲線
alpha = 0.076
loss = 10 * model_sizes ** (-alpha)
# Chinchilla最適化
chinchilla_optimal_data = model_sizes * 20 # 20トークン/パラメータ
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
# 損失 vs モデルサイズ
ax1.loglog(model_sizes, loss, 'b-', linewidth=2)
# 実際のモデルをプロット
real_models = {
'GPT-2': (1.5e9, 3.5),
'GPT-3': (175e9, 2.5),
'PaLM': (540e9, 2.2),
'GPT-4': (1e12, 1.8) # 推定
}
for name, (size, loss_val) in real_models.items():
ax1.scatter(size, loss_val, s=100, label=name)
ax1.annotate(name, (size, loss_val), xytext=(5, 5),
textcoords='offset points')
ax1.set_xlabel('Model Parameters')
ax1.set_ylabel('Loss')
ax1.set_title('スケーリング法則: Loss vs Model Size')
ax1.grid(True, alpha=0.3)
ax1.legend()
# データ量の最適配分
ax2.loglog(model_sizes, chinchilla_optimal_data, 'g-', linewidth=2)
ax2.fill_between(model_sizes, chinchilla_optimal_data * 0.5,
chinchilla_optimal_data * 2, alpha=0.3, color='green')
ax2.set_xlabel('Model Parameters')
ax2.set_ylabel('Training Tokens')
ax2.set_title('Chinchilla最適化: データ量の配分')
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\n重要な洞察:")
print("1. モデルサイズを10倍にすると、損失は約2倍改善")
print("2. データ量も比例して増やす必要がある(Chinchilla)")
print("3. 計算量は model_size × data_size に比例")
まとめ:深層学習の力を引き出す¶
この章で学んだ深層学習の重要な要素:
- 深さの価値:
- 階層的な特徴学習
- 指数的な表現力
-
各層が扱いやすい変換を学習
-
安定化技術:
- 残差接続:勾配の高速道路
- 層正規化:内部共変量シフトの抑制
-
適切な初期化:学習の出発点
-
Transformerブロック:
- すべての要素の調和
- Pre-LN構成による安定性
-
スケーラビリティ
-
スケーリング:
- より大きなモデル = より良い性能
- ただし、データと計算も比例して必要
これらの技術により、Transformerは100層以上の深さでも安定して学習でき、驚異的な性能を発揮します。次章では、いよいよこれらの要素を組み合わせて、完全なTransformerアーキテクチャを詳しく見ていきます。
演習問題¶
-
実装課題:残差接続ありとなしで20層のネットワークを訓練し、勾配の流れを比較してください。
-
分析課題:Pre-LNとPost-LNの構成で、層を重ねたときの活性化の分散がどう変化するか調べてください。
-
設計課題:新しい正規化手法(例:RMSNorm)を実装し、LayerNormと比較してください。
-
理論課題:なぜ残差接続が恒等写像の学習を容易にするか、数学的に説明してください。
次章「第3部:Transformerアーキテクチャ詳解」へ続く。