On-Policy Distillation (OPD) for Diffusion: An Introduction¶
This notebook explains how on-policy distillation (OPD) moves from language models to diffusion / flow-matching models.
We will build the story in a linear order:
- start with the original OPD idea for LLMs, following Thinking Machines Lab's blog post On-Policy Distillation: dense teacher feedback on the student's own tokens;
- translate the same idea to diffusion: dense teacher feedback on the student's own denoising states;
- build a single-teacher diffusion OPD example: an 8-mode student absorbs a two-mode teacher;
- extend to a multi-teacher example and compare three diffusion OPD designs: DiffusionOPD, Flow-OPD, and DanceOPD;
- discuss how this could combine with reward-based methods -- Tilt Matching and RAM -- to handle multi-reward training.
This notebook uses the same 8-mode toy world as our RAM tutorial and Tilt Matching tutorial.
The goal is not to reproduce image-scale numbers. The goal is to make the mechanics visible: where the student states come from, what the teacher supervises, how many states are queried, and why single-query DanceOPD can be much cheaper than dense-trajectory OPD.
Everything runs on CPU in a few minutes.
Edit History:
- 06-26-2026: initial standalone OPD tutorial
- 06-30-2026: add the router (Section 4.1); broaden Section 5 to Tilt Matching + RAM vs. MARBLE
0. Setup¶
We only need torch, numpy, matplotlib, tqdm, and IPython.display.
The code avoids GPU-only features and should run in Google Colab.
from __future__ import annotations
import math
from copy import deepcopy
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from IPython.display import Markdown, display
from tqdm import tqdm
torch.manual_seed(0)
np.random.seed(0)
torch.set_default_dtype(torch.float32)
torch.set_num_threads(4)
DEVICE = torch.device("cpu")
SHOW_PROGRESS = False
def display_markdown_table(headers, rows):
lines = [
"| " + " | ".join(headers) + " |",
"| " + " | ".join(["---"] * len(headers)) + " |",
]
lines.extend("| " + " | ".join(str(x) for x in row) + " |" for row in rows)
display(Markdown("\n".join(lines)))
print("torch:", torch.__version__)
torch: 2.11.0
Some plotting helpers. They are deliberately small so the OPD logic later stays readable.
PLOT_XLIM = (-5.4, 5.4)
PLOT_YLIM = (-5.4, 5.4)
def plot_samples(samples, ax=None, *, title=None, color="C0", s=6, alpha=0.45):
if ax is None:
_, ax = plt.subplots(figsize=(4.4, 4.4))
pts = samples.detach().cpu().numpy()
ax.scatter(pts[:, 0], pts[:, 1], s=s, alpha=alpha, color=color)
ax.set_xlim(*PLOT_XLIM)
ax.set_ylim(*PLOT_YLIM)
ax.set_aspect("equal")
if title:
ax.set_title(title)
return ax
def plot_density(density_fn, ax=None, *, title=None, n_grid=100, cmap="viridis"):
if ax is None:
_, ax = plt.subplots(figsize=(4.4, 4.4))
xs = torch.linspace(PLOT_XLIM[0], PLOT_XLIM[1], n_grid)
ys = torch.linspace(PLOT_YLIM[0], PLOT_YLIM[1], n_grid)
X, Y = torch.meshgrid(xs, ys, indexing="ij")
pts = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)
Z = density_fn(pts).reshape(n_grid, n_grid).detach().cpu().numpy()
ax.imshow(Z.T, origin="lower", extent=(*PLOT_XLIM, *PLOT_YLIM), cmap=cmap)
ax.set_aspect("equal")
if title:
ax.set_title(title)
return ax
1. Original OPD: dense teacher feedback on the student's own states¶
Thinking Machines Lab's blog post On-Policy Distillation frames post-training methods by two axes:
| method | samples come from | feedback is |
|---|---|---|
| supervised fine-tuning / ordinary distillation | teacher or dataset | dense |
| reinforcement learning | student | sparse |
| on-policy distillation | student | dense |
For LLMs, the student samples a token sequence. At each prefix $x_{1:t}$, a stronger teacher gives a next-token distribution. A simple OPD loss is per-token reverse KL:
$$ \mathrm{KL}\!\left( \pi_\theta(\cdot\mid x_{1:t}) \;\|\; \pi_{\mathrm{teacher}}(\cdot\mid x_{1:t}) \right), \qquad x_{1:t} \sim \pi_\theta. $$
The key phrase is student states. The teacher is not grading its own perfect trajectory. It grades the prefixes that the student actually visits.
This gives two benefits at once:
- like RL, the data are on-policy, so the student learns to recover from its own mistakes;
- like distillation, the feedback is dense, so the student receives a useful learning signal at every step, not just one final scalar reward.
1.1 Why this matters for diffusion¶
Diffusion sampling is also a trajectory. We will use $x_t$ for noisy states throughout the notebook. If the solver uses discrete times $t_N > t_{N-1} > \cdots > t_0$, then the reverse-time trajectory is
$$ x_{t_N} \to x_{t_{N-1}} \to x_{t_{N-2}} \to \cdots \to x_{t_0}, \qquad t_N = T,\quad t_0 = 0. $$
If the student's current sampler visits a noisy state $x_t$, we can ask a teacher diffusion model:
"What velocity would you predict at this same student state?"
That is the diffusion analogue of asking an LLM teacher:
"What next-token distribution would you predict at this same student prefix?"
In the next section we will introduce three recent diffusion OPD designs — DiffusionOPD, Flow-OPD, and DanceOPD — and compare them by two concrete choices: how much of the student trajectory they query and a few flow-specific design choices.
2. From token KL to diffusion transition matching¶
Before writing the local OPD loss, let's name the three arguments of a diffusion velocity model:
- $x_t$ is the current noisy state on the student trajectory;
- $t$ is the noise/time level of that state;
- $c$ is the conditioning information. In an image model this could be a text prompt, edit instruction, or source image. Some models are conditional, but our 2D examples below will intentionally omit $c$ so the student is a single unconditional sampler.
In LLM OPD, the teacher gives a next-token distribution at the student's current prefix. We can either say "minimize the reverse KL" or say "use negative KL as a dense reward." Those are two views of the same local signal:
$$ \text{good local step} \quad\Longleftrightarrow\quad \text{small } \mathrm{KL}\!\left( \pi_\theta(\cdot\mid\text{student prefix}) \;\|\; \pi_{\text{teacher}}(\cdot\mid\text{student prefix}) \right). $$
For diffusion, the "next token distribution" becomes a one-step transition distribution. Given the current noisy state $x_t$, both the student and the teacher define a distribution for the next state $x_{t-\Delta t}$:
$$ p_\theta(x_{t-\Delta t}\mid x_t,c), \qquad p_\phi(x_{t-\Delta t}\mid x_t,c). $$
In the stochastic SDE view used by DiffusionOPD / Flow-OPD, these one-step transitions are Gaussians with the same covariance:
$$ p_\theta(x_{t-\Delta t}\mid x_t,c) = \mathcal{N}\!\left(\mu_\theta(x_t,t,c),\,\sigma_t^2 I\right), \qquad p_\phi(x_{t-\Delta t}\mid x_t,c) = \mathcal{N}\!\left(\mu_\phi(x_t,t,c),\,\sigma_t^2 I\right). $$
For two Gaussians with the same covariance, the KL is exactly a squared distance between means:
$$ \mathrm{KL} \bigl( p_\theta(x_{t-\Delta t}\mid x_t) \;\|\; p_\phi(x_{t-\Delta t}\mid x_t) \bigr) = \frac{ \|\mu_\theta(x_t,t,c)-\mu_\phi(x_t,t,c)\|^2 }{ 2\sigma_t^2 }. $$
Now connect the transition mean to the velocity. For a simple Euler flow update,
$$ \mu_\theta(x_t,t,c) \approx x_t - \Delta t\,v_\theta(x_t,t,c), \qquad \mu_\phi(x_t,t,c) \approx x_t - \Delta t\,v_\phi(x_t,t,c). $$
The shared $x_t$ term cancels:
$$ \|\mu_\theta-\mu_\phi\|^2 \approx \Delta t^2 \|v_\theta(x_t,t,c)-v_\phi(x_t,t,c)\|^2. $$
Therefore the diffusion analogue of the per-token OPD KL is not an arbitrary MSE. It is the closed-form transition KL:
$$ \mathrm{KL} \bigl( p_\theta(x_{t-\Delta t}\mid x_t,c) \;\|\; p_\phi(x_{t-\Delta t}\mid x_t,c) \bigr) \approx w(t)\, \|v_\theta(x_t,t,c)-v_\phi(x_t,t,c)\|^2, \qquad w(t)=\frac{\Delta t^2}{2\sigma_t^2}. $$
In the deterministic ODE limit, there is no stochastic transition density to take KL over, so DiffusionOPD uses the matching part directly:
$$ \|\mu_\theta(x_t,t,c)-\mu_\phi(x_t,t,c)\|^2 \quad\text{or equivalently}\quad \|v_\theta(x_t,t,c)-v_\phi(x_t,t,c)\|^2 $$
up to scheduler-dependent constants.
So the bridge is:
$$ \text{LLM token KL} \;\longrightarrow\; \text{diffusion transition KL} \;\longrightarrow\; \text{closed-form velocity/transition MSE}. $$
This is the main bridge used by DiffusionOPD and Flow-OPD: closed-form local distillation replaces high-variance policy gradients.
3. Single-teacher OPD on the 8-mode ring¶
We now move to the same visual world as the RAM and Tilt Matching notebooks: an 8-mode Gaussian ring.
- The student is a flow-matching model trained on all 8 modes.
- A teacher is a separately trained flow-matching model that was only ever trained on two of those modes.
OPD asks the teacher for a dense velocity target on states that the current student visits, then updates the student toward that teacher.
A note on closed-form vs learned: the Gaussian-mixture density has a clean analytic form. We use that closed form only for illustration (to draw the target densities) and for the evaluation metric (to score how well the student matches a target). The teachers used inside the OPD loop are trained models, never the analytic field. This keeps the demo faithful to the real papers, where the teacher is always a model.
This section has one purpose: show the simplest diffusion OPD loop before adding multiple teachers.
RING_MODES = 8
RING_RADIUS = 4.0
RING_ANGLES = 2 * math.pi * torch.arange(RING_MODES, device=DEVICE) / RING_MODES
RING_CENTERS = torch.stack(
[RING_RADIUS * torch.cos(RING_ANGLES), RING_RADIUS * torch.sin(RING_ANGLES)],
dim=-1,
)
COMPONENT_STD = 0.35
# Teacher A knows modes 0 and 2. Teacher B knows modes 1 and 3.
# The two teachers are non-overlapping, and together they cover four ring modes.
TEACHER_MODES = {
0: [0, 2],
1: [1, 3],
}
def teacher_weights_for(route: int | None) -> torch.Tensor | None:
if route is None:
return None
weights = torch.zeros(RING_MODES, device=DEVICE)
weights[TEACHER_MODES[route]] = 1.0 / len(TEACHER_MODES[route])
return weights
def sample_mixture(
centers: torch.Tensor,
n: int,
std: float,
weights: torch.Tensor | None = None,
) -> torch.Tensor:
if weights is None:
idx = torch.randint(0, len(centers), (n,), device=centers.device)
else:
idx = torch.multinomial(weights, n, replacement=True)
return centers[idx] + std * torch.randn(n, 2, device=centers.device)
def _log_normal(x: torch.Tensor, mean: torch.Tensor, var) -> torch.Tensor:
d = x - mean
if torch.is_tensor(var):
return -0.5 * (d * d).sum(-1) / var - x.shape[-1] * 0.5 * (
math.log(2 * math.pi) + torch.log(var)
)
return -0.5 * (d * d).sum(-1) / var - x.shape[-1] * 0.5 * math.log(2 * math.pi * var)
def mixture_logprob(
x: torch.Tensor,
centers: torch.Tensor,
std: float,
weights: torch.Tensor | None = None,
) -> torch.Tensor:
'''Closed-form mixture log-density.
IMPORTANT: closed-form expressions are used in this notebook ONLY for
illustration (drawing target densities) and for the evaluation metric below.
They are never used as the teacher signal. The teachers in the OPD loop are
trained flow-matching models, exactly as in the real papers.
'''
logs = torch.stack([_log_normal(x, c, std**2) for c in centers], dim=1)
if weights is None:
logs = logs - math.log(len(centers))
else:
logs = logs + torch.log(weights.clamp_min(1e-12))[None, :]
return torch.logsumexp(logs, dim=1)
def target_logprob(routes: list[int], x: torch.Tensor) -> torch.Tensor:
# Log-density of the merged target: the uniform MIXTURE of the teacher
# distributions, log((1/R) sum_r p_r(x)). (Averaging log-densities instead
# would reward samples that sit *between* teachers, which is the opposite of
# what we want.)
logs = torch.stack([
mixture_logprob(x, RING_CENTERS, COMPONENT_STD, teacher_weights_for(r))
for r in routes
], dim=0)
return torch.logsumexp(logs, dim=0) - math.log(len(routes))
def density_for(route: int | None):
weights = teacher_weights_for(route)
def f(pts):
return torch.exp(mixture_logprob(pts, RING_CENTERS, COMPONENT_STD, weights))
return f
# Closed-form target densities. These plots are for illustration only; the OPD
# teachers below are trained models, not these analytic fields.
fig, axes = plt.subplots(1, 3, figsize=(14, 4.8), constrained_layout=True)
plot_density(density_for(None), axes[0], title="target density (illustration): 8 modes")
plot_density(density_for(0), axes[1], title="teacher A target: modes 0 and 2")
plot_density(density_for(1), axes[2], title="teacher B target: modes 1 and 3")
plt.show()
class VelocityNet(nn.Module):
def __init__(self, hidden: int = 128, n_freqs: int = 8):
super().__init__()
self.register_buffer("freqs", 2 ** torch.arange(n_freqs).float() * math.pi)
in_dim = 2 + 2 * n_freqs
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.SiLU(),
nn.Linear(hidden, hidden),
nn.SiLU(),
nn.Linear(hidden, hidden),
nn.SiLU(),
nn.Linear(hidden, 2),
)
def time_emb(self, t: torch.Tensor) -> torch.Tensor:
angles = t[:, None] * self.freqs[None, :]
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
if t.ndim == 0:
t = t.expand(x.shape[0])
return self.net(torch.cat([x, self.time_emb(t)], dim=-1))
def euler_sample(model: nn.Module, n: int, n_steps: int = 50) -> torch.Tensor:
x = torch.randn(n, 2, device=DEVICE)
h = 1.0 / n_steps
with torch.no_grad():
for i in range(n_steps):
t = torch.full((n,), 1.0 - i * h, device=DEVICE)
x = x - h * model(x, t)
return x
def rollout_states(
model: nn.Module,
n: int,
n_steps: int = 12,
sde_noise: float = 0.0,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
x = torch.randn(n, 2, device=DEVICE)
h = 1.0 / n_steps
states, times = [], []
with torch.no_grad():
for i in range(n_steps):
t = torch.full((n,), 1.0 - i * h, device=DEVICE)
states.append(x.clone())
times.append(t)
x = x - h * model(x, t)
if sde_noise > 0 and i < n_steps - 1:
x = x + sde_noise * math.sqrt(h) * torch.randn_like(x)
return states, times
def train_flow_model(
weights: torch.Tensor | None = None,
*,
steps: int = 5000,
batch_size: int = 256,
lr: float = 1e-3,
desc: str = "pretrain",
) -> VelocityNet:
'''Train a flow-matching VelocityNet on a (weighted) subset of ring modes.
weights=None trains on all 8 modes (the base student). A two-mode weight
vector trains a teacher that only knows those two modes.
'''
model = VelocityNet().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=lr)
iterator = range(steps)
if SHOW_PROGRESS:
iterator = tqdm(iterator, desc=desc)
for _ in iterator:
x0 = sample_mixture(RING_CENTERS, batch_size, COMPONENT_STD, weights)
eps = torch.randn_like(x0)
t = torch.rand(batch_size, device=DEVICE)
xt = (1 - t[:, None]) * x0 + t[:, None] * eps
target = eps - x0
loss = (model(xt, t) - target).pow(2).mean()
opt.zero_grad()
loss.backward()
opt.step()
return model
# The base student knows all 8 modes.
base_student = train_flow_model(desc="pretrain base student")
# Two teacher models, each TRAINED on its own two modes. These learned networks
# are the teachers used by OPD below -- not the analytic field.
teacher_models = {
route: train_flow_model(teacher_weights_for(route), steps=4000, desc=f"teacher {route}")
for route in TEACHER_MODES
}
def teacher_velocity(route: int, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return teacher_models[route](x, t)
def base_velocity(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
# Flow-OPD's manifold anchor is a frozen model; here it is the frozen base student.
with torch.no_grad():
return base_student(x, t)
print(f"base / teacher parameters: {sum(p.numel() for p in base_student.parameters()):,} each")
base / teacher parameters: 35,714 each
# Sanity check: what the LEARNED models actually sample.
fig, axes = plt.subplots(1, 3, figsize=(14, 4.8), constrained_layout=True)
plot_samples(euler_sample(base_student, 4000), axes[0], title="learned base student: 8 modes", color="C0")
plot_samples(euler_sample(teacher_models[0], 4000), axes[1], title="learned teacher A: modes 0, 2", color="C1")
plot_samples(euler_sample(teacher_models[1], 4000), axes[2], title="learned teacher B: modes 1, 3", color="C2")
plt.show()
def mode_mass(samples: torch.Tensor) -> tuple[list[float], float]:
nearest = torch.cdist(samples, RING_CENTERS).argmin(dim=1)
counts = torch.bincount(nearest, minlength=RING_MODES).float()
mass = counts / counts.sum().clamp_min(1.0)
effective_modes = float(1.0 / mass.pow(2).sum().clamp_min(1e-12))
return [float(v) for v in mass], effective_modes
def evaluate_model(model: nn.Module, routes: list[int], query_units: int = 0) -> dict:
samples = euler_sample(model, 5000)
mass, effective_modes = mode_mass(samples)
teacher_modes = sorted(set(sum([TEACHER_MODES[r] for r in routes], [])))
coverage = float(torch.tensor(mass)[teacher_modes].sum().item())
return {
"samples": samples,
"avg_target_logp": target_logprob(routes, samples).mean().item(),
"teacher_mode_coverage": coverage,
"effective_modes": effective_modes,
"mass": mass,
"query_units": query_units,
}
base_single_eval = evaluate_model(base_student, [0])
3.1 Distill one two-mode teacher into the student¶
For the single-teacher case, we use the simplest OPD update: roll out the current student, query teacher A (the trained two-mode model) on every trajectory state, and regress the student velocity toward the teacher velocity.
This is the one-teacher version of DiffusionOPD's dense transition matching. The expected behavior is clear: the 8-mode student should move most of its mass toward the two modes teacher A knows.
@dataclass
class OPDConfig:
steps: int = 1500
batch_size: int = 128
rollout_steps: int = 12
lr: float = 8e-4
flow_sde_noise: float = 0.05
flow_anchor: float = 0.02
def semantic_query_index(n_steps: int) -> int:
s = torch.distributions.Beta(5.0, 2.0).sample().item()
return min(n_steps - 1, max(1, int(s * (n_steps - 1))))
# The target modes each teacher is responsible for (used by the router below).
TARGET_CENTERS = {r: RING_CENTERS[TEACHER_MODES[r]] for r in TEACHER_MODES}
def route_samples(x: torch.Tensor, routes: list[int]) -> torch.Tensor:
'''Assign each rolled-out sample to the teacher responsible for its region.
The student is a single UNCONDITIONAL field, so something has to decide which
teacher supervises each sample. Real systems read this off the prompt/domain
(DiffusionOPD, Flow-OPD) or an explicit router (DanceOPD). With no prompt in
this 2D toy, the router uses the nearest target mode as a stand-in.
'''
dists = torch.stack(
[torch.cdist(x, TARGET_CENTERS[r]).min(dim=1).values for r in routes], dim=1
)
return torch.tensor(routes, device=x.device)[dists.argmin(dim=1)]
def routed_teacher_velocity(x: torch.Tensor, t: torch.Tensor, routes: list[int]) -> torch.Tensor:
'''Per-sample teacher target: each sample is supervised by its routed teacher.'''
chosen = route_samples(x, routes)
target = torch.empty_like(x)
for r in routes:
mask = chosen == r
if mask.any():
target[mask] = teacher_velocity(r, x[mask], t[mask])
return target
def teacher_target(x, t, routes, supervise):
if supervise == "average":
# Naive baseline: supervise every sample with the MEAN of both teacher
# velocities. Averaging two disagreeing fields points the student between
# the teachers, pulling modes toward their midpoint (visualized below).
return torch.stack([teacher_velocity(r, x, t) for r in routes]).mean(dim=0)
return routed_teacher_velocity(x, t, routes)
def train_opd_variant(
init_model: nn.Module,
*,
routes: list[int],
method: str,
supervise: str = "router",
cfg: OPDConfig = OPDConfig(),
) -> tuple[nn.Module, dict]:
model = deepcopy(init_model)
opt = torch.optim.Adam(model.parameters(), lr=cfg.lr)
rollouts = 0
queries = 0
iterator = range(cfg.steps)
if SHOW_PROGRESS:
iterator = tqdm(iterator, desc=method)
for step in iterator:
# Equal-rollout comparison: every method does exactly ONE student rollout
# per step. The rollout is the shared student's own trajectory (on-policy);
# the router then assigns each visited sample to its responsible teacher.
# Rollout cost is identical for all three methods; only the number of
# queried states per rollout differs.
states, times = rollout_states(
model,
cfg.batch_size,
cfg.rollout_steps,
sde_noise=cfg.flow_sde_noise if method == "flowopd" else 0.0,
)
rollouts += cfg.batch_size
query_indices = (
[semantic_query_index(cfg.rollout_steps)]
if method == "danceopd"
else list(range(cfg.rollout_steps))
)
per_query_losses = []
for idx in query_indices:
x = states[idx].detach()
t = times[idx]
pred = model(x, t)
target = teacher_target(x, t, routes, supervise).detach()
per_sample = (pred - target).pow(2).mean(dim=1)
if method == "flowopd":
time_weight = (1.2 - t).clamp(0.2, 1.2)
# Manifold anchor: a small pull toward the frozen base model.
anchor = base_velocity(x, t).detach()
per_sample = (
time_weight * per_sample
+ cfg.flow_anchor * (pred - anchor).pow(2).mean(dim=1)
)
per_query_losses.append(per_sample.mean())
queries += cfg.batch_size
loss = torch.stack(per_query_losses).mean()
opt.zero_grad()
loss.backward()
opt.step()
return model, {"rollouts": rollouts, "query_units": queries}
torch.manual_seed(123)
single_teacher_student, single_hist = train_opd_variant(
base_student,
routes=[0],
method="diffusionopd",
)
single_eval = evaluate_model(single_teacher_student, [0], single_hist["query_units"])
display_markdown_table(
["model", "student rollouts", "teacher queries", "target logp ↑", "teacher-mode mass ↑", "effective modes"],
[
[
"8-mode student initialization",
"0",
"0",
f"{base_single_eval['avg_target_logp']:.2f}",
f"{base_single_eval['teacher_mode_coverage']:.2f}",
f"{base_single_eval['effective_modes']:.2f}",
],
[
"after single-teacher OPD",
f"{single_hist['rollouts']:,}",
f"{single_hist['query_units']:,}",
f"{single_eval['avg_target_logp']:.2f}",
f"{single_eval['teacher_mode_coverage']:.2f}",
f"{single_eval['effective_modes']:.2f}",
],
],
)
fig, axes = plt.subplots(1, 3, figsize=(14, 4.8), constrained_layout=True)
plot_density(density_for(0), axes[0], title="teacher A target")
plot_samples(base_single_eval["samples"], axes[1], title="before OPD: 8 modes", color="C0")
plot_samples(single_eval["samples"], axes[2], title="after OPD: 2 modes from teacher A", color="C1")
plt.show()
| model | student rollouts | teacher queries | target logp ↑ | teacher-mode mass ↑ | effective modes |
|---|---|---|---|---|---|
| 8-mode student initialization | 0 | 0 | -20.93 | 0.27 | 7.93 |
| after single-teacher OPD | 192,000 | 2,304,000 | -1.38 | 1.00 | 2.01 |
4. Multi-teacher OPD: DiffusionOPD, Flow-OPD, and DanceOPD¶
Now we add teacher B. Teacher A knows modes 0 and 2; teacher B knows modes 1 and 3. They have no overlapping modes. The goal is a single student that covers the four modes known by the two teachers.
This is the multi-reward integration setting:
- no manual reward selector at inference;
- one student model;
- the teachers are training-only signals, merged into that single student.
4.1 Why a single student needs a router¶
Our student is one unconditional velocity field, but the two teachers disagree almost everywhere. At a state drifting toward mode 0, teacher A says "head to mode 0" while teacher B says "head to mode 1". If we supervise the student with the average of the two teacher velocities, it learns a field that points between the teachers -- so the modes collapse toward their midpoints and land off the true mode centers (red crosses, left panel below).
The fix is a router: for each rolled-out sample, decide which teacher is
responsible and use only that teacher's velocity. Every region of space is then
supervised by exactly one teacher, so the modes stay sharp and on-center (right
panel). This routing is exactly what DanceOPD's WeightedRouter does; DiffusionOPD
and Flow-OPD get the same effect from the prompt/domain label. In this 2D toy,
with no prompt to read, the router uses the nearest target mode as a stand-in for
that signal (see route_samples above).
# Same method (dense matching), same budget -- only the supervision differs.
torch.manual_seed(123)
naive_student, _ = train_opd_variant(
base_student, routes=[0, 1], method="diffusionopd", supervise="average"
)
torch.manual_seed(123)
routed_student, _ = train_opd_variant(
base_student, routes=[0, 1], method="diffusionopd", supervise="router"
)
naive_eval = evaluate_model(naive_student, [0, 1])
routed_eval = evaluate_model(routed_student, [0, 1])
teacher_modes = sorted(set(TEACHER_MODES[0] + TEACHER_MODES[1]))
centers = RING_CENTERS[teacher_modes].cpu()
fig, axes = plt.subplots(1, 2, figsize=(11, 5.2), constrained_layout=True)
for ax, ev, title in [
(axes[0], naive_eval, "naive: average both teachers"),
(axes[1], routed_eval, "router: one teacher per sample"),
]:
plot_samples(ev["samples"], ax, title=f"{title}\neff modes = {ev['effective_modes']:.2f}", color="C0")
ax.scatter(centers[:, 0], centers[:, 1], marker="x", c="red", s=160,
linewidths=2.5, zorder=5, label="true teacher modes")
ax.legend(loc="lower left", fontsize=8)
plt.show()
4.2 What differs between the three methods?¶
All three share the recipe above (Section 2 + the router): roll out the current student, route each visited state to its teacher, and regress the student's velocity toward that teacher's velocity. They differ in how many states they query and a few flow-specific choices.
DiffusionOPD is the most direct lift of LLM OPD. For each teacher, it rolls out the student and matches the teacher's transition (velocity) at every denoising state along that trajectory. With multiple teachers it visits them in turn (one rollout batch per teacher) and accumulates the losses before a step. This is dense, closed-form transition matching -- the velocity MSE we derived.
Flow-OPD keeps the same dense, all-states matching, but adds flow-specific machinery to stabilize multi-reward training: it rolls out with stochastic (SDE) noise for exploration, uses a time-weighted velocity MSE (the $w(t)$ from the KL derivation), starts from a cold-start init, and adds a Manifold Anchor: a small pull toward a frozen base/aesthetic model so the student does not drift off the data manifold.
DanceOPD keeps the same router but changes the query design. Dense trajectory states are highly correlated (same noise, same path), so querying all of them is expensive and partly redundant. Instead, DanceOPD queries a single low-noise, semantic-side state per rollout. It is the cheapest of the three per rollout, at the cost of less supervision.
To compare them fairly, below we give every method the same rollout budget: exactly one student rollout per step, with the router assigning each sample to its teacher. Only the number of queried states per rollout differs.
| method | student rollouts | queried states per rollout | flow-specific extras |
|---|---|---|---|
| DiffusionOPD-style | one per step | all $N$ trajectory states | dense velocity matching |
| Flow-OPD-style | one per step | all $N$ trajectory states | SDE rollout noise, low-noise time weight, base anchor |
| DanceOPD-style | one per step | 1 low-noise state | low-noise query bias |
The code below uses this table directly. It is still a toy analogue, not an image-scale reproduction.
methods = [
("DiffusionOPD dense", "diffusionopd"),
("Flow-OPD dense + anchor", "flowopd"),
("DanceOPD single query", "danceopd"),
]
multi_evals = {"8-mode student initialization": evaluate_model(base_student, [0, 1])}
multi_hist = {"8-mode student initialization": {"rollouts": 0, "query_units": 0}}
for title, method in methods:
torch.manual_seed(123)
model, hist = train_opd_variant(base_student, routes=[0, 1], method=method)
multi_evals[title] = evaluate_model(model, [0, 1], hist["query_units"])
multi_hist[title] = hist
rows = []
for name, ev in multi_evals.items():
hist = multi_hist[name]
rows.append(
[
name,
f"{hist['rollouts']:,}",
f"{hist['query_units']:,}",
f"{ev['avg_target_logp']:.2f}",
f"{ev['teacher_mode_coverage']:.2f}",
f"{ev['effective_modes']:.2f}",
]
)
display_markdown_table(
[
"method",
"student rollouts",
"teacher queries ↓",
"avg target logp ↑",
"teacher-mode mass ↑",
"effective modes",
],
rows,
)
| method | student rollouts | teacher queries ↓ | avg target logp ↑ | teacher-mode mass ↑ | effective modes |
|---|---|---|---|---|---|
| 8-mode student initialization | 0 | 0 | -14.82 | 0.52 | 7.94 |
| DiffusionOPD dense | 192,000 | 2,304,000 | -1.80 | 1.00 | 3.77 |
| Flow-OPD dense + anchor | 192,000 | 2,304,000 | -1.82 | 1.00 | 3.74 |
| DanceOPD single query | 192,000 | 192,000 | -1.75 | 1.00 | 3.86 |
fig, axes = plt.subplots(1, 4, figsize=(17, 4.8), constrained_layout=True)
columns = ["8-mode student initialization"] + [title for title, _ in methods]
target_centers = RING_CENTERS[sorted(set(TEACHER_MODES[0] + TEACHER_MODES[1]))].cpu()
for col, name in enumerate(columns):
ev = multi_evals[name]
plot_samples(
ev["samples"],
axes[col],
title=f"{name}\ncoverage={ev['teacher_mode_coverage']:.2f}, eff={ev['effective_modes']:.1f}",
color=f"C{col}",
)
axes[col].scatter(target_centers[:, 0], target_centers[:, 1], marker="x",
c="red", s=120, linewidths=2.2, zorder=5)
plt.show()
4.3 Takeaway¶
The single-teacher case showed the basic OPD effect: the student moves from the full 8-mode ring toward the two modes of one teacher.
In the multi-teacher table, all three methods use the same number of student rollouts and all reach full coverage of the four teacher modes. What differs is teacher queries: DiffusionOPD and Flow-OPD query every one of the $N$ trajectory states, while DanceOPD queries a single low-noise state -- about $N\times$ fewer queries.
But rollout cost is the same. This is the key caveat. A teacher query is one extra forward pass on an already-computed state; the expensive part is generating the student rollout in the first place. DanceOPD still needs a full $N$-step rollout to reach its single low-noise query state. So its saving is:
$$ \text{one rollout} + \text{one teacher query} \qquad\text{vs.}\qquad \text{one rollout} + N \text{ teacher queries}. $$
In other words, DanceOPD trims the dense supervision cost, not the rollout cost. If sampling dominates wall-clock (which it usually does at image scale), the real speedup from "one query instead of $N$" is much smaller than the raw query-count ratio suggests.
A note on DanceOPD's actual claim. Our toy only demonstrates the cost story above: all three methods reach full coverage, so this ring is too easy to separate them on quality. But DanceOPD's headline claim is a quality claim, not just a cheaper-query one. Their ablations argue the single low-noise query is a cleaner supervision signal, not merely a cheaper one: dense trajectory states are correlated (same noise, same path), so they add little independent signal and can bias the update toward a compromise, while high-noise states carry mostly generic denoising rather than capability-specific information. On their image benchmarks they report the single low-noise query beating dense weighted queries. We cannot surface that gap in 2D -- see the DanceOPD paper for the full ablations.
Finally, recall the router from Section 4.1: it is what keeps these modes sharp and on-center. Without it, averaging the two teachers on shared states pulls the modes inward to a compromise -- the failure in the left panel of the Section 4.1 demo.
5. Combining OPD with reward-based methods: Tilt Matching and RAM¶
OPD is a teacher-field method: it composes velocity fields it is handed. It never reads a raw reward. So if your only supervision is a black-box scalar reward $r(x)$, you first need a way to turn that reward into a teacher field. Two companion tutorials do exactly that:
- Tilt Matching / ITM learns the reward-tilted distribution $p_{\text{target}}(x)\propto p_{\text{ref}}(x)\exp(\beta\, r(x))$ directly.
- RAM keeps a regression loss and nudges the field toward higher-reward endpoints -- no reward gradients, no SDE rollouts.
Either one turns "reward $A$" into a concrete specialist field $A$ -- exactly the input OPD wants.
| stage | method | needs | produces |
|---|---|---|---|
| build a specialist | Tilt Matching or RAM | a black-box reward + a reference sampler | one reward-specialized field |
| compose specialists | DiffusionOPD / Flow-OPD / DanceOPD | several frozen specialist fields | one student that covers them all |
5.1 Why not just combine the rewards directly?¶
The direct alternative skips the specialists and optimizes one model on all rewards at once. The simplest version is a fixed linear blend $r=\sum_j w_j r_j$, which forces you to hand-balance reward scales and collapses per-sample signal. MARBLE does it more carefully: it keeps per-reward gradients and combines them with MGDA into a single minimum-norm common-descent direction, so the balance is set automatically each step instead of tuned in advance.
But as the multi-reward sweep in the RAM tutorial shows, this does not fully remove the interference. MGDA/MARBLE returns one point on the Pareto front. When the rewards genuinely conflict (near-disjoint optima), the combined gradient is a compromise: the model gives up roughly half of each reward's peak to cover both. A single model trained on the merged gradient cannot be excellent at conflicting objectives at once -- the averaged direction pulls it to the middle.
This is the same effect we saw in Section 4.1: averaging two disagreeing teachers pulled the student's modes inward to a compromise, off every true center. Combining conflicting reward gradients is the reward-space version of that averaging.
5.2 Specialize, then compose¶
The router from Section 4.1 is what broke that compromise, and the same idea applies in reward space. Instead of merging rewards into one gradient, specialize then compose:
- Specialize. Train one full-strength specialist per reward with Tilt Matching or RAM (say, one for text rendering, one for aesthetics). Each reaches its own optimum, with no inter-reward compromise.
- Compose. Distill the specialists into one student with OPD, using a router to send each sample to its responsible specialist -- exactly as in Section 4.
Because the router partitions the work, the single student can sit on every specialist's optimum -- a union of capabilities -- rather than MARBLE's single averaged trade-off point. At inference it is still one model, one forward pass.
The honest caveat. This is not a free lunch. Routing only helps when something can tell which capability a sample needs -- the prompt or domain in a real system, the nearest mode in our toy. For rewards that genuinely conflict on the same input, every method (MARBLE or ITM/RAM + OPD) must still trade off. And "specialize then compose" is two stages, so it costs more than MARBLE's single stage. The win is for the common case where multi-reward really means many capabilities, each for its own slice of inputs -- where one deployed model should cover them all, not blend them into one compromise.
6. Recap¶
The whole story:
- LLM OPD samples from the student and asks a teacher for dense feedback on the student's own prefixes.
- DiffusionOPD lifts this to diffusion trajectories: match teacher transitions along the student's own denoising path.
- Flow-OPD gives a flow-matching multi-teacher pipeline with stochastic exploration, time-weighted field matching, cold start, and manifold anchoring.
- DanceOPD keeps the on-policy field idea but changes the query design: hard-route the sample and use one low-noise semantic-side query.
- Specialize, then compose (Tilt Matching or RAM, then OPD) is a promising route for multi-reward: build full-strength reward specialists, then let OPD's router compose them into one student -- covering every capability instead of settling for a single averaged compromise.
The practical lesson is:
OPD is not just "distill teachers." It is where and how densely we ask teachers to supervise the states visited by the current student.