Skip to content

Music Transformer STEP2 - Model Training

把 Encodec 编码出来的 多 codebook 离散音频 token,当成“语言序列”来做自回归建模

Introduction:

1、构建训练样本:自回归格式

2、构建多 codebook 自回归 Transformer

3、构建 Dataset,用于 DataLoader

4、训练

5、保存 checkpoint


0、环境配置

!pip install -q torch torchaudio encodec torchcodec
import torch
import torchaudio
import encodec
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

1、构建训练样本(自回归格式)

选一个 10s 的音频,先看一下对于它 构建训练样本 是否能得到正确的样本

下面主要做了:转置这个音频的矩阵 + 赋值 x 和 y(y 始终比 x 往后偏移一个时间步)来构建训练样本

最终实现 “根据历史 token 序列,逐帧生成下一个帧的 token”。

# 测试

import torch
import torch.nn as nn

# 选一首歌、一个 chunk 作为最小实验
data = torch.load("maestro_tokens_3kbps/music_0000.pt")

# 取第一个 10s chunk
codes = data["tokens"][0]      # Tensor[K, T]【4 本 codebook(4 行),750 个时间步(750 个纵列)】
print("Original codes shape:", codes.shape)

# 把 音频编码 矩阵 转置
frames = codes.transpose(0, 1).contiguous()
print("Frames shape:", frames.shape)

# 自回归 shift
x = frames[:-1]   # [T-1, K]
y = frames[1:]    # [T-1, K]

print("x shape:", x.shape)
print("y shape:", y.shape)

# 看一下前 3 帧
print("x[:3]:\n", x[:3])
print("y[:3]:\n", y[:3])

输出:

Original codes shape: torch.Size([4, 750])
Frames shape: torch.Size([750, 4])
x shape: torch.Size([749, 4])
y shape: torch.Size([749, 4])
x[:3]:
 tensor([[ 133,  363,  989, 1018],
        [ 228,   87,  863,  793],
        [ 133,  859, 1007,   74]])
y[:3]:
 tensor([[ 228,   87,  863,  793],
        [ 133,  859, 1007,   74],
        [ 999,  646,  675,  866]])
以上输出验证了 我们构建自回归训练样本 的正确性。

可以看到 x(输入) 的第2、3行 分别相同于 y(预测) 的第1、2行

即 输入x1 预测y1;输入x2 预测y2 。

2、构建多 codebook 自回归 Transformer

一、定义部分:

1、定义 4 本 codebook 分别的 embedding - 为每个 codebook 单独建立一个 embedding 表 - 不同 codebook 表示不同频带/量化层,它们不是同一个“词表语义空间”,所以不能共享 embedding。

2、定义位置编码 embedding:告诉 transformer,输出的每个音频位置是不一样的,是有顺序的

3、定义 Transformer 块 - 定义标准的 nn.TransformerEncoderLayer - 堆叠 6 层(默认)

4、定义输出头:其实就是一个 分类头


二、forward 部分:

1、输入维度拆解:读到 B = 8;T = 750;K = 4

2、每个 codebook 单独 embedding

3、堆叠并展平:拉平成一个长序列

4、加位置编码:加一个维度 None,然后广播到各个 batch

5、构造自回归 mask:生成 上三角 = -inf,第 i 个音频 token 不能看未来的 音频 token。

6、Transformer 前向:每个 音频 token 都融合了过去所有音频 token 的信息

7、输出分类:进行打分 logits

8、还原形状:这样就可以对每个时间步、每个 codebook 做 交叉熵 loss

class MusicTransformer(nn.Module):
    def __init__(
        self, vocab_size=1024, 
        num_codebooks=4, 
        embed_dim=256, 
        max_seq_len=4096,
        num_layers=6, 
        num_heads=8, 
        dropout=0.2
    ):
        super().__init__()
        self.num_codebooks = num_codebooks
        self.embed_dim = embed_dim

        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, embed_dim)
            for _ in range(num_codebooks)
        ])

        self.pos_embedding = nn.Embedding(max_seq_len, embed_dim)

        decoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, 
            dim_feedforward=embed_dim * 4, dropout=dropout, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(decoder_layer, num_layers=num_layers)

        self.output_head = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):

        B, T, K = x.shape

        h_list = []
        for k in range(K):
            h_list.append(self.embeddings[k](x[:, :, k]))

        h = torch.stack(h_list, dim=2).reshape(B, T * K, -1)

        total_len = T * K
        pos_ids = torch.arange(total_len, device=x.device)
        h = h + self.pos_embedding(pos_ids)[None, :, :]

        mask = torch.triu(
            torch.full((total_len, total_len), float('-inf'), device=x.device),
            diagonal=1
        )

        h = self.transformer(h, mask)

        logits = self.output_head(h)

        logits = logits.reshape(B, T, K, -1)

        return logits

output_heads 在这里,这是一个分类头

得到的 logits_k 矩阵 的真实形状是:[B, T, vocab_size] [批次大小, 时间步, 词表里词数量] [一批次几段音频, 750时间步, 词表里共1024段音频]

logits_k[b, t, i] 表示:打分(score)

  • 在第 b 个样本、第 t 个时间步、第 k 个 codebook 上,

  • 在给定所有可见音频上下文的条件下,模型“倾向于生成 token i 的概率评分”


真实 logits 矩阵:750行(每行表示 1 个时间步),1024列(每列表示 codebook 中一个音频词 token)

  • 第1行表示:在第1个时间步,“生成codebook中1024个音频分别是什么概率?”的评分值。
  • 第2行 以此类推...
  • 第750行 以此类推...

注意:这 750 个时间步 实际是 10s。也就是:1s 是 75 个时间步。

# 测试

model = MusicTransformer(
    vocab_size=1024,
    num_codebooks=x.shape[1],   # x的列数:4列(4本codebook)(749行,因为10s有750个音再-1)
    embed_dim=256,
    num_layers=6,
    num_heads=8
)

x_small = x[:512]   # 只取512长度
dummy_x = x_small.unsqueeze(0)
logits = model(dummy_x)

print(f"Number of codebooks: {len(logits)}") # logits打分矩阵的行数:4行,分别对应4本codebook里的所有音
print("Logits shape for codebook 0:", logits[0].shape)
输出:
Number of codebooks: 1
Logits shape for codebook 0: torch.Size([512, 4, 1024])


import os
import torch
import random

TOKEN_DIR = "maestro_tokens_3kbps"

song_files = sorted([
    os.path.join(TOKEN_DIR, f)
    for f in os.listdir(TOKEN_DIR)
    if f.endswith(".pt")
])

print(f"Found {len(song_files)} songs.")
输出:
Found 65 songs.


# 看看我们存的.pt文件是什么 是dict还是tensor

data = torch.load("maestro_tokens_3kbps/music_0000.pt")
print(type(data))
print(data.keys())
输出:
<class 'dict'>
dict_keys(['tokens', 'num_chunks', 'bandwidth', 'sample_rate', 'chunk_size', 'music_id'])


data = torch.load("maestro_tokens_3kbps/music_0000.pt")

print(type(data["tokens"]))
print(len(data["tokens"]))
print(type(data["tokens"][0]))
print(data["tokens"][0].shape)

输出:

<class 'list'>
96
<class 'torch.Tensor'>
torch.Size([4, 750])

3、构建 Dataset,用于 DataLoader

这个模块整体需要做的是:

  • 拼接所有 chunk
  • 得到整首歌 token 序列
  • 切成长度 block_size 的小片段
  • 构造 (x, y) 训练样本

1、读取 token 文件:保证顺序一致

2、处理每一首 music - 取出 chunk 列表:每个是一个 10 秒 chunk - 转置

3、滑动窗口切样本 - 循环遍历音频的每一个采样点:构造 x、构造 y、加入样本


补充:为什么用 stride?

重叠率:87.5% - 好处:数据量暴增,每个时间点被多次训练 - 坏处:数据冗余大,训练更慢

import os
import torch
from torch.utils.data import Dataset

class MusicTokenDataset(Dataset):
    def __init__(
        self,
        token_dir,
        block_size=2048,
        stride=256
    ):
        self.block_size = block_size
        self.stride = stride
        self.samples = []

        files = sorted([
            os.path.join(token_dir, f)
            for f in os.listdir(token_dir)
            if f.endswith(".pt")
        ])

        for f in files:
            data = torch.load(f)

            chunks = data["tokens"]         # list of [4, 750]
            tokens = torch.cat(chunks, dim=1)  # [4, T]
            tokens = tokens.permute(1, 0)      # [T, 4]

            T = tokens.size(0)

            # 加 stride
            for i in range(0, T - block_size - 1, stride):
                x = tokens[i:i+block_size]
                y = tokens[i+1:i+block_size+1]
                self.samples.append((x, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

4、训练

整体流程概述:

1、初始化构建 Dataset、DataLoader、模型、损失函数、优化器、学习率调度器

  • 标准分类损失:用于预测:token id ∈ [0,1023]
  • 优化器:AdamW = Adam + 权重衰减。
  • 前 5% 步数做 warmup:防止训练初期不稳定;避免梯度爆炸

补充:为什么不用 Adam?

  • Transformer 训练通常配 AdamW
  • weight_decay=1e-2 是常见正则强度。

补充:为什么 warmup?

  • 学习率曲线:线性升高 → 线性下降
  • 防止训练初期不稳定,避免梯度爆炸,这是 GPT 标准训练策略。

2、训练循环:开始 10 个 epoch。 - 进入训练模式:开启 dropout - 遍历 batch - 前向传播 - 加权多 codebook 损失 - 反向传播:计算梯度;梯度裁剪:防止梯度爆炸;更新参数、学习率 - 记录 loss:用于算 epoch 平均 - 打印日志:每 50 步打印一次;单独计算第一个 codebook loss。 - 每个 epoch 结束:输出平均损失


补充:

加权多 codebook 损失

  • weights = [1.0, 0.5, 0.25, 0.1]
  • loss = 0

为什么这样设定:第一个 codebook 权重最高,越往后权重越低?

因为 Encodec 的第一个 codebook 通常包含最重要的低频信息

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup

dataset = MusicTokenDataset(
    "maestro_tokens_3kbps",
    block_size=1024,
    stride=256
)

loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

print("Dataset size:", len(dataset))
print("Batches per epoch:", len(loader))

model = MusicTransformer(
    vocab_size=1024,
    num_codebooks=4,
    embed_dim=512,
    max_seq_len=8192,
    num_layers=6,
    num_heads=8,
    dropout=0.2,
).cuda()

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-2
)

epochs = 10

total_steps = len(loader) * epochs

warmup_steps = int(total_steps * 0.05)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for step, (x, y) in enumerate(loader):
        x, y = x.cuda(), y.cuda()

        logits = model(x) # [B, T, K, 1024]

        weights = [1.0, 0.5, 0.25, 0.1]
        loss = 0

        for k in range(4):
            l_k = criterion(
                logits[:, :, k, :].reshape(-1, 1024), 
                y[:, :, k].reshape(-1)
            )
            loss += weights[k] * l_k

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

        if step % 50 == 0:
            with torch.no_grad():
                cb1_loss = criterion(logits[:, :, 0, :].reshape(-1, 1024), y[:, :, 0].reshape(-1))

            print(
                f"Epoch {epoch} | Step {step}/{len(loader)} | "
                f"Total Loss: {loss.item():.4f} | CB1 Loss: {cb1_loss.item():.4f}"
            )

    print(f"Epoch {epoch}: Average loss = {total_loss/len(loader)}")

输出:

Dataset size: 12124
Batches per epoch: 1516
Epoch 0 | Step 0/1516 | Total Loss: 13.1681 | CB1 Loss: 7.1360
Epoch 0 | Step 50/1516 | Total Loss: 12.4348 | CB1 Loss: 6.5062
Epoch 0 | Step 100/1516 | Total Loss: 12.1048 | CB1 Loss: 6.1696
Epoch 0 | Step 150/1516 | Total Loss: 12.1125 | CB1 Loss: 6.1280
Epoch 0 | Step 200/1516 | Total Loss: 11.9614 | CB1 Loss: 5.9766
Epoch 0 | Step 250/1516 | Total Loss: 11.4711 | CB1 Loss: 5.6517
Epoch 0 | Step 300/1516 | Total Loss: 10.9681 | CB1 Loss: 5.2020
Epoch 0 | Step 350/1516 | Total Loss: 10.7473 | CB1 Loss: 5.0230
Epoch 0 | Step 400/1516 | Total Loss: 10.2655 | CB1 Loss: 4.6984
Epoch 0 | Step 450/1516 | Total Loss: 10.0230 | CB1 Loss: 4.5502
Epoch 0 | Step 500/1516 | Total Loss: 10.0525 | CB1 Loss: 4.6381
Epoch 0 | Step 550/1516 | Total Loss: 9.9627 | CB1 Loss: 4.5851
Epoch 0 | Step 600/1516 | Total Loss: 9.8013 | CB1 Loss: 4.4313
Epoch 0 | Step 650/1516 | Total Loss: 9.7768 | CB1 Loss: 4.4808
Epoch 0 | Step 700/1516 | Total Loss: 9.2992 | CB1 Loss: 4.1366
Epoch 0 | Step 750/1516 | Total Loss: 9.5975 | CB1 Loss: 4.3616
Epoch 0 | Step 800/1516 | Total Loss: 9.7100 | CB1 Loss: 4.4841
Epoch 0 | Step 850/1516 | Total Loss: 9.5965 | CB1 Loss: 4.4378
Epoch 0 | Step 900/1516 | Total Loss: 9.4299 | CB1 Loss: 4.3070
Epoch 0 | Step 950/1516 | Total Loss: 9.5516 | CB1 Loss: 4.4275
Epoch 0 | Step 1000/1516 | Total Loss: 9.0184 | CB1 Loss: 4.0058
Epoch 0 | Step 1050/1516 | Total Loss: 8.5885 | CB1 Loss: 3.7543
Epoch 0 | Step 1100/1516 | Total Loss: 8.9285 | CB1 Loss: 4.0639
Epoch 0 | Step 1150/1516 | Total Loss: 9.2187 | CB1 Loss: 4.2103
Epoch 0 | Step 1200/1516 | Total Loss: 9.0627 | CB1 Loss: 4.0787
Epoch 0 | Step 1250/1516 | Total Loss: 9.1363 | CB1 Loss: 4.1802
Epoch 0 | Step 1300/1516 | Total Loss: 9.0649 | CB1 Loss: 4.1065
Epoch 0 | Step 1350/1516 | Total Loss: 8.8088 | CB1 Loss: 3.8661
Epoch 0 | Step 1400/1516 | Total Loss: 8.6759 | CB1 Loss: 3.8462
Epoch 0 | Step 1450/1516 | Total Loss: 8.5416 | CB1 Loss: 3.7931
Epoch 0 | Step 1500/1516 | Total Loss: 8.7619 | CB1 Loss: 3.8879
Epoch 0: Average loss = 9.959296647351147
Epoch 1 | Step 0/1516 | Total Loss: 8.8611 | CB1 Loss: 4.0412
...
Epoch 1: Average loss = 8.619560586431096
...
Epoch 3 | Step 0/1516 | Total Loss: 8.2190 | CB1 Loss: 3.5659
...
Epoch 3: Average loss = 8.13771899960601
Epoch 4 | Step 0/1516 | Total Loss: 8.1016 | CB1 Loss: 3.4350
Epoch 4 | Step 50/1516 | Total Loss: 7.6791 | CB1 Loss: 3.2969
Epoch 4 | Step 100/1516 | Total Loss: 7.5307 | CB1 Loss: 3.1364
...

训练日志分析:

我们先看 每个 epoch 的平均 loss

Epoch 0 → 1 下降很明显(-1.3),之后下降变缓,到 epoch 3 已经接近平稳

这是符合正常的 Transformer 收敛曲线。

接下来 CB1 Loss 分析:

CB1 是最重要的频带(低频主旋律结构)。

它从 7.1 下降到 3.3,这是一个非常健康的下降。说明模型确实学到音乐结构。

最后,Total Loss 是加权的:

1.0 * CB1 + 0.5 * CB2 + 0.25 * CB3 + 0.1 * CB4

这会导致模型 更关注低频,高频可能学得慢。这对生成音乐来说是合理的。

5、保存 checkpoint

os.makedirs("maestro_checkpoints", exist_ok=True)

torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "epoch": epoch,
}, f"maestro_checkpoints/checkpoint_epoch{epoch}.pt")