This post will be a high-level overview of the main ideas behind the paper, Generative Modeling via Drifting.

Setup

Generative modelling can be formulated as learning a function ff that maps a known prior distribution ppriorp_{\text{prior}} to a target distribution (usually the data itself) pdatap_{\text{data}}.

Let's assume that the prior distribution is a standard Gaussian (standard in the generative models) and the target distribution is the data distribution. Since training a neural network via gradient descent is inherently an iterative process, it's reasonable to assume that during training, we'd have access to a function fif_i that maps pϵp_{\epsilon} to qiq_i, which is our approximation to pdatap_{\text{data}} at iteration ii. So, for each iteration, there is a "drift" between a sample at the ii-th iteration and that at the (i+1)(i + 1)-th iteration. We can write:

xi+1=xi+Δi x_{i + 1} = x_i + \Delta_i

where Δi=fi+1(ϵ)fi(ϵ)\Delta_i = f_{i + 1}(\epsilon) - f_i(\epsilon)

If Δi=0\Delta_i = 0, we have reached an equilibrium point where it doesn't make sense to continue training.

Let's make this more formal. The authors introduce the notion of a drfiting field Vp,qV_{p, q} as a way to model Δi\Delta_i, so

xi+1=xi+Vp,q x_{i + 1} = x_i + V_{p, q}

where xi=fi(ϵ)x_i = f_i(\epsilon) and ϵp\epsilon \sim p, and xiqix_i \sim q_i. When p=qp = q, we want the drifting field Vp,qV_{p, q} to be 0. This notion of equilibrium motivates an update rule:

fθi+1(ϵ)=fθi(ϵ)+Vp,q(fθi(ϵ)) f_{\theta_{i + 1}}(\epsilon) = f_{\theta_i}(\epsilon) + V_{p, q}(f_{\theta_i}(\epsilon))

Here θi\theta_i are the parameters of the model at iteration ii. The loss function is the mean-squared error of xi+1x_{i + 1} and xix_i:

L=Eϵ[fθ(ϵ)stopgrad(fθ(ϵ)+Vp,q(fθ(ϵ)))] \mathcal{L} = \mathbb{E}_{\epsilon}\left[ f_{\theta}(\epsilon) - \text{stopgrad}\left(f_{\theta}(\epsilon) + V_{p, q}(f_{\theta}(\epsilon))\right) \right]

We use a stopgrad since we want to freeze our target and move our current predicted samples towards it.

Designing the Drifting Field

Now how should we design VV? Recall that we wish to find our In the paper, they give a sufficient (but not necessary) condition: make VV anti-symmetric. To see this, note that p=q    Vp,q=Vq,p=Vp,q=0p = q \implies V_{p, q} = -V_{q, p} = -V_{p, q} = 0.

The paper considers drifting fields of the form:

Vp,q(x)=Ey+pEyq[K(x,y+,y)] V_{p, q}(x) = \mathbb{E}_{y^+ \sim p}\mathbb{E}_{y^- \sim q}[\mathcal{K}(x, y^+, y^-)]

where K\mathcal{K} is a function describing interactions on sampled points from the target distribution and the current estimate of it. Here K\mathcal{K} just needs to be 0 when p=qp = q.

The authors define fields

Vp+(x)=1ZpEp[k(x,y+)(y+x)] V_p^+(x) = \frac{1}{Z_p}\mathbb{E}_p[k(x, y^+)(y^+ - x)] Vq(x)=1ZqEq[k(x,y)(yx)] V_q^-(x) = \frac{1}{Z_q}\mathbb{E}_q[k(x, y^-)(y^- - x)]

and decompose Vp,q(x)V_{p, q}(x) into a difference Vp+(x)Vq(x)V_p^+(x) - V_q^-(x). Here ZpZ_p and ZqZ_q are normalization constants Ep[k(x,y+)]\mathbb{E}_p[k(x, y^+)] and Eq[k(x,y)]\mathbb{E}_q[k(x, y^-)], respectively. We can rewrite Vp,q(x)V_{p, q}(x) as follows:

Vp,q(x)=Vp+(x)Vq(x) V_{p, q}(x) = V_p^+(x) - V_q^-(x) =1ZpEp[k(x,y+)(y+x)]1ZqEq[k(x,y)(yx)] = \frac{1}{Z_p}\mathbb{E}_p[k(x, y^+)(y^+ - x)] - \frac{1}{Z_q}\mathbb{E}_q[k(x, y^-)(y^- - x)] =1ZpZq[ZqEp[k(x,y+)(y+x)]ZpEq[k(x,y)(yx)]] = \frac{1}{Z_pZ_q}\left[ Z_q\mathbb{E}_p[k(x, y^+)(y^+ - x)] - Z_p\mathbb{E}_q[k(x, y^-)(y^- - x)] \right] =1ZpZq[Eq[k(x,y)]Ep[k(x,y+)(y+x)]Ep[k(x,y+)]Eq[k(x,y)(yx)]] = \frac{1}{Z_pZ_q}\left[ \mathbb{E}_q[k(x, y^-)]\mathbb{E}_p[k(x, y^+)(y^+ - x)] - \mathbb{E}_p[k(x, y^+)]\mathbb{E}_q[k(x, y^-)(y^- - x)] \right] =1ZpZq[Ep,q[k(x,y)k(x,y+)(y+x)]Ep,q[k(x,y+)k(x,y)(yx)]] = \frac{1}{Z_pZ_q}\left[ \mathbb{E}_{p, q}[k(x, y^-)k(x, y^+)(y^+ - x)] - \mathbb{E}_{p, q}[k(x, y^+)k(x, y^-)(y^- - x)] \right] =1ZpZqEp,q[k(x,y+)k(x,y)(y+y)] = \frac{1}{Z_pZ_q}\mathbb{E}_{p, q}[k(x, y^+)k(x, y^-)(y^+ - y^-)]

The reason we can rewrite terms Ep[f(y+)]E_p[f(y^+)], Eq[g(y)]E_q[g(y^-)] into a combined expression Ep,q[f(y+)g(y)]E_{p, q}[f(y^+)g(y^-)] is because the variables y+py^+\sim p, yqy^-\sim q are drawn separately, so their product can be written as an expectation over the product measure p×qp \times q.

From this, it's immediately obvious that Vp,qV_{p, q} is anti-symmetric, so our training objective discussed earlier is well-defined. The kernel kk they used was

k(x,y)=exp(xyτ) k(x, y) = \text{exp}\left(\frac{\lVert x - y \rVert}{\tau}\right)

Implementation Details

We give a concrete implementation of the drifting field above:

def compute_drift_field(x_bd, ypos_bd, yneg_bd, temp: float = 0.05, eps: float = 1e-12):
    """
    Computes the drift field V_pq(f(eps))
    
    :param x_bd: [N, D]
    :param ypos_bd: p, distribution of the data. [N_pos, D]
    :param yneg_bd: q, current distribution. [N_neg, D]
 
    Note that the batch dimensions of the data and generated predictions
    do not have to be the same.
 
    Returns a [N, D] matrix representing a drift field.
    """
 
    targets = jnp.concatenate([yneg_bd, ypos_bd], axis=0)
    N_neg = x_bd.shape[0]
 
    dist = cdist(x_bd, targets)
    # since x_bd is the same as yneg_bd, mask self
    dist = dist.at[:, :N_neg].add(jnp.eye(N_neg) * 1e6)
    kernel = jnp.exp(-dist / temp)
 
    normalizer = jnp.sum(kernel, axis=-1, keepdims=True) * jnp.sum(kernel, axis=-2, keepdims=True)
    normalizer = jnp.sqrt(jnp.clip(normalizer, a_min=eps))
    normalized_kernel = kernel / normalizer
 
    K_neg, K_pos = jnp.split(normalized_kernel, [N_neg,], axis=1)
 
    pos_coeff = K_pos * jnp.sum(K_neg, axis=-1, keepdims=True)
    V_pos = pos_coeff @ ypos_bd
    neg_coeff = K_neg * jnp.sum(K_pos, axis=-1, keepdims=True)
    V_neg = neg_coeff @ yneg_bd
 
    return V_pos - V_neg

To see how this can be written in the form 1ZpZqEp,q[k(x,y+)k(x,y)(y+y)]\frac{1}{Z_pZ_q}\mathbb{E}_{p, q}[k(x, y^+)k(x, y^-)(y^+ - y^-)], observe that Vp+(x)=iαi+(jαj)y+V_p^+(x) = \sum_{i}\alpha_i^+\left(\sum_j\alpha_j^-\right)y^+ and Vq(x)=iαj(iαi+)yV_q^-(x) = \sum_{i}\alpha_j^-\left(\sum_i\alpha_i^+\right)y^-

Combining the two, we get that

Vp,q(x)=i,jαi+αj(y+y) V_{p, q}(x) = \sum_{i, j}\alpha_i^+\alpha_j^- (y^+ - y^-)

which is exactly the form that we wanted. Here it's clear that αi+\alpha_i^+ and αi\alpha_i^- represent k(x,y+)k(x, y^+) and k(x,y)k(x, y^-) respectively.

What about the normalization terms ZpZ_p and ZqZ_q. When we compute kk, we bake in the normalization by summing across the concatenated y=[y+;y]y = [y^+; y^-] dimension and dividing by that constant. We also normalize over the batch dimension as it improved training dynamics.

Experiments on Toy Distributions

A JAX implementation of the above can be found here:

I was able to train the model on the toy distributions (chessboard, spiral). After 2.5k iterations, the drifting model produced some pretty neat results.

Training on Chessboard Data
Training on Spiral Data

Extending this to Actual Large-Scale Image Datasets

Much like other work such as Diffusion Transformers, this method can work on encoded features instead of raw pixels. Let ϕ\phi be a feature encoder. We can define our new loss as:

Lϕ=Eϵ[ϕ(x)stopgrad(ϕ(x)+Vp,q(ϕ(x)))] \mathcal{L_{\phi}} = \mathbb{E}_{\epsilon}\left[ \phi(x) - \text{stopgrad}\left(\phi(x) + V_{p, q}(\phi(x))\right) \right]

where ϕ(x)=fθ(ϵ)\phi(x) = f_{\theta}(\epsilon). The update rule is changed similarly.