聊对大模型可解释性的理解

IMG_7155

聊对大模型可解释性的理解

我的第一篇工作将围绕我对大模型可解释性的理解展开。在我进入下一段学习之前,希望通过一些文字来保留和记载我目前对机器学习模型的一些理解。

我想先从一个尽可能理论化的模型定义开始。

一、从监督学习模型的本质说起

对于一个监督学习任务,我们有样本空间 $\mathcal X$ 和标签空间 $\mathcal Y$。监督学习模型的本质就是一个函数,一个从样本空间映射到标签空间中的函数:

而我们所期望的完美模型,或者说目标模型,应当是所有这样函数集合中的某个特定子集,或者某个满足真实规律的目标函数。我们的目标,便是使用一个服从某分布的采样器,从样本空间中有限次采样,并获取带噪声的 label,使用这些数据去构建一个拟合函数,使它在函数空间里渐进逼近这个目标子集。

从泛函的角度来理解,我们可以在函数空间上定义一个刻画函数好坏的量:population risk $R$。它不一定严格是一个 norm,但它承担了类似“距离目标函数有多远”的作用。自然地,我们希望零点或者最小点定义在目标函数上,即理想情况下有

这样,我们的目标便是尽可能降低当前函数的风险。

遗憾的是,我们不知道,也不可能知道准确的 $R$ 的表达式。这个风险函数的零点本身就蕴含着我们的目标函数。因此我们只能构建一个 $R$ 的近似,即 empirical risk。这里的 empirical risk 可以通过我们采到的样本来构建:

一个直觉是:我们的模型在已知数据上预测得越好,我们就认为它在 population 上的预测也会更好。因此,我们经验性地认为,如果一个函数能够很好地 fit 我们的训练数据,那么这样的函数就是我们希望得到的拟合函数。

然而,数据的标签有噪音,满足训练数据的函数又太多。我们得到的这些函数集合应当与目标函数集合距离比较近,但我们不能任意地取 $f$。否则,一个极端不连续、只在训练点上表现很好的函数,也可能成为训练集上的优秀拟合函数之一。类似狄利克雷函数那样的构造,可以在有限采样点上取得很好的 empirical risk,却完全没有我们想要的泛化意义。

因此,我们需要对函数空间有一定的假设,比如连续性、连续可微性,或者 Lipschitz 连续性。另外,对函数的表达也要有明确的参数表示,比如

也就是说,我们需要一个可以用于直接表示函数的具体形式。由此定义 hypothesis space $\mathcal F$。这里 $\mathcal F$ 中的任何一个函数均可以用参数 $\theta$ 表达,$\theta$ 可以是多元矢量。

在这样的假设里,我们可以很好地刻画一个模型的误差估计过程。

首先,目标函数不一定出现在我们的 hypothesis space 中。因为我们的 space 只是 all-function space 上的一个极小子集,直觉地可以看成一个零测子集。因此,我们考虑在 hypothesis space 中使 population risk 最小的函数,也就是我们能在 hypothesis space 里表达的、最接近目标函数的函数:

这便是 approximation error,它由 hypothesis space 的表达能力限制产生。

其次,我们的训练数据是有限的,我们无法得到 exact population risk。因此,我们使用有限样本下的 empirical risk。这里往往会被定义为 predicted label 与真实 label 之间的 square loss,或者其他 loss function。我们取在 hypothesis space 中使 empirical risk 最小的函数:

它是我们能在 hypothesis space 里表达的、经验地认为最好的函数。这便是 estimation error,它由样本数量和样本质量的局限性产生。

最后,我们考虑得到 $\hat f_n$ 的过程。无论是显式算出,还是通过训练得到,最终的问题都归结为一个多元函数的优化问题。梯度下降是解决优化问题的方法之一。我们的训练需要更新参数,调整当前函数,以逐渐接近 empirical risk 的最小解。在解决优化问题的过程中产生的误差,就是 optimization error,它由训练时间、算力和优化算法的限制产生。

因此,我们可以将模型与目标函数之间的误差粗略分解为:

对于大模型来说,训练时间和训练方法、数据量和数据质量、模型参数量和模型表达能力,便成为限制模型能力的三个重要因素。

二、数据分布与 IID 假设

对于模型理论的理解,我们还可以再深入一些。

我们对数据采样分布的假设往往是 iid,即独立同分布。这里的 iid 在机器学习理论中至关重要。无论是 estimation error,还是衡量泛化能力,这里的 iid 都是对数据采样过程的极大简化描述,也是最简描述。对于相同的分布,我们往往还会承认它满足一定的正则性假设,比如均值、方差有限。

然而,在部分模型中,iid 假设并不能很好满足。一个经典的例子便是强化学习。

对于 on-policy 的强化学习算法,数据收集来自当前模型。因此,轨迹的采样分布会随着模型能力提升而变化,这显然不是 iid 的。对此,人们有一些改进,比如 replay buffer。它通过临时存储样本,并在后期训练时重新采样,来获取过去模型产生的轨迹数据,从而部分降低数据相关性和分布漂移的问题。

然而,这种方法也不能彻底解决问题。buffer 的容量有限,前期训练无法使用优质数据,旧策略生成的数据又可能与当前策略不匹配。因此,强化学习在数据采样上天然不满足标准 iid 分布,这也成为训练不稳定的因素之一。

对于大语言模型的预训练来说,数据通常是提前收集好的大规模语料,训练时可以从固定语料分布中采样。因此,它在工程上更接近 iid 的监督学习。当然,严格来说,大语言模型的预训练数据也并不是真正完美 iid 的:文档之间存在来源差异,token 之间存在强相关,语料本身也往往是多个分布的混合。但相比 on-policy 强化学习,它至少更接近固定数据分布下的经验风险最小化问题。

三、Hypothesis Space、模型架构与归纳偏置

Hypothesis space 的定义影响模型的表达能力。

一个经典的定理是:对于两层神经网络,在一些连续性假设下,随着参数量或者隐层宽度趋向无穷,它有几乎万能的函数表达能力。也就是说,神经网络在理论上可以逼近非常广泛的连续函数。

举例而言,MLP、CNN、RNN、LSTM、Transformer 等架构,本质上都是对 hypothesis space 的定义,即给出一个由参数生成函数的方案。

一个具有万能表达能力的架构似乎很容易构建。那么,为什么还会诞生如此多不同的架构?为什么不使用 CNN 做自然语言处理?为什么不使用 MLP 做大语言模型?其中细微的差距值得我们探索。

不同模型首先不同的就是表达习惯,也称归纳偏置。我个人认为,它也可以理解为“表达偏执”。

对一个架构而言,它会有自己独自擅长的领域。RNN 和 LSTM 处理 sequence data,CNN 处理图像数据。本质原因在于,数据自身的关联关系要和模型的假设基本匹配。

RNN 和 LSTM 的流式设计,可以在处理序列数据时关注 token 附近的 tokens 表现,并且随着距离增大,关系逐渐减弱。这与自然语言中词语之间关系强度的分布基本相符。

CNN 的卷积核设计,则可以更好捕获二维局部性质。这与图像中近邻像素之间关系更强的特点基本相符。

而 Transformer 是自由度最高的架构之一,有相对最弱的“偏执”。它既有处理序列数据的能力,又有处理图像数据的能力。所以我们说 Transformer is all you need。

可是,Transformer 是 only need 吗?我们可以将所有问题都交给 Transformer 处理吗?并非如此。在一些任务上,CNN 的表现仍然会超过 Transformer。

我个人理解为,这是“架构先验”和“数据后验”对训练共同作用的结果。

模型架构越有特色,预先假设的相关性越强,逻辑关系越明显,架构先验就越强,于是可以更快、更好地学习同类型的数据。而自由的架构,会被鼓励通过数据学习到结构。比如 Transformer 的注意力机制,可以认为是 CNN 和 RNN 那种先验注意力关系,也就是相邻关系的上游设计。在数据充分时,模型可以 grok 出近邻相关的注意力,甚至 grok 出更好的人类无法提前想到的结构。

因此,归纳偏置越强,模型越容易学到匹配的数据,但是上限可能更低;归纳偏置越弱,模型越难学到数据,但是在数据和算力充分时,上限可能更高。

一个极端的例子是 scaling law 的参数拟合与 LLM 预训练。前者只拟合个位数个参数,最终拟合结果极其依赖 scaling law 的结构,具有极强的归纳偏置。比如 Multi-Power Law 中,law 的新形式会显著影响拟合效果。后者依靠大量语料进行训练,捕获复杂语义关系,上限高,但是训练周期长。

因此,归纳偏置的强弱和偏好,会影响模型的表达偏好,进而影响模型在特定数据上的表现。

四、优化算法与训练动力学

优化算法的选择,则影响模型走向 empirical risk 最小函数的物理效率,也就是算力成本。

对于 LLM 而言,一个模型的训练动辄几周、几个月。在确认模型架构和数据之后,找到一个帮助模型迅速收敛的优化器十分关键。

我们对 hypothesis space 上的每一个函数定义 empirical risk,也就是对参数空间中的每一点定义 empirical risk。于是,我们的优化任务变为:在多元参数空间中最小化 empirical risk。梯度下降算法便应运而生。

随机梯度下降和 AdamW 都是梯度下降的改进,旨在经验上更高效地找到低 empirical risk 的区域,并尽量避免停留在很差的局部结构附近。现主流应用的是 AdamW。它通过动量估计、自适应学习率和解耦 weight decay,改变参数在损失地形中的移动方式,使训练在工程上更加稳定有效。

最新的一些优化器,比如 Muon,则考虑使用矩阵结构来处理梯度更新。例如,它会关注梯度矩阵中的主要结构方向,而不只是逐元素地更新参数。直觉上,这类似于通过分解梯度矩阵,关注主要特征值对应的特征空间移动。这一方法可以用于处理高噪音数据,例如高频量化交易数据。当然,这更多还是一种经验层面的理解,需要在具体任务中验证。

模型训练的动力学中,有一个经典现象是 grokking,即“灵感一现”。

它表达的是:在过参数化的 model 上做一些具有结构性的训练任务,比如 mod 加法、乘法、群运算等,模型会先在训练数据上拟合得很好,但是由于它只是记忆数据,而不是总结规律,所以很难泛化到其他结果。随着训练时间增长,模型会在某一个较短时间窗口内产生 validation accuracy 的跃迁,仿佛瞬间“顿悟”了训练任务。

事实上,这可能是由于正则化效应导致的参数平移。

经验地,我们认为,对于一个特定任务,表达这个任务的函数应当是参数稀疏的。它不应该完全利用到所有神经元和所有参数。而且,一个有规律的任务应当具有更简单的表达形式。因此我们认为,在同样拟合数据的情况下,稀疏解、低复杂度解或者更结构化的解,往往具有更强的泛化能力。

机器学习理论中,weight decay、dropout 和 ridge regression 都可以从正则化的角度理解。它们并不完全等同,但在直觉上都有抑制复杂解、偏向简单解或低范数解的作用。因此,我们可以用 ridge regression 的图像来理解这个问题。

在过参数模型中,参数个数远远超过样本个数,因此使 empirical risk 为 0 的参数往往不是一个点,而是构成了 hypothesis space 或参数空间中的一个高维区域。模型会在训练初期迅速收敛到这个区域附近,之后参数会在这个区域附近缓慢移动,利用正则项的梯度逐渐移动到更具有结构性、更稀疏的表达上,进而找到最有规律的 empirical risk zero function。

所以,grokking 并不是真的神秘地突然获得了灵感,而更像是模型在训练动力学中,从 memorization solution 迁移到 algorithmic solution 的过程。

五、强化学习的不稳定性

另一个经典现象是强化学习的不稳定性。

强化学习模型在训练中的 loss curve 明显具有较大波动性,模型崩溃的概率也高于一般机器学习模型。我理解的原因有几个。

首先是任务难度与数据之间的矛盾。强化学习模型往往希望找到一个策略,来完成具有强思维难度的任务,比如德州扑克。而在强化学习的数据收集中,每次 rollout 一轮带来的数据有限,无法像大语言模型那样高效使用互联网信息,或者通过其他方式迅速获取特定任务的大量优质数据。

因此,对于特定的困难任务,强化学习的数据获取困难,导致有效 batch size 较小,数据收集方差较大。再加上数据收集并不是 iid 的,这共同导致强化学习的数据质量不高、数量较少,训练天然具有困难。

其次是模型参数量和目标函数复杂性的限制。由于任务并不简单,策略的微小变化可能导致行动选择发生巨大变化,而行动选择的改变又会导致后续轨迹和奖励发生巨大变化。因此,策略到收益之间的函数关系可能非常不平滑。直觉上,这意味着目标函数的连续性较差,或者有效 Lipschitz 常数较大。

而在机器学习的许多理论推导中,Lipschitz 假设是必要的,其数值也会影响最终误差的上界估计。因此,如果强化学习问题中的有效 Lipschitz 常数很大,或者目标函数本身非常不平滑,那么收敛就会更加困难。

在模型表达上,由于这种弱连续性和高方差,渐进拟合表达的 variance 也会受到挑战,最终导致整体表现不稳定。

六、回到大模型可解释性

这就是我目前对机器学习模型,尤其是大模型可解释性的一些前置理解。

大模型首先是一个函数,但它又不只是一个黑箱函数。我们可以从 hypothesis space、数据分布、归纳偏置、优化动力学和训练现象几个层面去理解它。

如果我们想理解大模型为什么能够表达复杂语义,为什么某些结构会在训练中出现,为什么它有时只是记忆而有时能够学到规律,那么这些理论视角都是必要的起点。

大模型可解释性要解释的,最终也许不是某一个单独参数,也不是某一次输出,而是模型如何在训练过程中形成结构,如何在高维参数空间中找到可以泛化的表示,以及这些表示如何组合成我们观察到的能力。