Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 57 additions & 46 deletions content/90.back-matter.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ This appendix gives full proofs for Proposition 1 and Corollary 1. We keep the w

### A.0 Preliminaries and identities

- **Joint features.** For any pair $(x,c)$, define
$$
- **Joint features.** For any pair $(x,c)$, define
$$
\psi(x,c):=x\otimes \phi(c)\in\mathbb{R}^{d_x d_c}.
$$
For each indexed training example $a$ (standing in for $(i,j)$), write $\psi_a:=\psi(x_a,c_a)$.

- **Design/labels/weights.** Stack $N=\sum_i m_i$ training rows:
$$
Z\in\mathbb{R}^{N\times d_x d_c}\ \text{ with rows } Z_a=\psi_a^\top,\qquad
Z\in\mathbb{R}^{N\times d_x d_c}\ \text{ with rows } Z_a=\psi_a^T,\qquad
y\in\mathbb{R}^{N},\qquad
W=\mathrm{diag}(w_a)\in\mathbb{R}^{N\times N},\ w_a\ge 0.
$$
Define the (unweighted) Gram matrix \(K:=ZZ^\top\) and the weighted Gram
Define the (unweighted) Gram matrix $K:=ZZ^\top$ and the weighted Gram
$$
K_W:=W^{1/2} K\, W^{1/2} \;=\; W^{1/2} Z Z^\top W^{1/2}.
$$
Expand All @@ -36,15 +36,15 @@ $$
\langle \mathrm{vec}(B),\,x\otimes z\rangle = x^\top B z.
$$

- **Weighted ridge solution.** For any \(X\in\mathbb{R}^{N\times p}\), ridge objective
- **Weighted ridge solution.** For any $X\in\mathbb{R}^{N\times p}$, ridge objective
$$
\min_\beta \ \|W^{1/2}(y-X\beta)\|_2^2+\lambda\|\beta\|_2^2
$$
has unique minimizer \(\widehat\beta=(X^\top W X+\lambda I)^{-1}X^\top W y\) and equivalent dual form
has unique minimizer $\widehat\beta=(X^\top W X+\lambda I)^{-1}X^\top W y$ and equivalent dual form
$$
\widehat\beta = X^\top W^{1/2}\big(W^{1/2}XX^\top W^{1/2}+\lambda I\big)^{-1}W^{1/2}y.
$$
Predictions for a new feature vector \(x_\star\) equal
Predictions for a new feature vector $x_\star$ equal
$$
\widehat f(x_\star)=x_\star^\top \widehat\beta
\;=\;
Expand All @@ -54,15 +54,21 @@ $$
$$
This is **kernel ridge regression** (KRR) with kernel $K_W=W^{1/2}XX^\top W^{1/2}$ and query vector $k_\star=W^{1/2}X x_\star$.


---

### A.1 Proof of Proposition 1(A): explicit varying-coefficients ⇔ weighted KRR on joint features

Assume the linear, squared-loss setting with $y=\langle \theta(c),x\rangle+\varepsilon$ and $\mathbb{E}[\varepsilon]=0\). Let the varying-coefficients model be \(\theta(c)=B\,\phi(c)\) with \(B\in\mathbb{R}^{d_x\times d_c}$ and ridge penalty $\lambda\|B\|_F^2$.
Assume the linear, squared-loss setting with $y=\langle \theta(c),x\rangle+\varepsilon$ and $\mathbb{E}[\varepsilon]=0$.
Let the varying-coefficients model be $\theta(c)=B\,\phi(c)$ with $B\in\mathbb{R}^{d_x\times d_c}$ and ridge penalty $\lambda\|B\|_F^2$.

**Step 1 (reduce to ridge in joint-feature space).** Vectorize \(B\) as \(\beta=\mathrm{vec}(B)\in\mathbb{R}^{d_x d_c}\). By the identity above,
**Step 1 (reduce to ridge in joint-feature space).**
Vectorize $B$ as $\beta=\mathrm{vec}(B)\in\mathbb{R}^{d_x d_c}$.
By the identity above,
$$
x_a^\top B\,\phi(c_a)=\langle \beta,\, x_a\otimes \phi(c_a)\rangle = \langle \beta,\,\psi_a\rangle .
x_a^T B\,\phi(c_a)
= \langle \beta,\, x_a\otimes \phi(c_a)\rangle
= \langle \beta,\,\psi_a\rangle.
$$
Thus the weighted objective specialized from (★) is
$$
Expand All @@ -71,116 +77,121 @@ $$
$$
which is exactly weighted ridge with design $X\equiv Z$.

**Step 2 (closed form and prediction).** By the ridge solution,
**Step 2 (closed form and prediction).**
By the ridge solution,
$$
\widehat\beta=(Z^\top W Z+\lambda I)^{-1} Z^\top W y,
\widehat\beta=(Z^T W Z+\lambda I)^{-1} Z^T W y,
$$
and the prediction at a query $(x,c)$ with joint feature $\psi(x,c)$ is
$$
\widehat y(x,c)=\psi(x,c)^\top \widehat\beta
\widehat y(x,c)=\psi(x,c)^T \widehat\beta
= \underbrace{\big(W^{1/2} Z\, \psi(x,c)\big)}_{k_{(x,c)}}^\top
\big(W^{1/2} Z Z^\top W^{1/2}+\lambda I\big)^{-1} W^{1/2} y.
$$

**Step 3 (kernel form).** Since $K:=ZZ^\top$ and $K_W:=W^{1/2} K W^{1/2}$,
**Step 3 (kernel form).**
Since $K:=ZZ^T$ and $K_W:=W^{1/2} K W^{1/2}$,
$$
\boxed{\ \widehat y(x,c)\;=\; k_{(x,c)}^\top \big(K_W+\lambda I\big)^{-1} W^{1/2}y\ }.
\boxed{\ \widehat y(x,c)\;=\; k_{(x,c)}^T \big(K_W+\lambda I\big)^{-1} W^{1/2}y\ }.
$$
Moreover,
Moreover, the $(a,b)$-th entry of the kernel matrix $K$ is
$$
K_{ab}=\langle \psi_a,\psi_b\rangle
=\big\langle x_a\otimes \phi(c_a),\,x_b\otimes \phi(c_b)\big\rangle
=\langle x_a,x_b\rangle\cdot \langle \phi(c_a),\phi(c_b)\rangle,
$$
so (A) is precisely **KRR on joint features** with sample weights $W$. This proves part (A). \■
so (A) is precisely **KRR on joint features** with sample weights $W$.
This proves part (A). ■


---

### A.2 Proof of Proposition 1(B): linear ICL ⇒ kernel regression

We analyze a single attention layer operating on the weighted support set \(S(c)\), using **linear** maps for queries, keys, and values:
We analyze a single attention layer operating on the weighted support set $S(c)$, using **linear** maps for queries, keys, and values:
$$
q(x,c)=Q\,\psi(x,c),\qquad k_a = K\,\psi_a,\qquad v_a=V\,\psi_a,
$$
with $Q\in\mathbb{R}^{d_q\times d_\psi}$, $K\in\mathbb{R}^{d_k\times d_\psi}$, $V\in\mathbb{R}^{d_v\times d_\psi}$, $d_\psi=d_x d_c$. Let the **unnormalized** attention score for index $a$ be
$$
s_a(x,c):=w_a\,\langle q(x,c),k_a\rangle \;=\; w_a\,\psi(x,c)^\top Q^\top K\,\psi_a .
s_a(x,c):=w_a\,\langle q(x,c),k_a\rangle \;=\; w_a\,\psi(x,c)^T Q^T K\,\psi_a .
$$
Define normalized weights $\alpha_a(x,c):=s_a(x,c)/\sum_b s_b(x,c)$ (or any fixed positive normalization; the form below is pointwise in $\{\alpha_a\}$). The context representation and scalar prediction are
$$
z(x,c)=\sum_a \alpha_a(x,c)\, v_a,\qquad \widehat y(x,c)=u^\top z(x,c).
z(x,c)=\sum_a \alpha_a(x,c)\, v_a,\qquad \widehat y(x,c)=u^T z(x,c).
$$

We prove two statements: **(B1)** exact KRR if the attention maps are fixed and only the readout is trained, and **(B2)** kernel regression with the NTK if the attention parameters are trained in the linearized regime.

#### A.2.1 (B1) Fixed attention, trained linear head ⇒ exact KRR

Assume \(Q,K,V\) are fixed functions (pretrained or chosen a priori), hence \(\alpha_a(x,c)\) are **deterministic** functions of \((x,c)\) and the support set. Define the induced **feature map**
Assume $Q,K,V$ are fixed functions (pretrained or chosen a priori), hence $\alpha_a(x,c)$ are **deterministic** functions of $(x,c)$ and the support set. Define the induced **feature map**
$$
\varphi(x,c):=\sum_a \alpha_a(x,c)\, v_a \;\in\; \mathbb{R}^{d_v}.
$$
Stack \(\varphi_a:=\varphi(x_a,c_a)\) row-wise into \(\Phi\in\mathbb{R}^{N\times d_v}\). Training only the readout \(u\) with weighted ridge,
Stack $\varphi_a:=\varphi(x_a,c_a)$ row-wise into $\Phi\in\mathbb{R}^{N\times d_v}$. Training only the readout $u$ with weighted ridge,
$$
\widehat u \in \arg\min_u \ \|W^{1/2}(y-\Phi u)\|_2^2+\lambda \|u\|_2^2
$$
yields \(\widehat u=(\Phi^\top W \Phi + \lambda I)^{-1}\Phi^\top W y\) and predictions
yields $\widehat u=(\Phi^T W \Phi + \lambda I)^{-1}\Phi^T W y$ and predictions
$$
\widehat y(x,c)=\varphi(x,c)^\top \widehat u
= \underbrace{\big(W^{1/2}\Phi\,\varphi(x,c)\big)}_{k_{(x,c)}}^\top
\big(W^{1/2}\Phi\Phi^\top W^{1/2}+\lambda I\big)^{-1} W^{1/2} y.
\widehat y(x,c)=\varphi(x,c)^T \widehat u
= \underbrace{\big(W^{1/2}\Phi\,\varphi(x,c)\big)}_{k_{(x,c)}}^T
\big(W^{1/2}\Phi\Phi^T W^{1/2}+\lambda I\big)^{-1} W^{1/2} y.
$$
Therefore,
$$
\boxed{\ \widehat y(x,c)=k_{(x,c)}^\top \big(K_W+\lambda I\big)^{-1}W^{1/2}y\ },
\quad K_W:=W^{1/2}\underbrace{(\Phi\Phi^\top)}_{=:K}W^{1/2},
\boxed{\ \widehat y(x,c)=k_{(x,c)}^T \big(K_W+\lambda I\big)^{-1}W^{1/2}y\ },
\quad K_W:=W^{1/2}\underbrace{(\Phi\Phi^T)}_{=:K}W^{1/2},
$$
which is exactly **kernel ridge regression** with kernel
$$
k\big((x,c),(x',c')\big)=\langle \varphi(x,c),\varphi(x',c')\rangle.
$$
Because \(v_a=V\psi_a\) and \(\alpha_a(x,c)\propto w_a\,\psi(x,c)^\top Q^\top K \psi_a\), \(\varphi\) is a linear transform of a **weighted average of joint features**; hence the kernel is a dot-product on linear transforms of \(\{\psi_a\}\). This proves (B1). ■
Because $v_a=V\psi_a$ and $\alpha_a(x,c)\propto w_a\,\psi(x,c)^T Q^T K \psi_a$, $\varphi$ is a linear transform of a **weighted average of joint features**; hence the kernel is a dot-product on linear transforms of $\{\psi_a\}$. This proves (B1). ■


#### A.2.2 (B2) Training attention in the linearized/NTK regime ⇒ kernel regression with NTK

Now let \(\theta=(Q,K,V,u)\) be trainable, and suppose training uses squared loss with gradient flow (or sufficiently small steps) starting from initialization \(\theta_0\). The **linearized model** around \(\theta_0\) is the first-order Taylor expansion
Now let $\theta=(Q,K,V,u)$ be trainable, and suppose training uses squared loss with gradient flow (or sufficiently small steps) starting from initialization $\theta_0$. The **linearized model** around $\theta_0$ is the first-order Taylor expansion
$$
\widehat y_\theta(x,c)\;\approx\;\widehat y_{\theta_0}(x,c)+\nabla_\theta \widehat y_{\theta_0}(x,c)^\top (\theta-\theta_0)
=: \widehat y_{\theta_0}(x,c) + \phi_{\mathrm{NTK}}(x,c)^\top (\theta-\theta_0),
\widehat y_\theta(x,c)\;\approx\;\widehat y_{\theta_0}(x,c)+\nabla_\theta \widehat y_{\theta_0}(x,c)^T (\theta-\theta_0)
=: \widehat y_{\theta_0}(x,c) + \phi_{\mathrm{NTK}}(x,c)^T (\theta-\theta_0),
$$
where \(\phi_{\mathrm{NTK}}(x,c):=\nabla_\theta \widehat y_{\theta_0}(x,c)\) are the **tangent features**. Standard NTK results (for squared loss, gradient flow, and linearization-validity conditions) imply that the learned function equals **kernel regression with the NTK**:
where $\phi_{\mathrm{NTK}}(x,c):=\nabla_\theta \widehat y_{\theta_0}(x,c)$ are the **tangent features**. Standard NTK results (for squared loss, gradient flow, and linearization-validity conditions) imply that the learned function equals **kernel regression with the NTK**:
$$
k_{\mathrm{NTK}}\big((x,c),(x',c')\big)
:= \big\langle \phi_{\mathrm{NTK}}(x,c),\,\phi_{\mathrm{NTK}}(x',c')\big\rangle,
$$
i.e., predictions have the KRR form with kernel \(K_{\mathrm{NTK}}\) on the training set (and explicit ridge if used, or implicit regularization via early stopping).
i.e., predictions have the KRR form with kernel $K_{\mathrm{NTK}}$ on the training set (and explicit ridge if used, or implicit regularization via early stopping).

It remains to identify the structure of \(\phi_{\mathrm{NTK}}\) for our **linear attention** block and show it lies in the span of **linear transforms of joint features**. Differentiating
$\widehat y(x,c)=u^\top \sum_a \alpha_a(x,c)\, V\psi_a$ at $\theta_0$ yields four groups of terms:
It remains to identify the structure of $\phi_{\mathrm{NTK}}$ for our **linear attention** block and show it lies in the span of **linear transforms of joint features**. Differentiating
$\widehat y(x,c)=u^T \sum_a \alpha_a(x,c)\, V\psi_a$ at $\theta_0$ yields four groups of terms:

- **Readout path ($u$).** \(\partial \widehat y/\partial u = \sum_a \alpha_a(x,c)\, V\psi_a = \varphi_0(x,c)\). This is linear in \(\{\psi_a\}\).
- **Readout path ($u$).** $\partial \widehat y/\partial u = \sum_a \alpha_a(x,c)\, V\psi_a = \varphi_0(x,c)$. This is linear in $\{\psi_a\}$.

- **Value path ($V$).** \(\partial \widehat y/\partial V = \sum_a \alpha_a(x,c)\, u\,\psi_a^\top\). This contributes terms of the form \((u\otimes I)\sum_a \alpha_a(x,c)\psi_a\), i.e., linear in \(\{\psi_a\}\).
- **Value path ($V$).** $\partial \widehat y/\partial V = \sum_a \alpha_a(x,c)\, u\,\psi_a^T$. This contributes terms of the form $(u\otimes I)\sum_a \alpha_a(x,c)\psi_a$, i.e., linear in $\{\psi_a\}$.

- **Query/key paths ($Q,K$).** For linear attention with scores \(s_a=w_a\,\psi(x,c)^\top Q^\top K \psi_a\) and normalized \(\alpha_a=s_a/\sum_b s_b\), derivatives of \(\alpha_a\) w.r.t. \(Q\) and \(K\) are linear combinations of \(\psi(x,c)\) and \(\{\psi_a\}\):
- **Query/key paths ($Q,K$).** For linear attention with scores $s_a=w_a\,\psi(x,c)^T Q^T K \psi_a$ and normalized $\alpha_a=s_a/\sum_b s_b$, derivatives of $\alpha_a$ w.r.t. $Q$ and $K$ are linear combinations of $\psi(x,c)$ and $\{\psi_a\}$:
$$
\frac{\partial \alpha_a}{\partial Q}\propto
\sum_b \big[\delta_{ab}-\alpha_b(x,c)\big]\,
w_a w_b \big( K\psi_a\,\psi(x,c)^\top \big),
w_a w_b \big( K\psi_a\,\psi(x,c)^T \big),
\qquad
\frac{\partial \alpha_a}{\partial K}\propto
\sum_b \big[\delta_{ab}-\alpha_b(x,c)\big]\,
w_a w_b \big( \psi(x,c)\,\psi_a^\top Q^\top \big),
w_a w_b \big( \psi(x,c)\,\psi_a^T Q^T \big),
$$
and hence $\partial \widehat y/\partial Q$, $\partial \widehat y/\partial K$ are finite linear combinations of tensors each bilinear in $\psi(x,c)$ and some $\psi_a$. Contracting with \(u\) and \(V\) produces terms *linear* in \(\psi(x,c)\) and linear in the set \(\{\psi_a\}\).
and hence $\partial \widehat y/\partial Q$, $\partial \widehat y/\partial K$ are finite linear combinations of tensors each bilinear in $\psi(x,c)$ and some $\psi_a$. Contracting with $u$ and $V$ produces terms *linear* in $\psi(x,c)$ and linear in the set $\{\psi_a\}$.

Collecting all components, the tangent feature map can be written as
$$
\phi_{\mathrm{NTK}}(x,c)=\mathcal{L}\big(\psi(x,c),\{\psi_a\}\big),
$$
where \(\mathcal{L}\) is a fixed linear operator determined by \(\theta_0\), \(W\), and the normalization rule for attention. Consequently, the NTK takes the **dot-product** form
where $\mathcal{L}$ is a fixed linear operator determined by $\theta_0$, $W$, and the normalization rule for attention. Consequently, the NTK takes the **dot-product** form
$$
k_{\mathrm{NTK}}\big((x,c),(x',c')\big)=
\Psi(x,c)^\top\, \mathcal{M}\, \Psi(x',c'),
\Psi(x,c)^T\, \mathcal{M}\, \Psi(x',c'),
$$
for some positive semidefinite matrix $\mathcal{M}$ and a finite-dimensional feature stack $\Psi$ that concatenates linear transforms of $\psi(x,c)$ and of the support-set $\{\psi_a\}$. In particular, $k_{\mathrm{NTK}}$ is a dot-product kernel on **linear transforms of the joint features** (possibly augmented by normalization-dependent combinations). Therefore, training the linear-attention ICL model in the linearized regime equals kernel regression with such a kernel—completing (B2). ■

Expand All @@ -192,9 +203,9 @@ for some positive semidefinite matrix $\mathcal{M}$ and a finite-dimensional fea

In both A.1 and A.2, predictions have the KRR form
$$
\widehat y(x,c)=k_{(x,c)}^\top \big(K^\sharp + \lambda I\big)^{-1} \mu,
\widehat y(x,c)=k_{(x,c)}^T \big(K^\sharp + \lambda I\big)^{-1} \mu,
$$
where \(K^\sharp\) is a positive semidefinite kernel matrix computed over the support set (e.g., $K_W=W^{1/2}ZZ^\top W^{1/2}$ in A.1 or $W^{1/2}\Phi\Phi^\top W^{1/2}$ / $K_{\mathrm{NTK}}$ in A.2), $k_{(x,c)}$ is the associated query vector, and \(\mu=W^{1/2}y\) (or an equivalent reweighting).
where $K^{\sharp}$ is a positive semidefinite kernel matrix computed over the support set (e.g., $K_W=W^{1/2}ZZ^T W^{1/2}$ in A.1 or $W^{1/2}\Phi\Phi^T W^{1/2}$ / $K_{\mathrm{NTK}}$ in A.2), $k_{(x,c)}$ is the associated query vector, and $\mu=W^{1/2}y$ (or an equivalent reweighting).

- **Retrieval $R(c)$ / gating.** Changing the support set $S(c)$ (e.g., via a retriever or a gating policy) **removes or adds rows/columns** in $K^\sharp$ and entries in $k_{(x,c)}$. This is equivalent to changing the **empirical measure** over which the kernel smoother is computed (i.e., which samples contribute and how).

Expand Down
Loading