为什么ChatGPT可以进行数学运算和逻辑推理?变形金刚如何“思考”?

简介:为什么现在流行的ChatGPT可以做数学运算和逻辑推理?让我们阅读这篇文章来了解基本原理!
Transformer 模型是 AI 系统的基础。 “变压器如何工作”的核心结构图无数。
但是这些图并没有提供计算该模型的框架的直观表示。如果研究人员对变压器的工作原理感兴趣,直观地了解其工作机制是非常有用的。
像 Transformers 一样思考本文提出了一个 Transformer 类的计算框架,可以直接计算和模仿 Transformer 计算。使用RASP编程语言,每个程序都被编译成一个特殊的转换器。
在这篇博客中,我用 Python 重现了 RASP (RASPy) 的变体。该语言与原始语言大致相同,但还有一些我觉得有趣的变化。通过这些语言,作者盖尔·韦斯的作品提供了一种具有挑战性且有趣的方式来理解它是如何运作的。
!pip install git+https://github.com/srush/RASPy
在我们讨论语言本身之前,让我们看一个使用 Transformer 进行编码的示例。这是一些计算翻转的代码,即反转输入序列。代码本身使用两个转换器层来应用注意力和数学计算来实现此结果。
def flip():
length = (key(1) == query(1)).value(1)
flip = (key(length - indices - 1) == query(indices)).value(tokens)
return flip
flip()

Transformers 作为代码
我们的目标是定义一组最小化 Transformer 表达的计算形式。我们通过类比来描述每种语言的构造及其在 Transformer 中的对应部分。 (正式的语言规范请参见本文末尾论文全文的链接)。
该语言的核心单元是序列操作,将一个序列转换为另一个相同长度的序列。从现在开始我将称这些转变为转变。
输入
在变压器中,基础层是模型的前馈输入。该输入通常包含原始令牌和位置信息。
在代码中,令牌特征表示最简单的转换,返回通过模型传递的令牌。默认的输入序列是“Hello”:
tokens

如果我们想改变转换中的输入,我们使用输入法来传递它。价值。
tokens.input([5, 2, 4, 5, 2, 2])

作为变形金刚,我们无法立即接受这些序列的位置。但是为了模拟位置嵌入,我们可以得到位置的索引:
indices

sop = indices
sop.input("goodbye")

前馈网络
经过输入层后,我们到达前馈网络层。在 Transformer 中,此步骤独立地将数学运算应用于序列的每个元素。
在代码中,我们通过计算变换来表示此步骤。对序列的每个元素执行独立的数学运算。
tokens == "l"

结果是新的转变。当新的输入被重构时,根据重构方法进行计算:
model = tokens * 2 - 1
model.input([1, 2, 3, 5, 2])

该操作可以组合多个变换。例如,以上面的标记和索引为例,这里您可以 Transformer 类别可以跟踪各种片段信息:
model = tokens - 5 + indices
model.input([1, 2, 3, 5, 2])

(tokens == "l") | (indices == 1)

我们提供了一些辅助函数来使编写转换变得更容易。例如,,其中
提供类似于if
结构的功能。
where((tokens == "h") | (tokens == "l"), tokens, "q")

map
允许我们定义自己的操作,例如将字符串转换为int
。 (用户应该小心可以使用简单神经网络计算的操作)
atoi = tokens.map(lambda x: ord(x) - ord('0'))
atoi.input("31234")

函数可以轻松描述这些转换的级联。例如,这里是应用 和 atoi 并加上 2
def atoi(seq=tokens):
return seq.map(lambda x: ord(x) - ord('0'))
op = (atoi(where(tokens == "-", "0", tokens)) + 2)
op.input("02-13")

注意力过滤器
的操作。一旦开始应用注意力机制,事情就开始变得有趣。这允许序列的不同元素之间交换信息。
我们首先定义键和查询的概念。可以直接从上面的转换创建键和查询。例如,如果我们要定义一个键,我们将其称为key
。
key(tokens)

与 查询
query(tokens)

标量可用作 我们创建了过滤器来应用键和查询之间的操作。这对应于一个二进制矩阵,指示每个查询要关注哪个键。与 Transformers 不同,这个注意力矩阵不添加权重。 一些示例: 选择器可以通过布尔运算组合。例如,此选择器组合了 before 和 eq ,我们通过在矩阵中包含键和值对来显示它们。 给定一个注意力选择器,我们可以为聚合操作提供一个序列值。我们通过收集选择器选择的正确值来执行聚合。 (请注意:在原始论文中,他们使用平均聚合操作并展示了一种巧妙的结构,其中平均聚合代表求和计算。RASPy 默认使用累积来使其简单并避免碎片。实际上,这意味着raspy 可能会低估层数。基于平均值的模型可能需要两倍的层数) 视觉上我们遵循图结构,查询在左侧,键在顶部,值在底部,输出在右边 一些注意力机制操作甚至不需要输入令牌。例如,为了计算序列长度,我们创建一个“全选”注意过滤器并为其分配一个值。 这里有更复杂的示例,分步如下所示。 (这有点像进行采访) 我们要计算一个序列的相邻值的总和,首先我们向前截断: 然后我们向后截断: 两者相交: 这里是计算累积和的示例。我们引入了调用转换的功能来帮助您进行调试。 此语言支持编译更复杂的转换。他还通过跟踪每一个操作来操作计算层。 这里是2层变换的示例,第一层对应于长度的计算,第二层对应于累积和。 有了这个函数库,我们可以编写来完成复杂的任务。盖尔·韦斯(Gail Weiss)问了我一个极具挑战性的问题来打破这一步:我们可以加载一个可以添加任意长度数字的变压器吗? 例如:使用字符串“19492+23919”我们可以加载正确的输出吗? 如果您想亲自尝试,我们提供了一个您可以亲自尝试的版本。 加载一个序列,其中索引I中的所有元素都有值 将所有标记向右移动。 计算序列的最小值。(这一步变得困难,我们的版本使用2层注意力机制) 用令牌q计算第一个索引(2层) 分离一个序列,然后将两个部分右对齐(2层): 将特殊标记“ "" "00 key♿♿ 或 ♿ 给定 或 广播为基本序列的长度。
query(1)
eq = (key(tokens) == query(tokens))
eq
offset = (key(indices) == query(indices - 1))
offset
before = key(indices) < query(indices)
before
after = key(indices) > query(indices)
after
before & eq
使用注意力机制
(key(tokens) == query(tokens)).value(1)
length = (key(1) == query(1)).value(1)
length = length.name("length")
length
WINDOW=3
s1 = (key(indices) >= query(indices - WINDOW + 1))
s1
s2 = (key(indices) <= query(indices))
s2
offset = (key(indices) == query(indices - 1))
offset
sum2 = sel.value(tokens)
sum2.input([1,3,2,2,2])
def cumsum(seq=tokens):
x = (before | (key(indices) == query(indices))).value(seq)
return x.name("cumsum")
cumsum().input([3, 1, -2, 3, 1])
Layer
x = cumsum(length - indices)
x.input([3, 2, 3, 5])
使用变压器编程
挑战1:选择给定索引
def index(i, seq=tokens):
x = (key(indices) == query(i)).value(seq)
return x.name("index")
index(1)
挑战2:变换
def shift(i=1, default="_", seq=tokens):
x = (key(indices) == query(indices-i)).value(seq, default)
return x.name("shift")
shift(2)
挑战 3:最小化
def minimum(seq=tokens):
sel1 = before & (key(seq) == query(seq))
sel2 = key(seq) < query(seq)
less = (sel1 | sel2).value(1)
x = (key(less) == query(0)).value(seq)
return x.name("min")
minimum()([5,3,2,5,2])
挑战4:第一个索引
def first(q, seq=tokens):
return minimum(where(seq == q, indices, 99))
first("l")
挑战5:右对齐右-对齐填充序列。示例:“
ralign().inputs('xyz___') ='—xyz'
”(级别 2)def ralign(default="-", sop=tokens):
c = (key(sop) == query("_")).value(1)
x = (key(indices + c) == query(indices)).value(sop, default)
return x.name("ralign")
ralign()("xyz__")
挑战 6:分离
def split(v, i, sop=tokens):
mid = (key(sop) == query(v)).value(indices)
if i == 0:
x = ralign("0", where(indices < mid, sop, "_"))
return x
else:
x = where(indices > mid, sop, "0")
return x
split("+", 1)("xyz+zyr")
split("+", 0)("xyz+zyr")
挑战7:滑动
版权声明
本文仅代表作者观点,不代表Code前端网立场。
本文系作者Code前端网发表,如需转载,请注明页面地址。
发表评论:
◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。