Skip to content

Music Transformer STEP3 - Music Generation

在 STEP2 中我们完成了 Music Transformer 的训练,得到了能够建模多 codebook 离散音频 token 的自回归模型。

STEP3 的核心目标是利用训练好的模型进行音乐生成,从一段初始音频 token 出发,让模型逐帧 autoregressively 续写新的音频 token,最后通过 Encodec 解码为可播放的波形音频,并保存为文件。


0、环境配置

首先配置运行所需依赖,安装 PyTorch、音频处理库 torchaudio、音频编解码器 Encodec 以及 torchcodec,同时设置环境缓存路径,检查 GPU 可用性。

!pip install -q torch torchaudio encodec torchcodec
import torch
import torchaudio
import encodec
import os

os.environ["TORCH_HOME"] = "/mnt/workspace/torch_cache"

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

执行后会输出当前 PyTorch 版本与 CUDA 是否可用,确保后续生成与解码能在 GPU 上加速运行。


1、加载训练好的 Music Transformer 模型

这一步的核心是恢复训练完成的模型结构与权重,为后续生成做准备:

  1. 严格按照 STEP2 训练时的模型参数构建 MusicTransformer,保证结构一致

  2. 加载 STEP2 保存的 checkpoint 模型权重

  3. 将模型切换为 eval 模式,关闭 dropout、batchnorm 等训练特有的层行为,避免影响生成稳定性

import torch
import torch.nn as nn
from music_transformer import MusicTransformer

device = "cuda" if torch.cuda.is_available() else "cpu"

model = MusicTransformer(
    vocab_size=1024,
    num_codebooks=4,
    embed_dim=512,
    max_seq_len=8192,
    num_layers=6,
    num_heads=8,
    dropout=0.2,
).to(device)

checkpoint = torch.load("maestro_checkpoints/checkpoint_epoch4.pt", map_location=device)

model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

执行后会打印完整的模型结构,包含 4 个 codebook 独立嵌入层、位置编码、6 层 TransformerEncoder 以及输出分类头,与训练时结构完全一致,证明模型加载成功。

输出:

MusicTransformer(
  (embeddings): ModuleList(
    (0-3): 4 x Embedding(1024, 512)
  )
  (pos_embedding): Embedding(8192, 512)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, out_features=512, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
      )
    )
  )
  (output_head): Linear(in_features=512, out_features=1024, bias=True)
)


2、定义自回归生成函数

实现核心的逐帧生成逻辑,遵循自回归生成规则:基于历史所有 token,预测下一帧的 4 个 codebook token

函数关键逻辑说明:

  1. 关闭梯度计算,减少显存占用、加速推理

  2. 每次只截取模型支持的最大序列长度,防止溢出

  3. 对每个 codebook 独立预测下一 token:通过模型输出 logits → 温度系数缩放 → softmax 转概率 → 多项式采样

  4. 把新生成的一帧拼接到序列末尾,循环直至达到目标生成长度

  5. temperature:控制生成随机性,值越大生成越多样,值越小越保守

import torch.nn.functional as F

@torch.no_grad()
def generate(model, start_tokens, max_new_tokens=200, temperature=1.0):
    model.eval()
    x = start_tokens.to(device)

    for _ in range(max_new_tokens):

        # 截断到模型最大支持长度,避免序列过长
        x_cond = x[:, -model.max_seq_len:]

        # 前向推理得到所有位置的预测 logits
        logits = model(x_cond)  # [B, T, K, vocab]

        next_tokens = []
        # 对 4 个 codebook 分别预测下一 token
        for k in range(model.num_codebooks):
            # 取最后一帧的 logits,用温度调节概率分布
            last_logits = logits[:, -1, k, :] / temperature
            probs = F.softmax(last_logits, dim=-1)
            # 按概率采样下一个 token
            next_token = torch.multinomial(probs, num_samples=1)
            next_tokens.append(next_token)

        # 拼接成 [B, 1, K] 格式
        next_tokens = torch.stack(next_tokens, dim=-1)

        # 拼接到原序列后,实现自回归续写
        x = torch.cat([x, next_tokens], dim=1)

    return x

3、准备初始提示 token(Prompt)

从训练集数据集中取出一段真实音频 token 作为生成起点(Prompt)

  1. 加载 STEP2 中使用的 MusicTokenDataset,保证数据格式一致

  2. 取样本的前 10 帧作为初始序列,让模型基于这一小段音乐继续生成

  3. 生成完成后,统计每个 codebook 生成 token 的唯一数量,判断生成是否具有多样性

from dataset import MusicTokenDataset

dataset = MusicTokenDataset(
    "maestro_tokens_3kbps",
    block_size=1024,
    stride=256
)

x_sample, _ = dataset[0]
start = x_sample[:10].unsqueeze(0)  # 前10帧作为prompt

# 开始生成,续写 400 帧新 token
generated_tokens = generate(
    model,
    start_tokens=start,
    max_new_tokens=400,
    temperature=1.1
)

# 统计每个 codebook 生成的不重复 token 数量
for k in range(4):
    print("Codebook", k, 
          generated_tokens[0, :, k].unique().shape[0])

输出代表每个 codebook 生成了多少种不同的 token,数值合理说明模型没有坍缩为单一重复值,生成具有多样性:

Codebook 0 144
Codebook 1 230
Codebook 2 260
Codebook 3 282


4、Encodec 解码:从 token 还原音频波形

模型生成的是离散音频 token,需要通过训练时使用的 Encodec 解码器还原为连续波形音频:

  1. 加载 24kHz、3kbps 配置的 Encodec 模型,与编码时保持一致

  2. 调整 token 维度格式,适配 Encodec 解码输入要求

  3. 执行解码,得到可播放的波形张量

from encodec import EncodecModel
import torchaudio

# 加载与编码时一致的 Encodec 模型
codec = EncodecModel.encodec_model_24khz()
codec.set_target_bandwidth(3.0)  # 3kbps
codec = codec.to(device)
codec.eval()

# 调整维度为 Encodec 要求的 [1, K, T]
codes = generated_tokens[0].permute(1, 0).unsqueeze(0).to(device)

# Encodec 解码要求输入格式为 (codes, scale) 列表
encoded_frames = [(codes, None)]  # 关键格式适配

with torch.no_grad():
    wav = codec.decode(encoded_frames)

5、播放生成的音乐

在 Notebook 环境中直接播放生成的音频波形,快速试听生成效果。

from IPython.display import Audio

audio = wav[0, 0].cpu().numpy()

Audio(audio, rate=24000)

6、保存生成音频为文件

把模型生成的音频波形保存为 .wav 格式文件,方便导出、本地播放与后续处理。

import soundfile as sf

audio = wav[0, 0].cpu().numpy()

sf.write("generated_music.wav", audio, 24000)

执行后会在当前目录生成 generated_music.wav,即 Music Transformer 自回归生成的完整音乐。