コンテンツにスキップ

動作検証

はじめに:正しさの証明

コンパイラを作った後、最も重要なのはその正しさを検証することです。テストスイートを実行し、既知の入力に対して期待される出力が得られることを確認します。エッジケースを探し、パフォーマンスを測定し、他の実装と比較します。

Transformerの実装でも同じアプローチが必要です。この章では、実装したTransformerが正しく動作することを体系的に検証する方法を学びます。

16.1 単体テストの実装

コンポーネントレベルのテスト

```python import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import unittest from typing import Tuple, Optional, List import math import matplotlib.pyplot as plt import seaborn as sns from torch.testing import assert_close

class TestMultiHeadAttention(unittest.TestCase): """Multi-Head Attentionの単体テスト"""

def setUp(self):
    """テストの初期設定"""
    self.d_model = 512
    self.n_heads = 8
    self.batch_size = 2
    self.seq_len = 10

    # テスト対象のモジュール
    self.attention = nn.MultiheadAttention(
        self.d_model, self.n_heads, batch_first=True
    )

def test_output_shape(self):
    """出力形状のテスト"""
    # 入力データ
    x = torch.randn(self.batch_size, self.seq_len, self.d_model)

    # 順伝播
    output, weights = self.attention(x, x, x)

    # 形状の確認
    self.assertEqual(output.shape, (self.batch_size, self.seq_len, self.d_model))
    self.assertEqual(weights.shape, (self.batch_size, self.seq_len, self.seq_len))

    print("✅ 出力形状テスト: PASS")

def test_attention_mask(self):
    """注意マスクのテスト"""
    x = torch.randn(self.batch_size, self.seq_len, self.d_model)

    # 因果的マスクの作成
    mask = torch.triu(torch.ones(self.seq_len, self.seq_len) * float('-inf'), diagonal=1)

    # マスク付き順伝播
    output, weights = self.attention(x, x, x, attn_mask=mask)

    # 未来の位置への注意が0であることを確認
    for i in range(self.seq_len):
        for j in range(i + 1, self.seq_len):
            self.assertAlmostEqual(
                weights[0, i, j].item(), 0.0, places=5,
                msg=f"Position {i} should not attend to future position {j}"
            )

    print("✅ 注意マスクテスト: PASS")

def test_key_value_different(self):
    """異なるKey/Valueのテスト"""
    # Query, Key, Valueが異なる場合
    q = torch.randn(self.batch_size, self.seq_len, self.d_model)
    k = torch.randn(self.batch_size, self.seq_len * 2, self.d_model)
    v = torch.randn(self.batch_size, self.seq_len * 2, self.d_model)

    output, weights = self.attention(q, k, v)

    # 出力形状の確認
    self.assertEqual(output.shape, (self.batch_size, self.seq_len, self.d_model))
    self.assertEqual(weights.shape, (self.batch_size, self.seq_len, self.seq_len * 2))

    print("✅ 異なるKey/Valueテスト: PASS")

def test_attention_weights_sum(self):
    """注意重みの和が1であることをテスト"""
    x = torch.randn(self.batch_size, self.seq_len, self.d_model)
    _, weights = self.attention(x, x, x)

    # 各行の和が1に近いことを確認
    row_sums = weights.sum(dim=-1)
    expected = torch.ones_like(row_sums)

    assert_close(row_sums, expected, rtol=1e-5, atol=1e-5)

    print("✅ 注意重みの和テスト: PASS")

def test_gradient_flow(self):
    """勾配が正しく流れることをテスト"""
    x = torch.randn(self.batch_size, self.seq_len, self.d_model, requires_grad=True)

    output, _ = self.attention(x, x, x)
    loss = output.mean()
    loss.backward()

    # 勾配が計算されていることを確認
    self.assertIsNotNone(x.grad)
    self.assertFalse(torch.isnan(x.grad).any())
    self.assertFalse(torch.isinf(x.grad).any())

    print("✅ 勾配フローテスト: PASS")

class TestPositionalEncoding(unittest.TestCase): """位置エンコーディングのテスト"""

def setUp(self):
    self.d_model = 512
    self.max_len = 1000

def test_sinusoidal_encoding_properties(self):
    """正弦波位置エンコーディングの性質をテスト"""
    # 位置エンコーディングの生成
    pe = self._create_sinusoidal_encoding(self.max_len, self.d_model)

    # 1. 値の範囲が[-1, 1]であること
    self.assertLessEqual(pe.max().item(), 1.0)
    self.assertGreaterEqual(pe.min().item(), -1.0)

    # 2. 偶数次元がsin、奇数次元がcosであること
    pos = 10  # テスト位置
    for i in range(0, self.d_model, 2):
        div_term = 10000 ** (i / self.d_model)
        expected_sin = math.sin(pos / div_term)
        expected_cos = math.cos(pos / div_term)

        self.assertAlmostEqual(pe[pos, i].item(), expected_sin, places=5)
        if i + 1 < self.d_model:
            self.assertAlmostEqual(pe[pos, i + 1].item(), expected_cos, places=5)

    print("✅ 正弦波エンコーディングテスト: PASS")

def test_relative_position_encoding(self):
    """相対位置の性質をテスト"""
    pe = self._create_sinusoidal_encoding(self.max_len, self.d_model)

    # 固定された相対距離での内積が一定であることを確認
    distance = 5
    products = []

    for pos in range(10, 20):
        dot_product = torch.dot(pe[pos], pe[pos + distance])
        products.append(dot_product.item())

    # 標準偏差が小さいことを確認
    std = np.std(products)
    self.assertLess(std, 0.01, "相対位置の内積は一定であるべき")

    print("✅ 相対位置エンコーディングテスト: PASS")

def _create_sinusoidal_encoding(self, max_len: int, d_model: int) -> torch.Tensor:
    """正弦波位置エンコーディングを作成"""
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1).float()

    div_term = torch.exp(
        torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
    )

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    return pe

class TestTransformerBlock(unittest.TestCase): """Transformerブロックのテスト"""

def setUp(self):
    self.d_model = 256
    self.n_heads = 8
    self.d_ff = 1024
    self.dropout = 0.1

    # Transformerブロックの作成
    self.encoder_layer = nn.TransformerEncoderLayer(
        d_model=self.d_model,
        nhead=self.n_heads,
        dim_feedforward=self.d_ff,
        dropout=self.dropout,
        batch_first=True
    )

def test_residual_connections(self):
    """残差接続のテスト"""
    batch_size = 2
    seq_len = 10

    # 入力
    x = torch.randn(batch_size, seq_len, self.d_model)

    # ドロップアウトを無効化(テストのため)
    self.encoder_layer.eval()

    # 小さな重みで初期化(残差接続の効果を見やすくする)
    for param in self.encoder_layer.parameters():
        param.data.mul_(0.01)

    output = self.encoder_layer(x)

    # 出力が入力に近いことを確認(残差接続の効果)
    diff = (output - x).abs().mean()
    self.assertLess(diff.item(), 0.5, "残差接続により出力は入力に近いはず")

    print("✅ 残差接続テスト: PASS")

def test_layer_norm_effect(self):
    """層正規化の効果をテスト"""
    batch_size = 2
    seq_len = 10

    # 大きな値を持つ入力
    x = torch.randn(batch_size, seq_len, self.d_model) * 100

    output = self.encoder_layer(x)

    # 出力の各位置での平均と分散を計算
    mean = output.mean(dim=-1)
    var = output.var(dim=-1)

    # 層正規化により、平均が0に近く、分散が1に近いことを確認
    self.assertLess(mean.abs().mean().item(), 0.1)
    self.assertLess((var - 1).abs().mean().item(), 0.5)

    print("✅ 層正規化テスト: PASS")

16.2 統合テスト

class IntegrationTests: """モデル全体の統合テスト"""

def __init__(self, model_class):
    self.model_class = model_class

def test_full_forward_pass(self):
    """完全な順伝播のテスト"""
    print("\n=== 統合テスト: 完全な順伝播 ===")

    # モデルの作成
    vocab_size = 1000
    d_model = 256
    model = self.model_class(vocab_size=vocab_size, d_model=d_model)
    model.eval()

    # テストケース
    test_cases = [
        {"batch_size": 1, "seq_len": 10},
        {"batch_size": 4, "seq_len": 50},
        {"batch_size": 8, "seq_len": 100},
    ]

    for case in test_cases:
        batch_size = case["batch_size"]
        seq_len = case["seq_len"]

        # 入力データ
        input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

        # 順伝播
        with torch.no_grad():
            output = model(input_ids)

        # 出力形状の確認
        expected_shape = (batch_size, seq_len, vocab_size)
        assert output.shape == expected_shape, \
            f"Expected shape {expected_shape}, got {output.shape}"

        # NaNやInfがないことを確認
        assert not torch.isnan(output).any(), "Output contains NaN"
        assert not torch.isinf(output).any(), "Output contains Inf"

        print(f"✅ バッチサイズ={batch_size}, シーケンス長={seq_len}: PASS")

def test_generation_consistency(self):
    """生成の一貫性テスト"""
    print("\n=== 統合テスト: 生成の一貫性 ===")

    vocab_size = 100
    model = self.model_class(vocab_size=vocab_size, d_model=128)
    model.eval()

    # シード固定
    torch.manual_seed(42)

    # 同じプロンプトから生成
    prompt = torch.tensor([[1, 2, 3]])

    # 複数回生成
    outputs = []
    for _ in range(3):
        torch.manual_seed(42)  # 同じシード
        output = model.generate(prompt, max_new_tokens=10, temperature=1.0)
        outputs.append(output)

    # すべての出力が同じであることを確認
    for i in range(1, len(outputs)):
        assert torch.equal(outputs[0], outputs[i]), \
            f"生成結果が一貫していません: {i}回目"

    print("✅ 生成の一貫性: PASS")

def test_attention_pattern_analysis(self):
    """注意パターンの分析"""
    print("\n=== 統合テスト: 注意パターン分析 ===")

    # 特別なテストケース:繰り返しパターン
    vocab_size = 50
    model = self.model_class(vocab_size=vocab_size, d_model=128, n_layers=2)
    model.eval()

    # 繰り返しのある入力
    # "A B C A B C A B C"のようなパターン
    pattern = [10, 20, 30]
    input_ids = torch.tensor([pattern * 3]).to(torch.long)

    # 注意重みを取得するためのフック
    attention_weights = []

    def hook_fn(module, input, output):
        if isinstance(output, tuple) and len(output) == 2:
            _, attn = output
            if attn is not None:
                attention_weights.append(attn.detach())

    # フックを登録
    hooks = []
    for module in model.modules():
        if isinstance(module, nn.MultiheadAttention):
            hook = module.register_forward_hook(hook_fn)
            hooks.append(hook)

    # 順伝播
    with torch.no_grad():
        _ = model(input_ids)

    # フックを削除
    for hook in hooks:
        hook.remove()

    # 注意パターンの分析
    if attention_weights:
        # 最初の層の注意重みを分析
        attn = attention_weights[0][0].mean(dim=0)  # ヘッドの平均

        # 同じトークンへの注意が高いことを確認
        for i in range(3):
            for j in range(3):
                if i != j:
                    pos1 = i * 3
                    pos2 = j * 3
                    # 同じトークン(位置は違う)への注意
                    similarity = attn[pos1, pos2].item()
                    print(f"  位置{pos1} → 位置{pos2}の注意: {similarity:.3f}")

    print("✅ 注意パターン分析: 完了")

16.3 性能ベンチマーク

class PerformanceBenchmark: """性能ベンチマークテスト"""

def __init__(self, model):
    self.model = model
    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.model.to(self.device)

def benchmark_inference_speed(self):
    """推論速度のベンチマーク"""
    print("\n=== 性能ベンチマーク: 推論速度 ===")

    # テスト設定
    batch_sizes = [1, 4, 16, 32]
    seq_lengths = [10, 50, 100, 200]
    vocab_size = 1000

    results = {}

    for batch_size in batch_sizes:
        results[batch_size] = {}

        for seq_len in seq_lengths:
            # 入力データ
            input_ids = torch.randint(0, vocab_size, 
                                    (batch_size, seq_len)).to(self.device)

            # ウォームアップ
            for _ in range(10):
                with torch.no_grad():
                    _ = self.model(input_ids)

            # 時間測定
            import time
            torch.cuda.synchronize() if torch.cuda.is_available() else None

            start_time = time.time()
            n_iterations = 100

            for _ in range(n_iterations):
                with torch.no_grad():
                    _ = self.model(input_ids)

            torch.cuda.synchronize() if torch.cuda.is_available() else None
            end_time = time.time()

            # 平均時間(ミリ秒)
            avg_time = (end_time - start_time) / n_iterations * 1000
            throughput = batch_size / (avg_time / 1000)  # samples/sec

            results[batch_size][seq_len] = {
                'time_ms': avg_time,
                'throughput': throughput
            }

            print(f"  Batch={batch_size}, Seq={seq_len}: "
                  f"{avg_time:.2f}ms, {throughput:.1f} samples/sec")

    # 結果の可視化
    self._visualize_benchmark_results(results)

    return results

def benchmark_memory_usage(self):
    """メモリ使用量のベンチマーク"""
    print("\n=== 性能ベンチマーク: メモリ使用量 ===")

    if not torch.cuda.is_available():
        print("  GPUが利用できないため、メモリベンチマークをスキップ")
        return

    seq_lengths = [10, 50, 100, 200, 500]
    batch_size = 4
    vocab_size = 1000

    memory_usage = []

    for seq_len in seq_lengths:
        # メモリをクリア
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        # 入力データ
        input_ids = torch.randint(0, vocab_size, 
                                (batch_size, seq_len)).to(self.device)

        # 順伝播
        with torch.no_grad():
            _ = self.model(input_ids)

        # ピークメモリ使用量
        peak_memory = torch.cuda.max_memory_allocated() / 1024**2  # MB
        memory_usage.append(peak_memory)

        print(f"  Seq={seq_len}: {peak_memory:.1f} MB")

    # メモリ使用量の成長率を分析
    self._analyze_memory_scaling(seq_lengths, memory_usage)

def _visualize_benchmark_results(self, results):
    """ベンチマーク結果を可視化"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # 推論時間
    for batch_size, seq_results in results.items():
        seq_lengths = list(seq_results.keys())
        times = [seq_results[seq]['time_ms'] for seq in seq_lengths]
        ax1.plot(seq_lengths, times, marker='o', label=f'Batch={batch_size}')

    ax1.set_xlabel('Sequence Length')
    ax1.set_ylabel('Inference Time (ms)')
    ax1.set_title('Inference Time vs Sequence Length')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_yscale('log')

    # スループット
    for batch_size, seq_results in results.items():
        seq_lengths = list(seq_results.keys())
        throughputs = [seq_results[seq]['throughput'] for seq in seq_lengths]
        ax2.plot(seq_lengths, throughputs, marker='o', label=f'Batch={batch_size}')

    ax2.set_xlabel('Sequence Length')
    ax2.set_ylabel('Throughput (samples/sec)')
    ax2.set_title('Throughput vs Sequence Length')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_yscale('log')

    plt.tight_layout()
    plt.show()

def _analyze_memory_scaling(self, seq_lengths, memory_usage):
    """メモリスケーリングの分析"""
    # O(n^2)フィッティング
    coeffs = np.polyfit(seq_lengths, memory_usage, 2)
    poly = np.poly1d(coeffs)

    plt.figure(figsize=(8, 6))
    plt.scatter(seq_lengths, memory_usage, label='Actual', s=100)

    # フィット曲線
    x_fit = np.linspace(min(seq_lengths), max(seq_lengths), 100)
    y_fit = poly(x_fit)
    plt.plot(x_fit, y_fit, 'r--', label=f'Quadratic Fit', alpha=0.7)

    plt.xlabel('Sequence Length')
    plt.ylabel('Memory Usage (MB)')
    plt.title('Memory Usage Scaling')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

    print(f"\n  メモリ使用量は O(n^{2:.1f}) でスケール")

16.4 比較検証

class ComparativeValidation: """他の実装との比較検証"""

def __init__(self, custom_model, reference_model=None):
    self.custom_model = custom_model
    self.reference_model = reference_model

def compare_with_pytorch_transformer(self):
    """PyTorchの標準Transformerとの比較"""
    print("\n=== 比較検証: PyTorch標準実装との比較 ===")

    d_model = 256
    n_heads = 8
    n_layers = 2

    # PyTorchの標準Transformer
    pytorch_encoder = nn.TransformerEncoder(
        nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            batch_first=True
        ),
        num_layers=n_layers
    )

    # 同じ重みで初期化(可能な限り)
    # ここでは簡略化のため新規の重みを使用

    # テスト入力
    batch_size = 2
    seq_len = 10
    x = torch.randn(batch_size, seq_len, d_model)

    # 両方のモデルで推論
    pytorch_encoder.eval()
    self.custom_model.eval() if self.custom_model else None

    with torch.no_grad():
        pytorch_output = pytorch_encoder(x)

        if self.custom_model:
            # カスタムモデルがエンコーダーのみの場合
            custom_output = self.custom_model.encoder(x) if hasattr(self.custom_model, 'encoder') else None

            if custom_output is not None:
                # 出力の統計を比較
                print(f"  PyTorch出力 - 平均: {pytorch_output.mean():.4f}, "
                      f"標準偏差: {pytorch_output.std():.4f}")
                print(f"  カスタム出力 - 平均: {custom_output.mean():.4f}, "
                      f"標準偏差: {custom_output.std():.4f}")
        else:
            print(f"  PyTorch出力形状: {pytorch_output.shape}")
            print(f"  平均: {pytorch_output.mean():.4f}, "
                  f"標準偏差: {pytorch_output.std():.4f}")

def validate_attention_computation(self):
    """注意計算の検証"""
    print("\n=== 比較検証: 注意計算の正確性 ===")

    # 手動での注意計算
    d_model = 64
    seq_len = 5

    # ランダムなQ, K, V
    torch.manual_seed(42)
    Q = torch.randn(1, seq_len, d_model)
    K = torch.randn(1, seq_len, d_model)
    V = torch.randn(1, seq_len, d_model)

    # 手動計算
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_model)
    attention_weights = F.softmax(scores, dim=-1)
    output_manual = torch.matmul(attention_weights, V)

    # nn.MultiheadAttentionでの計算(single head)
    attention = nn.MultiheadAttention(d_model, 1, batch_first=True)

    # 重みを設定(恒等変換)
    with torch.no_grad():
        attention.in_proj_weight.data = torch.eye(3 * d_model)
        attention.in_proj_bias.data = torch.zeros(3 * d_model)
        attention.out_proj.weight.data = torch.eye(d_model)
        attention.out_proj.bias.data = torch.zeros(d_model)

    # 入力を結合
    qkv = torch.cat([Q, K, V], dim=-1)
    x = qkv[:, :, :d_model]  # Qの部分のみ(簡略化)

    output_pytorch, weights_pytorch = attention(Q, K, V)

    print(f"  手動計算とPyTorchの差:")
    print(f"  注意重みの差: {(attention_weights - weights_pytorch).abs().max():.6f}")
    print(f"  出力の差: {(output_manual - output_pytorch).abs().mean():.6f}")

class ValidationSuite: """完全な検証スイート"""

def __init__(self, model_class):
    self.model_class = model_class
    self.test_results = {}

def run_all_tests(self):
    """すべてのテストを実行"""
    print("=" * 70)
    print("Transformer検証スイート")
    print("=" * 70)

    # 1. 単体テスト
    print("\n【1. 単体テスト】")
    self._run_unit_tests()

    # 2. 統合テスト
    print("\n【2. 統合テスト】")
    self._run_integration_tests()

    # 3. 性能テスト
    print("\n【3. 性能テスト】")
    self._run_performance_tests()

    # 4. 比較検証
    print("\n【4. 比較検証】")
    self._run_comparative_tests()

    # 結果サマリー
    self._print_summary()

def _run_unit_tests(self):
    """単体テストの実行"""
    # テストランナーの作成
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()

    # テストクラスを追加
    suite.addTests(loader.loadTestsFromTestCase(TestMultiHeadAttention))
    suite.addTests(loader.loadTestsFromTestCase(TestPositionalEncoding))
    suite.addTests(loader.loadTestsFromTestCase(TestTransformerBlock))

    # テスト実行
    runner = unittest.TextTestRunner(verbosity=0)
    result = runner.run(suite)

    self.test_results['unit_tests'] = {
        'total': result.testsRun,
        'failures': len(result.failures),
        'errors': len(result.errors)
    }

def _run_integration_tests(self):
    """統合テストの実行"""
    integration = IntegrationTests(self.model_class)

    try:
        integration.test_full_forward_pass()
        integration.test_generation_consistency()
        integration.test_attention_pattern_analysis()
        self.test_results['integration_tests'] = 'PASS'
    except Exception as e:
        self.test_results['integration_tests'] = f'FAIL: {str(e)}'

def _run_performance_tests(self):
    """性能テストの実行"""
    # 小さなモデルでテスト
    model = self.model_class(vocab_size=1000, d_model=128, n_layers=2)
    benchmark = PerformanceBenchmark(model)

    results = benchmark.benchmark_inference_speed()
    benchmark.benchmark_memory_usage()

    self.test_results['performance_tests'] = results

def _run_comparative_tests(self):
    """比較テストの実行"""
    model = self.model_class(vocab_size=1000, d_model=256, n_layers=2)
    comparative = ComparativeValidation(model)

    comparative.compare_with_pytorch_transformer()
    comparative.validate_attention_computation()

    self.test_results['comparative_tests'] = 'COMPLETED'

def _print_summary(self):
    """テスト結果のサマリー"""
    print("\n" + "=" * 70)
    print("テスト結果サマリー")
    print("=" * 70)

    # 単体テスト結果
    unit_results = self.test_results.get('unit_tests', {})
    print(f"\n単体テスト: {unit_results.get('total', 0)}個のテスト")
    print(f"  成功: {unit_results.get('total', 0) - unit_results.get('failures', 0) - unit_results.get('errors', 0)}")
    print(f"  失敗: {unit_results.get('failures', 0)}")
    print(f"  エラー: {unit_results.get('errors', 0)}")

    # 統合テスト結果
    print(f"\n統合テスト: {self.test_results.get('integration_tests', 'N/A')}")

    # 性能テスト結果
    if 'performance_tests' in self.test_results:
        print("\n性能テスト: 完了")
        # 代表的な結果を表示
        perf_results = self.test_results['performance_tests']
        if 4 in perf_results and 100 in perf_results[4]:
            time_ms = perf_results[4][100]['time_ms']
            throughput = perf_results[4][100]['throughput']
            print(f"  代表例 (Batch=4, Seq=100): {time_ms:.2f}ms, {throughput:.1f} samples/sec")

    # 比較テスト結果
    print(f"\n比較テスト: {self.test_results.get('comparative_tests', 'N/A')}")

    # 総合評価
    print("\n" + "=" * 70)
    all_passed = (
        unit_results.get('failures', 1) == 0 and
        unit_results.get('errors', 1) == 0 and
        self.test_results.get('integration_tests') == 'PASS'
    )

    if all_passed:
        print("✅ すべてのテストに合格しました!")
    else:
        print("❌ 一部のテストに失敗しました。")

実際の検証例

def run_validation_example(): """検証の実行例"""

# ダミーのTransformerモデルクラス
class DummyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers=4, n_heads=8):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=n_heads,
                batch_first=True
            ),
            num_layers=n_layers
        )
        self.output_projection = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        x = self.output_projection(x)
        return x

    def generate(self, prompt, max_new_tokens, temperature):
        # 簡略化された生成
        current = prompt
        for _ in range(max_new_tokens):
            output = self.forward(current)
            next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
            current = torch.cat([current, next_token], dim=1)
        return current

# 検証スイートの実行
validation = ValidationSuite(DummyTransformer)
validation.run_all_tests()

エラー分析とデバッグのヒント

def debugging_tips(): """デバッグのヒント""" print("\n" + "=" * 70) print("一般的な問題とデバッグのヒント") print("=" * 70 + "\n")

tips = {
    "勾配消失/爆発": [
        "層正規化が正しく適用されているか確認",
        "残差接続が機能しているか確認",
        "学習率が適切か確認",
        "勾配クリッピングを使用"
    ],

    "注意の偏り": [
        "スケーリング係数(1/√d_k)が正しいか確認",
        "マスクが正しく適用されているか確認",
        "初期化方法を確認"
    ],

    "生成品質": [
        "温度パラメータの調整",
        "Top-k/Top-pサンプリングの使用",
        "ビームサーチの実装",
        "繰り返しペナルティの追加"
    ],

    "メモリ不足": [
        "バッチサイズの削減",
        "シーケンス長の制限",
        "勾配累積の使用",
        "Mixed Precision Trainingの使用"
    ]
}

for problem, solutions in tips.items():
    print(f"{problem}:")
    for solution in solutions:
        print(f"  • {solution}")
    print()

if name == "main": # 検証例の実行 run_validation_example()

# デバッグのヒント
debugging_tips()