数据结构论坛

首页 » 分类 » 分类 » 万字综述,核心开发者全面解读PyTorc
TUhjnbcbe - 2024/7/9 16:13:00

斯坦福大学博士生与Facebook人工智能研究所研究工程师EdwardZ.Yang是PyTorch开源项目的核心开发者之一。他在5月14日的PyTorch纽约聚会上做了一个有关PyTorch内部机制的演讲,本文是该演讲的长文章版本。

大家好!今天我想谈谈PyTorch的内部机制。

这份演讲是为用过PyTorch并且有心为PyTorch做贡献但却被PyTorch那庞大的C++代码库劝退的人提供的。没必要说谎:PyTorch代码库有时候确实让人难以招架。

本演讲的目的是为你提供一份导航图:为你讲解一个「支持自动微分的张量库」的基本概念结构,并为你提供一些能帮你在代码库中寻路的工具和技巧。我预设你之前已经写过一些PyTorch,但却可能还没有深入理解机器学习软件库的编写方式。

本演讲分为两部分:在第一部分中,我首先会全面介绍张量库的各种概念。我首先会谈谈你们知道且喜爱的张量数据类型,并详细讨论这种数据类型究竟能提供什么,这能让我们更好地理解其内部真正的实现方式。

如果你是一位PyTorch高级用户,你可能已经熟悉其中大部分材料了。我们也会谈到「扩展点(extensionpoints)」的三个概念、布局(layout)、设备(device)和数据类型(dtype),这能引导我们思考张量类的扩展的方式。在PyTorch纽约聚会的现场演讲中,我略过了有关自动梯度(autograd)的幻灯片,但我在这里会进行一些讲解。

第二部分会阐述真正用PyTorch写代码时所涉及的基本细节。我会告诉你如何在autograd代码中披荆斩棘、什么代码是真正重要的以及怎样造福他人,我还会介绍PyTorch为你写核(kernel)所提供的所有炫酷工具。

概念

张量

张量是PyTorch中的核心数据结构。对于张量直观上所表示的东西,你可能已有很好的理解:张量是一种包含某种标量类型(比如浮点数和整型数等)的n维数据结构。我们可以将张量看作是由一些数据构成的,还有一些元数据描述了张量的大小、所包含的元素的类型(dtype)、张量所在的设备(CPU内存?CUDA内存?)

另外还有一个你可能没那么熟悉的元数据:步幅(stride)。stride实际上是PyTorch最别致的特征之一,所以值得稍微多讨论它一些。

张量一个数学概念。但要在我们的计算机中表示它,我们必须为它们定义某种物理表示方法。最常用的表示方法是在内存中相邻地放置张量的每个元素(这也是术语「contiguous(邻接)」的来源),即将每一行写出到内存,如上所示。在上面的案例中,我已经指定该张量包含32位的整型数,这样你可以看到每一个整型数都位于一个物理地址中,每个地址与相邻地址相距4字节。为了记住张量的实际维度,我们必须将规模大小记为额外的元数据。

所以这幅图与步幅有什么关系?

假设我想要读取我的逻辑表示中位置张量[0,1]的元素。我该如何将这个逻辑位置转译为物理内存中的位置?步幅能让我们做到这一点:要找到一个张量中任意元素的位置,我将每个索引与该维度下各自的步幅相乘,然后将它们全部加到一起。在上图中,我用蓝色表示第一个维度,用红色表示第二个维度,以便你了解该步幅计算中的索引和步幅。进行这个求和后,我得到了2(零索引的);实际上,数字3正是位于这个邻接数组的起点以下2个位置。

(后面我还会谈到TensorAccessor,这是一个处理索引计算的便利类(convenienceclass)。当你使用TensorAccessor时,不会再操作原始指针,这些计算过程已经为你隐藏了起来。)

步幅是我们为PyTorch用户讲解方法的基本基础。举个例子,假设我想取出一个表示以上张量的第二行的张量:

使用高级的索引支持,我只需写出张量[1,:]就能得到这一行。重要的是:当我这样做时,不会创建一个新张量;而是会返回一个基于底层数据的不同域段(view)的张量。这意味着,如果我编辑该视角下的这些数据,它就会反映在原始的张量中。

在这种情况下,了解如何做到这一点并不算太困难:3和4位于邻接的内存中,我们只需要记录一个说明该(逻辑)张量的数据位于顶部以下2个位置的偏移量(offset)。(每个张量都记录一个偏移量,但大多数时候它为零,出现这种情况时我会在我的图表中省略它。)

演讲时的提问:如果我取张量的一个域段,我该如何释放底层张量的内存?答案:你必须制作该域段的一个副本,由此断开其与原始物理内存的连接。你能做的其它事情实际上并不多。另外,如果你很久之前写过Java,取一个字符串的子字符串也有类似的问题,因为默认不会制作副本,所以子字符串会保留(可能非常大的字符串)。很显然,Java7u6将其固定了下来。

如果我想取第一列,还会更有意思:

当我们查看物理内存时,可以看到该列的元素不是相邻的:两者之间有一个元素的间隙。步幅在这里就大显神威了:我们不再将一个元素与下一个元素之间的步幅指定为1,而是将其设定为2,即跳两步。(顺便一提,这就是其被称为「步幅(stride)」的原因:如果我们将索引看作是在布局上行走,步幅就指定了我们每次迈步时向前多少位置。)

步幅表示实际上可以让你表示所有类型的张量域段;如果你想了解各种不同的可能做法,请参阅

1
查看完整版本: 万字综述,核心开发者全面解读PyTorc