diff --git a/content/90.back-matter.md b/content/90.back-matter.md index 65b0339..2ed7263 100644 --- a/content/90.back-matter.md +++ b/content/90.back-matter.md @@ -12,19 +12,19 @@ This appendix gives full proofs for Proposition 1 and Corollary 1. We keep the w ### A.0 Preliminaries and identities -- **Joint features.** For any pair $(x,c)$, define -$$ +- **Joint features.** For any pair $(x,c)$, define + $$ \psi(x,c):=x\otimes \phi(c)\in\mathbb{R}^{d_x d_c}. $$ For each indexed training example $a$ (standing in for $(i,j)$), write $\psi_a:=\psi(x_a,c_a)$. - **Design/labels/weights.** Stack $N=\sum_i m_i$ training rows: $$ - Z\in\mathbb{R}^{N\times d_x d_c}\ \text{ with rows } Z_a=\psi_a^\top,\qquad + Z\in\mathbb{R}^{N\times d_x d_c}\ \text{ with rows } Z_a=\psi_a^T,\qquad y\in\mathbb{R}^{N},\qquad W=\mathrm{diag}(w_a)\in\mathbb{R}^{N\times N},\ w_a\ge 0. $$ - Define the (unweighted) Gram matrix \(K:=ZZ^\top\) and the weighted Gram + Define the (unweighted) Gram matrix $K:=ZZ^\top$ and the weighted Gram $$ K_W:=W^{1/2} K\, W^{1/2} \;=\; W^{1/2} Z Z^\top W^{1/2}. $$ @@ -36,15 +36,15 @@ $$ \langle \mathrm{vec}(B),\,x\otimes z\rangle = x^\top B z. $$ -- **Weighted ridge solution.** For any \(X\in\mathbb{R}^{N\times p}\), ridge objective +- **Weighted ridge solution.** For any $X\in\mathbb{R}^{N\times p}$, ridge objective $$ \min_\beta \ \|W^{1/2}(y-X\beta)\|_2^2+\lambda\|\beta\|_2^2 $$ - has unique minimizer \(\widehat\beta=(X^\top W X+\lambda I)^{-1}X^\top W y\) and equivalent dual form + has unique minimizer $\widehat\beta=(X^\top W X+\lambda I)^{-1}X^\top W y$ and equivalent dual form $$ \widehat\beta = X^\top W^{1/2}\big(W^{1/2}XX^\top W^{1/2}+\lambda I\big)^{-1}W^{1/2}y. $$ - Predictions for a new feature vector \(x_\star\) equal + Predictions for a new feature vector $x_\star$ equal $$ \widehat f(x_\star)=x_\star^\top \widehat\beta \;=\; @@ -54,15 +54,21 @@ $$ $$ This is **kernel ridge regression** (KRR) with kernel $K_W=W^{1/2}XX^\top W^{1/2}$ and query vector $k_\star=W^{1/2}X x_\star$. + --- ### A.1 Proof of Proposition 1(A): explicit varying-coefficients ⇔ weighted KRR on joint features -Assume the linear, squared-loss setting with $y=\langle \theta(c),x\rangle+\varepsilon$ and $\mathbb{E}[\varepsilon]=0\). Let the varying-coefficients model be \(\theta(c)=B\,\phi(c)\) with \(B\in\mathbb{R}^{d_x\times d_c}$ and ridge penalty $\lambda\|B\|_F^2$. +Assume the linear, squared-loss setting with $y=\langle \theta(c),x\rangle+\varepsilon$ and $\mathbb{E}[\varepsilon]=0$. +Let the varying-coefficients model be $\theta(c)=B\,\phi(c)$ with $B\in\mathbb{R}^{d_x\times d_c}$ and ridge penalty $\lambda\|B\|_F^2$. -**Step 1 (reduce to ridge in joint-feature space).** Vectorize \(B\) as \(\beta=\mathrm{vec}(B)\in\mathbb{R}^{d_x d_c}\). By the identity above, +**Step 1 (reduce to ridge in joint-feature space).** +Vectorize $B$ as $\beta=\mathrm{vec}(B)\in\mathbb{R}^{d_x d_c}$. +By the identity above, $$ -x_a^\top B\,\phi(c_a)=\langle \beta,\, x_a\otimes \phi(c_a)\rangle = \langle \beta,\,\psi_a\rangle . +x_a^T B\,\phi(c_a) += \langle \beta,\, x_a\otimes \phi(c_a)\rangle += \langle \beta,\,\psi_a\rangle. $$ Thus the weighted objective specialized from (★) is $$ @@ -71,116 +77,121 @@ $$ $$ which is exactly weighted ridge with design $X\equiv Z$. -**Step 2 (closed form and prediction).** By the ridge solution, +**Step 2 (closed form and prediction).** +By the ridge solution, $$ -\widehat\beta=(Z^\top W Z+\lambda I)^{-1} Z^\top W y, +\widehat\beta=(Z^T W Z+\lambda I)^{-1} Z^T W y, $$ and the prediction at a query $(x,c)$ with joint feature $\psi(x,c)$ is $$ -\widehat y(x,c)=\psi(x,c)^\top \widehat\beta +\widehat y(x,c)=\psi(x,c)^T \widehat\beta = \underbrace{\big(W^{1/2} Z\, \psi(x,c)\big)}_{k_{(x,c)}}^\top \big(W^{1/2} Z Z^\top W^{1/2}+\lambda I\big)^{-1} W^{1/2} y. $$ -**Step 3 (kernel form).** Since $K:=ZZ^\top$ and $K_W:=W^{1/2} K W^{1/2}$, +**Step 3 (kernel form).** +Since $K:=ZZ^T$ and $K_W:=W^{1/2} K W^{1/2}$, $$ -\boxed{\ \widehat y(x,c)\;=\; k_{(x,c)}^\top \big(K_W+\lambda I\big)^{-1} W^{1/2}y\ }. +\boxed{\ \widehat y(x,c)\;=\; k_{(x,c)}^T \big(K_W+\lambda I\big)^{-1} W^{1/2}y\ }. $$ -Moreover, +Moreover, the $(a,b)$-th entry of the kernel matrix $K$ is $$ K_{ab}=\langle \psi_a,\psi_b\rangle =\big\langle x_a\otimes \phi(c_a),\,x_b\otimes \phi(c_b)\big\rangle =\langle x_a,x_b\rangle\cdot \langle \phi(c_a),\phi(c_b)\rangle, $$ -so (A) is precisely **KRR on joint features** with sample weights $W$. This proves part (A). \■ +so (A) is precisely **KRR on joint features** with sample weights $W$. +This proves part (A). ■ + --- ### A.2 Proof of Proposition 1(B): linear ICL ⇒ kernel regression -We analyze a single attention layer operating on the weighted support set \(S(c)\), using **linear** maps for queries, keys, and values: +We analyze a single attention layer operating on the weighted support set $S(c)$, using **linear** maps for queries, keys, and values: $$ q(x,c)=Q\,\psi(x,c),\qquad k_a = K\,\psi_a,\qquad v_a=V\,\psi_a, $$ with $Q\in\mathbb{R}^{d_q\times d_\psi}$, $K\in\mathbb{R}^{d_k\times d_\psi}$, $V\in\mathbb{R}^{d_v\times d_\psi}$, $d_\psi=d_x d_c$. Let the **unnormalized** attention score for index $a$ be $$ -s_a(x,c):=w_a\,\langle q(x,c),k_a\rangle \;=\; w_a\,\psi(x,c)^\top Q^\top K\,\psi_a . +s_a(x,c):=w_a\,\langle q(x,c),k_a\rangle \;=\; w_a\,\psi(x,c)^T Q^T K\,\psi_a . $$ Define normalized weights $\alpha_a(x,c):=s_a(x,c)/\sum_b s_b(x,c)$ (or any fixed positive normalization; the form below is pointwise in $\{\alpha_a\}$). The context representation and scalar prediction are $$ -z(x,c)=\sum_a \alpha_a(x,c)\, v_a,\qquad \widehat y(x,c)=u^\top z(x,c). +z(x,c)=\sum_a \alpha_a(x,c)\, v_a,\qquad \widehat y(x,c)=u^T z(x,c). $$ We prove two statements: **(B1)** exact KRR if the attention maps are fixed and only the readout is trained, and **(B2)** kernel regression with the NTK if the attention parameters are trained in the linearized regime. #### A.2.1 (B1) Fixed attention, trained linear head ⇒ exact KRR -Assume \(Q,K,V\) are fixed functions (pretrained or chosen a priori), hence \(\alpha_a(x,c)\) are **deterministic** functions of \((x,c)\) and the support set. Define the induced **feature map** +Assume $Q,K,V$ are fixed functions (pretrained or chosen a priori), hence $\alpha_a(x,c)$ are **deterministic** functions of $(x,c)$ and the support set. Define the induced **feature map** $$ \varphi(x,c):=\sum_a \alpha_a(x,c)\, v_a \;\in\; \mathbb{R}^{d_v}. $$ -Stack \(\varphi_a:=\varphi(x_a,c_a)\) row-wise into \(\Phi\in\mathbb{R}^{N\times d_v}\). Training only the readout \(u\) with weighted ridge, +Stack $\varphi_a:=\varphi(x_a,c_a)$ row-wise into $\Phi\in\mathbb{R}^{N\times d_v}$. Training only the readout $u$ with weighted ridge, $$ \widehat u \in \arg\min_u \ \|W^{1/2}(y-\Phi u)\|_2^2+\lambda \|u\|_2^2 $$ -yields \(\widehat u=(\Phi^\top W \Phi + \lambda I)^{-1}\Phi^\top W y\) and predictions +yields $\widehat u=(\Phi^T W \Phi + \lambda I)^{-1}\Phi^T W y$ and predictions $$ -\widehat y(x,c)=\varphi(x,c)^\top \widehat u -= \underbrace{\big(W^{1/2}\Phi\,\varphi(x,c)\big)}_{k_{(x,c)}}^\top -\big(W^{1/2}\Phi\Phi^\top W^{1/2}+\lambda I\big)^{-1} W^{1/2} y. +\widehat y(x,c)=\varphi(x,c)^T \widehat u += \underbrace{\big(W^{1/2}\Phi\,\varphi(x,c)\big)}_{k_{(x,c)}}^T +\big(W^{1/2}\Phi\Phi^T W^{1/2}+\lambda I\big)^{-1} W^{1/2} y. $$ Therefore, $$ -\boxed{\ \widehat y(x,c)=k_{(x,c)}^\top \big(K_W+\lambda I\big)^{-1}W^{1/2}y\ }, -\quad K_W:=W^{1/2}\underbrace{(\Phi\Phi^\top)}_{=:K}W^{1/2}, +\boxed{\ \widehat y(x,c)=k_{(x,c)}^T \big(K_W+\lambda I\big)^{-1}W^{1/2}y\ }, +\quad K_W:=W^{1/2}\underbrace{(\Phi\Phi^T)}_{=:K}W^{1/2}, $$ which is exactly **kernel ridge regression** with kernel $$ k\big((x,c),(x',c')\big)=\langle \varphi(x,c),\varphi(x',c')\rangle. $$ -Because \(v_a=V\psi_a\) and \(\alpha_a(x,c)\propto w_a\,\psi(x,c)^\top Q^\top K \psi_a\), \(\varphi\) is a linear transform of a **weighted average of joint features**; hence the kernel is a dot-product on linear transforms of \(\{\psi_a\}\). This proves (B1). ■ +Because $v_a=V\psi_a$ and $\alpha_a(x,c)\propto w_a\,\psi(x,c)^T Q^T K \psi_a$, $\varphi$ is a linear transform of a **weighted average of joint features**; hence the kernel is a dot-product on linear transforms of $\{\psi_a\}$. This proves (B1). ■ + #### A.2.2 (B2) Training attention in the linearized/NTK regime ⇒ kernel regression with NTK -Now let \(\theta=(Q,K,V,u)\) be trainable, and suppose training uses squared loss with gradient flow (or sufficiently small steps) starting from initialization \(\theta_0\). The **linearized model** around \(\theta_0\) is the first-order Taylor expansion +Now let $\theta=(Q,K,V,u)$ be trainable, and suppose training uses squared loss with gradient flow (or sufficiently small steps) starting from initialization $\theta_0$. The **linearized model** around $\theta_0$ is the first-order Taylor expansion $$ -\widehat y_\theta(x,c)\;\approx\;\widehat y_{\theta_0}(x,c)+\nabla_\theta \widehat y_{\theta_0}(x,c)^\top (\theta-\theta_0) -=: \widehat y_{\theta_0}(x,c) + \phi_{\mathrm{NTK}}(x,c)^\top (\theta-\theta_0), +\widehat y_\theta(x,c)\;\approx\;\widehat y_{\theta_0}(x,c)+\nabla_\theta \widehat y_{\theta_0}(x,c)^T (\theta-\theta_0) +=: \widehat y_{\theta_0}(x,c) + \phi_{\mathrm{NTK}}(x,c)^T (\theta-\theta_0), $$ -where \(\phi_{\mathrm{NTK}}(x,c):=\nabla_\theta \widehat y_{\theta_0}(x,c)\) are the **tangent features**. Standard NTK results (for squared loss, gradient flow, and linearization-validity conditions) imply that the learned function equals **kernel regression with the NTK**: +where $\phi_{\mathrm{NTK}}(x,c):=\nabla_\theta \widehat y_{\theta_0}(x,c)$ are the **tangent features**. Standard NTK results (for squared loss, gradient flow, and linearization-validity conditions) imply that the learned function equals **kernel regression with the NTK**: $$ k_{\mathrm{NTK}}\big((x,c),(x',c')\big) := \big\langle \phi_{\mathrm{NTK}}(x,c),\,\phi_{\mathrm{NTK}}(x',c')\big\rangle, $$ -i.e., predictions have the KRR form with kernel \(K_{\mathrm{NTK}}\) on the training set (and explicit ridge if used, or implicit regularization via early stopping). +i.e., predictions have the KRR form with kernel $K_{\mathrm{NTK}}$ on the training set (and explicit ridge if used, or implicit regularization via early stopping). -It remains to identify the structure of \(\phi_{\mathrm{NTK}}\) for our **linear attention** block and show it lies in the span of **linear transforms of joint features**. Differentiating -$\widehat y(x,c)=u^\top \sum_a \alpha_a(x,c)\, V\psi_a$ at $\theta_0$ yields four groups of terms: +It remains to identify the structure of $\phi_{\mathrm{NTK}}$ for our **linear attention** block and show it lies in the span of **linear transforms of joint features**. Differentiating +$\widehat y(x,c)=u^T \sum_a \alpha_a(x,c)\, V\psi_a$ at $\theta_0$ yields four groups of terms: -- **Readout path ($u$).** \(\partial \widehat y/\partial u = \sum_a \alpha_a(x,c)\, V\psi_a = \varphi_0(x,c)\). This is linear in \(\{\psi_a\}\). +- **Readout path ($u$).** $\partial \widehat y/\partial u = \sum_a \alpha_a(x,c)\, V\psi_a = \varphi_0(x,c)$. This is linear in $\{\psi_a\}$. -- **Value path ($V$).** \(\partial \widehat y/\partial V = \sum_a \alpha_a(x,c)\, u\,\psi_a^\top\). This contributes terms of the form \((u\otimes I)\sum_a \alpha_a(x,c)\psi_a\), i.e., linear in \(\{\psi_a\}\). +- **Value path ($V$).** $\partial \widehat y/\partial V = \sum_a \alpha_a(x,c)\, u\,\psi_a^T$. This contributes terms of the form $(u\otimes I)\sum_a \alpha_a(x,c)\psi_a$, i.e., linear in $\{\psi_a\}$. -- **Query/key paths ($Q,K$).** For linear attention with scores \(s_a=w_a\,\psi(x,c)^\top Q^\top K \psi_a\) and normalized \(\alpha_a=s_a/\sum_b s_b\), derivatives of \(\alpha_a\) w.r.t. \(Q\) and \(K\) are linear combinations of \(\psi(x,c)\) and \(\{\psi_a\}\): +- **Query/key paths ($Q,K$).** For linear attention with scores $s_a=w_a\,\psi(x,c)^T Q^T K \psi_a$ and normalized $\alpha_a=s_a/\sum_b s_b$, derivatives of $\alpha_a$ w.r.t. $Q$ and $K$ are linear combinations of $\psi(x,c)$ and $\{\psi_a\}$: $$ \frac{\partial \alpha_a}{\partial Q}\propto \sum_b \big[\delta_{ab}-\alpha_b(x,c)\big]\, - w_a w_b \big( K\psi_a\,\psi(x,c)^\top \big), + w_a w_b \big( K\psi_a\,\psi(x,c)^T \big), \qquad \frac{\partial \alpha_a}{\partial K}\propto \sum_b \big[\delta_{ab}-\alpha_b(x,c)\big]\, - w_a w_b \big( \psi(x,c)\,\psi_a^\top Q^\top \big), + w_a w_b \big( \psi(x,c)\,\psi_a^T Q^T \big), $$ - and hence $\partial \widehat y/\partial Q$, $\partial \widehat y/\partial K$ are finite linear combinations of tensors each bilinear in $\psi(x,c)$ and some $\psi_a$. Contracting with \(u\) and \(V\) produces terms *linear* in \(\psi(x,c)\) and linear in the set \(\{\psi_a\}\). + and hence $\partial \widehat y/\partial Q$, $\partial \widehat y/\partial K$ are finite linear combinations of tensors each bilinear in $\psi(x,c)$ and some $\psi_a$. Contracting with $u$ and $V$ produces terms *linear* in $\psi(x,c)$ and linear in the set $\{\psi_a\}$. Collecting all components, the tangent feature map can be written as $$ \phi_{\mathrm{NTK}}(x,c)=\mathcal{L}\big(\psi(x,c),\{\psi_a\}\big), $$ -where \(\mathcal{L}\) is a fixed linear operator determined by \(\theta_0\), \(W\), and the normalization rule for attention. Consequently, the NTK takes the **dot-product** form +where $\mathcal{L}$ is a fixed linear operator determined by $\theta_0$, $W$, and the normalization rule for attention. Consequently, the NTK takes the **dot-product** form $$ k_{\mathrm{NTK}}\big((x,c),(x',c')\big)= -\Psi(x,c)^\top\, \mathcal{M}\, \Psi(x',c'), +\Psi(x,c)^T\, \mathcal{M}\, \Psi(x',c'), $$ for some positive semidefinite matrix $\mathcal{M}$ and a finite-dimensional feature stack $\Psi$ that concatenates linear transforms of $\psi(x,c)$ and of the support-set $\{\psi_a\}$. In particular, $k_{\mathrm{NTK}}$ is a dot-product kernel on **linear transforms of the joint features** (possibly augmented by normalization-dependent combinations). Therefore, training the linear-attention ICL model in the linearized regime equals kernel regression with such a kernel—completing (B2). ■ @@ -192,9 +203,9 @@ for some positive semidefinite matrix $\mathcal{M}$ and a finite-dimensional fea In both A.1 and A.2, predictions have the KRR form $$ -\widehat y(x,c)=k_{(x,c)}^\top \big(K^\sharp + \lambda I\big)^{-1} \mu, +\widehat y(x,c)=k_{(x,c)}^T \big(K^\sharp + \lambda I\big)^{-1} \mu, $$ -where \(K^\sharp\) is a positive semidefinite kernel matrix computed over the support set (e.g., $K_W=W^{1/2}ZZ^\top W^{1/2}$ in A.1 or $W^{1/2}\Phi\Phi^\top W^{1/2}$ / $K_{\mathrm{NTK}}$ in A.2), $k_{(x,c)}$ is the associated query vector, and \(\mu=W^{1/2}y\) (or an equivalent reweighting). +where $K^{\sharp}$ is a positive semidefinite kernel matrix computed over the support set (e.g., $K_W=W^{1/2}ZZ^T W^{1/2}$ in A.1 or $W^{1/2}\Phi\Phi^T W^{1/2}$ / $K_{\mathrm{NTK}}$ in A.2), $k_{(x,c)}$ is the associated query vector, and $\mu=W^{1/2}y$ (or an equivalent reweighting). - **Retrieval $R(c)$ / gating.** Changing the support set $S(c)$ (e.g., via a retriever or a gating policy) **removes or adds rows/columns** in $K^\sharp$ and entries in $k_{(x,c)}$. This is equivalent to changing the **empirical measure** over which the kernel smoother is computed (i.e., which samples contribute and how).