Skip to content

Part B: CEM-MPC and Latent Actor-Critic

Given a world model, how does an agent use it to select actions? This section is the direct prerequisite for P03 and P04. It introduces three planning mechanisms: from the most intuitive random search, to Dreamer's imagination-based training, to TD-MPC's hybrid approach.

MuZero and the Counterfactual Paradigm

There is a class of tasks that takes counterfactual reasoning to the extreme: the counterfactual paradigm, which forgoes pixel prediction entirely and instead makes accurate predictions only at the abstract level of values or rewards. MuZero (Nature, 2020) decomposes the world model into three functions:

  • Representation function hθ: compresses past observations o1:t into an internal hidden state s0=hθ(o1:t)
  • Dynamics function gθ: given the previous hidden state and a candidate action, predicts the immediate reward and the next hidden state: rk,sk=gθ(sk1,ak)
  • Prediction function fθ: predicts a policy prior and value from the hidden state: pk,vk=fθ(sk)

The three functions are trained jointly end-to-end, with the total loss:

lt(θ)=k=0K[lr(ut+k,rtk)+lv(zt+k,vtk)+lp(πt+k,ptk)]+cθ2

Symbol definitions: K is the number of unroll steps (how many steps are unrolled per training update); lr, lv, lp are the loss functions for the reward, value, and policy prediction heads respectively; ut+k is the actual reward collected from real interactions (the training target); rtk is the reward predicted by the dynamics function; zt+k is the n-step bootstrapped target value (constructed from real rewards plus a value estimate several steps later); πt+k is the improved policy produced by MCTS search (the visit-count distribution, used as the training target for the policy head); cθ2 is L2 regularization (weight decay, where c is the regularization coefficient, preventing overfitting from excessively large parameters). The hidden state sk has no semantic constraints: it does not need to correspond to the true environment state, nor does it need to be able to reconstruct pixels. The only requirement is: "starting from sk, accurately predict rewards, values, and policies." This is the most fundamental design difference between MuZero and PlaNet/Dreamer.

MuZero maintains three prediction heads:

Prediction headPrediction targetRole
reward headimmediate reward rtevaluates the quality of the current step
value headfuture cumulative value V(st)guides MCTS search direction
policy prioraction probability distribution π(ast)reduces the number of branches MCTS needs to explore

All three heads are trained jointly through the unrolled dynamics function on real interaction data.

MuZero's implicit world model: three-module architecture of representation function, dynamics function, and prediction function
Schrittwieser et al. (2020) MuZero's three-function structure: the representation function h compresses historical observations into hidden state s; the dynamics function g simulates action transitions in hidden state space and predicts immediate rewards; the prediction function f outputs a policy prior and value estimate from the hidden state, driving MCTS search. The hidden state does not need to correspond to real pixels, only to support accurate reward and value prediction.

As long as these three prediction heads are accurate, the exact form of the latent state st does not matter. For the agent, "faithfully reconstructing the world" is not necessarily the optimal objective. MuZero achieves superhuman performance on Go (without being given the rules), Chess, Shogi, and 57 Atari games, while relying on no real model or environment rules.

📖 MCTS (Monte Carlo Tree Search): Starting from the current state, repeatedly perform four steps: (1) Select: traverse down the tree, selecting the node with the highest UCB score (Upper Confidence Bound, UCB=Q(s,a)+clnN/na, where Q is the average value of that action, N is the total visit count of the parent node, na is the visit count of that action, and c is the exploration coefficient; UCB balances "choosing known good actions" with "exploring less-visited actions"); (2) Expand: try a new action at a leaf node; (3) Simulate/Evaluate: use the neural network to estimate the value of the new node (MuZero uses the value head directly, without rollout); (4) Backpropagate: update the value estimate upward along the path. After hundreds of repetitions, the most-visited action is the one "deemed optimal after sufficient search." MuZero's key extension over AlphaZero: support for single-agent domains (not just two-player games) and intermediate step rewards (Atari), with value targets constructed via n-step bootstrapping rather than terminal win/loss.


Mechanism 1: CEM Shooting-Method MPC

📖 CEM (Cross-Entropy Method): A sampling-based optimization algorithm. The core idea: sample a large number of candidate solutions from a distribution (e.g., Gaussian), evaluate the objective value of each candidate, retain the best fraction (elite samples), refit the distribution using these elite samples (update the mean and variance), and repeat. With each iteration, the sampling distribution narrows and concentrates toward high-quality regions. Here it is used to search for optimal action sequences in the action sequence space, hence the name "shooting method."

📖 MPC (Model Predictive Control): At each time step, use the model to predict H steps into the future (H is called the planning horizon), select the optimal action sequence, execute only the first action, then re-plan at the next step. Even if the model is imperfect, frequent re-planning corrects errors promptly, preventing them from accumulating indefinitely.

In one sentence: randomly sample a batch of action sequences, "imagine" executing them in the model, select the sequence with the highest expected return, execute only the first step, and repeat.

Algorithm steps:

CEM-MPC Planning Loop (executed once per step)

Input: current state s_t, world model f, reward model r, planning steps H, refinement rounds K

1. Initialize action distribution: μ ← 0, σ ← 1

2. FOR k = 1 to K (refinement rounds):
   a. Sample N action sequences from N(μ, σ²): {a^(i)_{t:t+H}}
   b. FOR each sequence i:
        Roll out imagined trajectory: s^(i)_{t+1} = f(s_t, a^(i)_t), ..., s^(i)_{t+H}
        Compute cumulative reward: R^(i) = Σ_{h=0}^{H-1} γ^h · r(s^(i)_{t+h}, a^(i)_{t+h})
        # γ (gamma) is the discount factor, 0 < γ < 1, causing future rewards to decay exponentially
        # γ=0.99 means a reward 100 steps away is still worth 0.99^100 ≈ 0.37 of its face value
   c. Select Top-K sequences (sorted by R^(i) descending)
   d. Refit using Top-K sequences: μ ← mean(Top-K), σ ← std(Top-K)

3. Execute the first action from μ: a_t ← μ[0]

The first round of sampling covers a broad range with low precision, identifying roughly "where the high-return regions are." Subsequent rounds refit the distribution using elite sequences, progressively narrowing the sampling range toward high-return regions.

Limitation: in high-dimensional continuous action spaces (e.g., a robotic arm controlling 7 joints simultaneously), random search is extremely inefficient. This is the core problem TD-MPC addresses: guiding the search with a Q-function rather than sampling blindly.

Advantages: simple, gradient-free, easy to implement, with no differentiability requirements on the world model.


Mechanism 2: Actor-Critic in Latent Space (Dreamer's Approach)

📖 Actor-Critic architecture: consists of two networks. The Actor (policy network πθ(a|s)) handles "decision-making," and the Critic (value network Vϕ(s)) handles "evaluation." The baseline provided by the Critic greatly reduces the variance of gradient estimates, making training more stable.

Dreamer's core insight: rather than collecting large amounts of data in the real environment to train a policy, train inside the imagined trajectories of the world model, which is faster, risk-free, and differentiable.

Training procedure:

  1. Imagination rollout: starting from the current latent state zt, sample actions with the Actor and use RSSM to roll forward H steps
  2. Critic evaluation: compute V(zh) for each imagined state, constructing training targets with λ-return
  3. Actor optimization: the Actor maximizes the cumulative value predicted by the Critic via backpropagation through the entire imagined trajectory
  4. World model update: update the RSSM and encoder using real environment data (reconstruction loss + KL)

Intuition behind λ-return: pure Monte Carlo requires waiting until the episode ends to obtain a true return, giving high variance; pure TD looks only one step ahead, giving high bias. λ-return interpolates between the two, constructing a k-step return using "the first k steps of real rewards plus a Critic estimate at step k+1," then taking a weighted average over all k. λ1 trusts the real rollout; λ0 trusts the Critic.

Why differentiability matters: the Actor's gradients flow directly through the differentiable dynamics of the RSSM, which is far more accurate than estimating policy gradients via Monte Carlo sampling.

Model exploitation problem: the policy may discover actions that yield high rewards inside the model but are invalid in the real world, such as high-frequency jitter actions that score highly in the world model but would only damage motors on a real robot. The Dreamer series addresses this by periodically updating the world model with real environment data and limiting the number of imagination rollout steps, but the problem has not been fundamentally solved.

CEM is inefficient in high-dimensional action spaces, and Actor-Critic carries model exploitation risk. TD-MPC combines both approaches to address these two problems, as described in the next section.