Music Transformer STEP3 - Music Generation
在 STEP2 中我们完成了 Music Transformer 的训练,得到了能够建模多 codebook 离散音频 token 的自回归模型。
STEP3 的核心目标是利用训练好的模型进行音乐生成,从一段初始音频 token 出发,让模型逐帧 autoregressively 续写新的音频 token,最后通过 Encodec 解码为可播放的波形音频,并保存为文件。
0、环境配置
首先配置运行所需依赖,安装 PyTorch、音频处理库 torchaudio、音频编解码器 Encodec 以及 torchcodec,同时设置环境缓存路径,检查 GPU 可用性。
!pip install -q torch torchaudio encodec torchcodec
import torch
import torchaudio
import encodec
import os
os.environ["TORCH_HOME"] = "/mnt/workspace/torch_cache"
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
执行后会输出当前 PyTorch 版本与 CUDA 是否可用,确保后续生成与解码能在 GPU 上加速运行。
1、加载训练好的 Music Transformer 模型
这一步的核心是恢复训练完成的模型结构与权重,为后续生成做准备:
-
严格按照 STEP2 训练时的模型参数构建 MusicTransformer,保证结构一致
-
加载 STEP2 保存的 checkpoint 模型权重
-
将模型切换为 eval 模式,关闭 dropout、batchnorm 等训练特有的层行为,避免影响生成稳定性
import torch
import torch.nn as nn
from music_transformer import MusicTransformer
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MusicTransformer(
vocab_size=1024,
num_codebooks=4,
embed_dim=512,
max_seq_len=8192,
num_layers=6,
num_heads=8,
dropout=0.2,
).to(device)
checkpoint = torch.load("maestro_checkpoints/checkpoint_epoch4.pt", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
执行后会打印完整的模型结构,包含 4 个 codebook 独立嵌入层、位置编码、6 层 TransformerEncoder 以及输出分类头,与训练时结构完全一致,证明模型加载成功。
输出:
MusicTransformer(
(embeddings): ModuleList(
(0-3): 4 x Embedding(1024, 512)
)
(pos_embedding): Embedding(8192, 512)
(transformer): TransformerEncoder(
(layers): ModuleList(
(0-5): 6 x TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
)
(linear1): Linear(in_features=512, out_features=2048, out_features=512, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
(linear2): Linear(in_features=2048, out_features=512, bias=True)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
)
)
)
(output_head): Linear(in_features=512, out_features=1024, bias=True)
)
2、定义自回归生成函数
实现核心的逐帧生成逻辑,遵循自回归生成规则:基于历史所有 token,预测下一帧的 4 个 codebook token。
函数关键逻辑说明:
-
关闭梯度计算,减少显存占用、加速推理
-
每次只截取模型支持的最大序列长度,防止溢出
-
对每个 codebook 独立预测下一 token:通过模型输出 logits → 温度系数缩放 → softmax 转概率 → 多项式采样
-
把新生成的一帧拼接到序列末尾,循环直至达到目标生成长度
-
temperature:控制生成随机性,值越大生成越多样,值越小越保守
import torch.nn.functional as F
@torch.no_grad()
def generate(model, start_tokens, max_new_tokens=200, temperature=1.0):
model.eval()
x = start_tokens.to(device)
for _ in range(max_new_tokens):
# 截断到模型最大支持长度,避免序列过长
x_cond = x[:, -model.max_seq_len:]
# 前向推理得到所有位置的预测 logits
logits = model(x_cond) # [B, T, K, vocab]
next_tokens = []
# 对 4 个 codebook 分别预测下一 token
for k in range(model.num_codebooks):
# 取最后一帧的 logits,用温度调节概率分布
last_logits = logits[:, -1, k, :] / temperature
probs = F.softmax(last_logits, dim=-1)
# 按概率采样下一个 token
next_token = torch.multinomial(probs, num_samples=1)
next_tokens.append(next_token)
# 拼接成 [B, 1, K] 格式
next_tokens = torch.stack(next_tokens, dim=-1)
# 拼接到原序列后,实现自回归续写
x = torch.cat([x, next_tokens], dim=1)
return x
3、准备初始提示 token(Prompt)
从训练集数据集中取出一段真实音频 token 作为生成起点(Prompt):
-
加载 STEP2 中使用的 MusicTokenDataset,保证数据格式一致
-
取样本的前 10 帧作为初始序列,让模型基于这一小段音乐继续生成
-
生成完成后,统计每个 codebook 生成 token 的唯一数量,判断生成是否具有多样性
from dataset import MusicTokenDataset
dataset = MusicTokenDataset(
"maestro_tokens_3kbps",
block_size=1024,
stride=256
)
x_sample, _ = dataset[0]
start = x_sample[:10].unsqueeze(0) # 前10帧作为prompt
# 开始生成,续写 400 帧新 token
generated_tokens = generate(
model,
start_tokens=start,
max_new_tokens=400,
temperature=1.1
)
# 统计每个 codebook 生成的不重复 token 数量
for k in range(4):
print("Codebook", k,
generated_tokens[0, :, k].unique().shape[0])
输出代表每个 codebook 生成了多少种不同的 token,数值合理说明模型没有坍缩为单一重复值,生成具有多样性:
Codebook 0 144
Codebook 1 230
Codebook 2 260
Codebook 3 282
4、Encodec 解码:从 token 还原音频波形
模型生成的是离散音频 token,需要通过训练时使用的 Encodec 解码器还原为连续波形音频:
-
加载 24kHz、3kbps 配置的 Encodec 模型,与编码时保持一致
-
调整 token 维度格式,适配 Encodec 解码输入要求
-
执行解码,得到可播放的波形张量
from encodec import EncodecModel
import torchaudio
# 加载与编码时一致的 Encodec 模型
codec = EncodecModel.encodec_model_24khz()
codec.set_target_bandwidth(3.0) # 3kbps
codec = codec.to(device)
codec.eval()
# 调整维度为 Encodec 要求的 [1, K, T]
codes = generated_tokens[0].permute(1, 0).unsqueeze(0).to(device)
# Encodec 解码要求输入格式为 (codes, scale) 列表
encoded_frames = [(codes, None)] # 关键格式适配
with torch.no_grad():
wav = codec.decode(encoded_frames)
5、播放生成的音乐
在 Notebook 环境中直接播放生成的音频波形,快速试听生成效果。
from IPython.display import Audio
audio = wav[0, 0].cpu().numpy()
Audio(audio, rate=24000)
6、保存生成音频为文件
把模型生成的音频波形保存为 .wav 格式文件,方便导出、本地播放与后续处理。
import soundfile as sf
audio = wav[0, 0].cpu().numpy()
sf.write("generated_music.wav", audio, 24000)
执行后会在当前目录生成 generated_music.wav,即 Music Transformer 自回归生成的完整音乐。