Music Transformer STEP2 - Model Training
把 Encodec 编码出来的 多 codebook 离散音频 token,当成“语言序列”来做自回归建模
Introduction:
1、构建训练样本:自回归格式
2、构建多 codebook 自回归 Transformer
3、构建 Dataset,用于 DataLoader
4、训练
5、保存 checkpoint
0、环境配置
!pip install -q torch torchaudio encodec torchcodec
import torch
import torchaudio
import encodec
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
1、构建训练样本(自回归格式)
选一个 10s 的音频,先看一下对于它 构建训练样本 是否能得到正确的样本
下面主要做了:转置这个音频的矩阵 + 赋值 x 和 y(y 始终比 x 往后偏移一个时间步)来构建训练样本
最终实现 “根据历史 token 序列,逐帧生成下一个帧的 token”。
# 测试
import torch
import torch.nn as nn
# 选一首歌、一个 chunk 作为最小实验
data = torch.load("maestro_tokens_3kbps/music_0000.pt")
# 取第一个 10s chunk
codes = data["tokens"][0] # Tensor[K, T]【4 本 codebook(4 行),750 个时间步(750 个纵列)】
print("Original codes shape:", codes.shape)
# 把 音频编码 矩阵 转置
frames = codes.transpose(0, 1).contiguous()
print("Frames shape:", frames.shape)
# 自回归 shift
x = frames[:-1] # [T-1, K]
y = frames[1:] # [T-1, K]
print("x shape:", x.shape)
print("y shape:", y.shape)
# 看一下前 3 帧
print("x[:3]:\n", x[:3])
print("y[:3]:\n", y[:3])
输出:
Original codes shape: torch.Size([4, 750])
Frames shape: torch.Size([750, 4])
x shape: torch.Size([749, 4])
y shape: torch.Size([749, 4])
x[:3]:
tensor([[ 133, 363, 989, 1018],
[ 228, 87, 863, 793],
[ 133, 859, 1007, 74]])
y[:3]:
tensor([[ 228, 87, 863, 793],
[ 133, 859, 1007, 74],
[ 999, 646, 675, 866]])
可以看到 x(输入) 的第2、3行 分别相同于 y(预测) 的第1、2行
即 输入x1 预测y1;输入x2 预测y2 。
2、构建多 codebook 自回归 Transformer
一、定义部分:
1、定义 4 本 codebook 分别的 embedding - 为每个 codebook 单独建立一个 embedding 表 - 不同 codebook 表示不同频带/量化层,它们不是同一个“词表语义空间”,所以不能共享 embedding。
2、定义位置编码 embedding:告诉 transformer,输出的每个音频位置是不一样的,是有顺序的
3、定义 Transformer 块 - 定义标准的 nn.TransformerEncoderLayer - 堆叠 6 层(默认)
4、定义输出头:其实就是一个 分类头
二、forward 部分:
1、输入维度拆解:读到 B = 8;T = 750;K = 4
2、每个 codebook 单独 embedding
3、堆叠并展平:拉平成一个长序列
4、加位置编码:加一个维度 None,然后广播到各个 batch
5、构造自回归 mask:生成 上三角 = -inf,第 i 个音频 token 不能看未来的 音频 token。
6、Transformer 前向:每个 音频 token 都融合了过去所有音频 token 的信息
7、输出分类:进行打分 logits
8、还原形状:这样就可以对每个时间步、每个 codebook 做 交叉熵 loss
class MusicTransformer(nn.Module):
def __init__(
self, vocab_size=1024,
num_codebooks=4,
embed_dim=256,
max_seq_len=4096,
num_layers=6,
num_heads=8,
dropout=0.2
):
super().__init__()
self.num_codebooks = num_codebooks
self.embed_dim = embed_dim
self.embeddings = nn.ModuleList([
nn.Embedding(vocab_size, embed_dim)
for _ in range(num_codebooks)
])
self.pos_embedding = nn.Embedding(max_seq_len, embed_dim)
decoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=num_heads,
dim_feedforward=embed_dim * 4, dropout=dropout, batch_first=True
)
self.transformer = nn.TransformerEncoder(decoder_layer, num_layers=num_layers)
self.output_head = nn.Linear(embed_dim, vocab_size)
def forward(self, x):
B, T, K = x.shape
h_list = []
for k in range(K):
h_list.append(self.embeddings[k](x[:, :, k]))
h = torch.stack(h_list, dim=2).reshape(B, T * K, -1)
total_len = T * K
pos_ids = torch.arange(total_len, device=x.device)
h = h + self.pos_embedding(pos_ids)[None, :, :]
mask = torch.triu(
torch.full((total_len, total_len), float('-inf'), device=x.device),
diagonal=1
)
h = self.transformer(h, mask)
logits = self.output_head(h)
logits = logits.reshape(B, T, K, -1)
return logits
output_heads 在这里,这是一个分类头
得到的 logits_k 矩阵 的真实形状是:[B, T, vocab_size] [批次大小, 时间步, 词表里词数量] [一批次几段音频, 750时间步, 词表里共1024段音频]
logits_k[b, t, i] 表示:打分(score)
-
在第 b 个样本、第 t 个时间步、第 k 个 codebook 上,
-
在给定所有可见音频上下文的条件下,模型“倾向于生成 token i 的概率评分”
真实 logits 矩阵:750行(每行表示 1 个时间步),1024列(每列表示 codebook 中一个音频词 token)
- 第1行表示:在第1个时间步,“生成codebook中1024个音频分别是什么概率?”的评分值。
- 第2行 以此类推...
- 第750行 以此类推...
注意:这 750 个时间步 实际是 10s。也就是:1s 是 75 个时间步。
# 测试
model = MusicTransformer(
vocab_size=1024,
num_codebooks=x.shape[1], # x的列数:4列(4本codebook)(749行,因为10s有750个音再-1)
embed_dim=256,
num_layers=6,
num_heads=8
)
x_small = x[:512] # 只取512长度
dummy_x = x_small.unsqueeze(0)
logits = model(dummy_x)
print(f"Number of codebooks: {len(logits)}") # logits打分矩阵的行数:4行,分别对应4本codebook里的所有音
print("Logits shape for codebook 0:", logits[0].shape)
Number of codebooks: 1
Logits shape for codebook 0: torch.Size([512, 4, 1024])
import os
import torch
import random
TOKEN_DIR = "maestro_tokens_3kbps"
song_files = sorted([
os.path.join(TOKEN_DIR, f)
for f in os.listdir(TOKEN_DIR)
if f.endswith(".pt")
])
print(f"Found {len(song_files)} songs.")
Found 65 songs.
# 看看我们存的.pt文件是什么 是dict还是tensor
data = torch.load("maestro_tokens_3kbps/music_0000.pt")
print(type(data))
print(data.keys())
<class 'dict'>
dict_keys(['tokens', 'num_chunks', 'bandwidth', 'sample_rate', 'chunk_size', 'music_id'])
data = torch.load("maestro_tokens_3kbps/music_0000.pt")
print(type(data["tokens"]))
print(len(data["tokens"]))
print(type(data["tokens"][0]))
print(data["tokens"][0].shape)
输出:
<class 'list'>
96
<class 'torch.Tensor'>
torch.Size([4, 750])
3、构建 Dataset,用于 DataLoader
这个模块整体需要做的是:
- 拼接所有 chunk
- 得到整首歌 token 序列
- 切成长度 block_size 的小片段
- 构造 (x, y) 训练样本
1、读取 token 文件:保证顺序一致
2、处理每一首 music - 取出 chunk 列表:每个是一个 10 秒 chunk - 转置
3、滑动窗口切样本 - 循环遍历音频的每一个采样点:构造 x、构造 y、加入样本
补充:为什么用 stride?
重叠率:87.5% - 好处:数据量暴增,每个时间点被多次训练 - 坏处:数据冗余大,训练更慢
import os
import torch
from torch.utils.data import Dataset
class MusicTokenDataset(Dataset):
def __init__(
self,
token_dir,
block_size=2048,
stride=256
):
self.block_size = block_size
self.stride = stride
self.samples = []
files = sorted([
os.path.join(token_dir, f)
for f in os.listdir(token_dir)
if f.endswith(".pt")
])
for f in files:
data = torch.load(f)
chunks = data["tokens"] # list of [4, 750]
tokens = torch.cat(chunks, dim=1) # [4, T]
tokens = tokens.permute(1, 0) # [T, 4]
T = tokens.size(0)
# 加 stride
for i in range(0, T - block_size - 1, stride):
x = tokens[i:i+block_size]
y = tokens[i+1:i+block_size+1]
self.samples.append((x, y))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
4、训练
整体流程概述:
1、初始化构建 Dataset、DataLoader、模型、损失函数、优化器、学习率调度器
- 标准分类损失:用于预测:token id ∈ [0,1023]
- 优化器:AdamW = Adam + 权重衰减。
- 前 5% 步数做 warmup:防止训练初期不稳定;避免梯度爆炸
补充:为什么不用 Adam?
- Transformer 训练通常配 AdamW
- weight_decay=1e-2 是常见正则强度。
补充:为什么 warmup?
- 学习率曲线:线性升高 → 线性下降
- 防止训练初期不稳定,避免梯度爆炸,这是 GPT 标准训练策略。
2、训练循环:开始 10 个 epoch。 - 进入训练模式:开启 dropout - 遍历 batch - 前向传播 - 加权多 codebook 损失 - 反向传播:计算梯度;梯度裁剪:防止梯度爆炸;更新参数、学习率 - 记录 loss:用于算 epoch 平均 - 打印日志:每 50 步打印一次;单独计算第一个 codebook loss。 - 每个 epoch 结束:输出平均损失
补充:
加权多 codebook 损失:
- weights = [1.0, 0.5, 0.25, 0.1]
- loss = 0
为什么这样设定:第一个 codebook 权重最高,越往后权重越低?
因为 Encodec 的第一个 codebook 通常包含最重要的低频信息。
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
dataset = MusicTokenDataset(
"maestro_tokens_3kbps",
block_size=1024,
stride=256
)
loader = DataLoader(
dataset,
batch_size=8,
shuffle=True,
num_workers=4,
pin_memory=True
)
print("Dataset size:", len(dataset))
print("Batches per epoch:", len(loader))
model = MusicTransformer(
vocab_size=1024,
num_codebooks=4,
embed_dim=512,
max_seq_len=8192,
num_layers=6,
num_heads=8,
dropout=0.2,
).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
weight_decay=1e-2
)
epochs = 10
total_steps = len(loader) * epochs
warmup_steps = int(total_steps * 0.05)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
for epoch in range(epochs):
model.train()
total_loss = 0
for step, (x, y) in enumerate(loader):
x, y = x.cuda(), y.cuda()
logits = model(x) # [B, T, K, 1024]
weights = [1.0, 0.5, 0.25, 0.1]
loss = 0
for k in range(4):
l_k = criterion(
logits[:, :, k, :].reshape(-1, 1024),
y[:, :, k].reshape(-1)
)
loss += weights[k] * l_k
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
if step % 50 == 0:
with torch.no_grad():
cb1_loss = criterion(logits[:, :, 0, :].reshape(-1, 1024), y[:, :, 0].reshape(-1))
print(
f"Epoch {epoch} | Step {step}/{len(loader)} | "
f"Total Loss: {loss.item():.4f} | CB1 Loss: {cb1_loss.item():.4f}"
)
print(f"Epoch {epoch}: Average loss = {total_loss/len(loader)}")
输出:
Dataset size: 12124
Batches per epoch: 1516
Epoch 0 | Step 0/1516 | Total Loss: 13.1681 | CB1 Loss: 7.1360
Epoch 0 | Step 50/1516 | Total Loss: 12.4348 | CB1 Loss: 6.5062
Epoch 0 | Step 100/1516 | Total Loss: 12.1048 | CB1 Loss: 6.1696
Epoch 0 | Step 150/1516 | Total Loss: 12.1125 | CB1 Loss: 6.1280
Epoch 0 | Step 200/1516 | Total Loss: 11.9614 | CB1 Loss: 5.9766
Epoch 0 | Step 250/1516 | Total Loss: 11.4711 | CB1 Loss: 5.6517
Epoch 0 | Step 300/1516 | Total Loss: 10.9681 | CB1 Loss: 5.2020
Epoch 0 | Step 350/1516 | Total Loss: 10.7473 | CB1 Loss: 5.0230
Epoch 0 | Step 400/1516 | Total Loss: 10.2655 | CB1 Loss: 4.6984
Epoch 0 | Step 450/1516 | Total Loss: 10.0230 | CB1 Loss: 4.5502
Epoch 0 | Step 500/1516 | Total Loss: 10.0525 | CB1 Loss: 4.6381
Epoch 0 | Step 550/1516 | Total Loss: 9.9627 | CB1 Loss: 4.5851
Epoch 0 | Step 600/1516 | Total Loss: 9.8013 | CB1 Loss: 4.4313
Epoch 0 | Step 650/1516 | Total Loss: 9.7768 | CB1 Loss: 4.4808
Epoch 0 | Step 700/1516 | Total Loss: 9.2992 | CB1 Loss: 4.1366
Epoch 0 | Step 750/1516 | Total Loss: 9.5975 | CB1 Loss: 4.3616
Epoch 0 | Step 800/1516 | Total Loss: 9.7100 | CB1 Loss: 4.4841
Epoch 0 | Step 850/1516 | Total Loss: 9.5965 | CB1 Loss: 4.4378
Epoch 0 | Step 900/1516 | Total Loss: 9.4299 | CB1 Loss: 4.3070
Epoch 0 | Step 950/1516 | Total Loss: 9.5516 | CB1 Loss: 4.4275
Epoch 0 | Step 1000/1516 | Total Loss: 9.0184 | CB1 Loss: 4.0058
Epoch 0 | Step 1050/1516 | Total Loss: 8.5885 | CB1 Loss: 3.7543
Epoch 0 | Step 1100/1516 | Total Loss: 8.9285 | CB1 Loss: 4.0639
Epoch 0 | Step 1150/1516 | Total Loss: 9.2187 | CB1 Loss: 4.2103
Epoch 0 | Step 1200/1516 | Total Loss: 9.0627 | CB1 Loss: 4.0787
Epoch 0 | Step 1250/1516 | Total Loss: 9.1363 | CB1 Loss: 4.1802
Epoch 0 | Step 1300/1516 | Total Loss: 9.0649 | CB1 Loss: 4.1065
Epoch 0 | Step 1350/1516 | Total Loss: 8.8088 | CB1 Loss: 3.8661
Epoch 0 | Step 1400/1516 | Total Loss: 8.6759 | CB1 Loss: 3.8462
Epoch 0 | Step 1450/1516 | Total Loss: 8.5416 | CB1 Loss: 3.7931
Epoch 0 | Step 1500/1516 | Total Loss: 8.7619 | CB1 Loss: 3.8879
Epoch 0: Average loss = 9.959296647351147
Epoch 1 | Step 0/1516 | Total Loss: 8.8611 | CB1 Loss: 4.0412
...
Epoch 1: Average loss = 8.619560586431096
...
Epoch 3 | Step 0/1516 | Total Loss: 8.2190 | CB1 Loss: 3.5659
...
Epoch 3: Average loss = 8.13771899960601
Epoch 4 | Step 0/1516 | Total Loss: 8.1016 | CB1 Loss: 3.4350
Epoch 4 | Step 50/1516 | Total Loss: 7.6791 | CB1 Loss: 3.2969
Epoch 4 | Step 100/1516 | Total Loss: 7.5307 | CB1 Loss: 3.1364
...
训练日志分析:
我们先看 每个 epoch 的平均 loss:
Epoch 0 → 1 下降很明显(-1.3),之后下降变缓,到 epoch 3 已经接近平稳
这是符合正常的 Transformer 收敛曲线。
接下来 CB1 Loss 分析:
CB1 是最重要的频带(低频主旋律结构)。
它从 7.1 下降到 3.3,这是一个非常健康的下降。说明模型确实学到音乐结构。
最后,Total Loss 是加权的:
1.0 * CB1 + 0.5 * CB2 + 0.25 * CB3 + 0.1 * CB4
这会导致模型 更关注低频,高频可能学得慢。这对生成音乐来说是合理的。
5、保存 checkpoint
os.makedirs("maestro_checkpoints", exist_ok=True)
torch.save({
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch,
}, f"maestro_checkpoints/checkpoint_epoch{epoch}.pt")