今天聊聊比较重要的torch.fx
,也趁着这次机会把之前的torch.fx
整理笔记,大致分为三篇,分别为三篇:
什么是torch.fx
基于torch.fx做量化
基于torch.fx量化部署到TensorRT
本文第十五,主要介绍torch.fx和基本使用方法。废话不多说,直接开始吧!
什么是Torch.FX
torch.fx
是Pytorch 1.8
出来的一套工具或者说一个库,是做python-to-python code transformation
,大意就是可以把pytorch中的python前向代码转换为你想要的样子,官方介绍如下:
我们应用了ch.f,这是一个完全用Python编写的PyTorch程序和库,并针对高生产力ML从业者进行了优化。上面可以看FX的论文,FX的PRACTICAL PROGRAM CAPTURE AND TRANSFORM DEEP LEARNING IN PYTHON [1]相对而言,知上也有不错的解读[2],这里就不再重复讲述个人了。
核心的关键词是program capture
和transformation library
,这两个很重要的概念。
FX怎么用呢?今天了解一下,我们定义了一个pytorch.nn.module
:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
很简单地继承于的模块(最开始的pytorch的应该都懂)。其中前向转向函数也记录了这个模块的具体操作逻辑。torch.nn.Module
如果想把这个模块中转向中的部分操作逻辑self.linear(x + self.param).clamp(min=0.0, max=1.0)
的clamp
部分替换为sigmoid
,应该怎么搞呢?
可以操作这些代码,或者说你直接写了很多这样的模块,或者说你在本地写了很多实验了(如果在中改不改),再去比较无聊了。
这时候就需要修改FX,不需要我们自己手动设置这个修改前锋的规则),只需要设置好,使用,然后进入torch.fx
这个实例模型,跑一下代码。:MyModule
self.linear(x + self.param).sigmoid()
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
# 打印查看FX的IR
print(symbolic_traced.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param "#users=1] = call_function[target=operator.add"), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add, "#users=1] = call_module[target=linear"), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear, "#users=1] = call_method[target=clamp"), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
# 通过FX生成的代码,可以视为module中的forward代码
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
这样,FX会帮助你修改这个模块,并且修改这个好的model
就和平时一样使用就可以了,注意这里,FX capture了你写的转发代码,然后进行了transform,修改了其中的操作。
这只是很简单很简单的fx的一个功能,当然我们还可以通过fx:
融合两个op,比如conv和bn
去掉某个操作
替换某些操作
在某些操作后插入一些操作或其他操作
等等。
可能大家也会疑惑,这些操作不是很像AI编译器中的PASS ,而操作对象也是神经网络这种DAG(有无环图)。其实吧,FX你可以理解为是编译器,不过这个编译器为什么最终会产生一个说文件python->python
的规则,不同的是,最终的结果还是基于 Pytorch 的 python 代码,也不是一直Python-to-Python (or Module-to-Module) transformation toolkit
是compiler
。
FX当前API已经稳定(在torch-1.10正式发布),使用起来历史包袱不大。
fx的官方介绍:
https://pytorch.org/docs/stable/fx.html
torch.fx 与量化的关系
FX的出现的第一利好是基于Pytorch的量化工具,也是我介绍FX的一个原因。借助FX可以很方便地对pytorch模型进行操作,之前的商汤工具就一个基于fx的量化基准[3 ]。
对于实际而言,无论是PTQ(需要引入观察操作来增加每一层的激活以及权重分发)QTA(需要插入计算到类似来模拟),都会涉及到fx的功能。所以如果想基于Pytorch框架来做量化,建议直接上手torch.fx
。
fxpytorch-1.10
已经中也拿到了等表状态,已经有API,我的ch模型。,使用的版本是Pytorch-1.10
和TensorRT-8.2
。
fx 部分自己修改了下源码,添加了一些操作。这里我是直接把最新发布的 pytorch 中的 fx 部分提取出来,然后 pip 安装torch-1.10.0+cu113-cp38-cp38-linux_x86_64.whl
,用于食用。
与 TorchScript 的区别
真正开始torch.fx
的时候也有这两个最终出现有啥,都是先解析、然后生成几个IR版、基于IR做优化,最后生成一个优化后的模型,然后是python的版本一个是当你用FX的时候,会发现简单的FX并没有什么改变,FX的性质更适合于对这样的功能进行修改,增加操作,比如运行)而torchscript更适合于统计当前模型的性能,并且可以脱离python,仅在C++环境。
借一句大佬的回答:
torch.fx 与 TorchScript 的不同之处在于它是 PyTorch 代码的 Python 到 Python 转换的平台。另一方面,TorchScript 更针对将 PyTorch 程序移出 Python 以进行部署。从这个意义上说,FX 和 TorchScript 是相互正交的,甚至可以相互组合(例如用 FX 转换 PyTorch 程序,然后导出到 TorchScript 进行部署)。
大意就是转换后的FX模式做Python2Python的
转换,因为Torchscript
是为了部署(脱离Python这个,在C++中运行)而做转换。torchscript
,是正交的。
Python 到 Python?
注意是,由 Python 到运行式,FX 并没有和我们使用 Python 生成的直接生成的代码nn.Module
的网络,需要使用 Python 生成的直接代码,可以像不同的eager mode
FXtorchscript
一样,是另一个套runtime(我们跑torchscript的时候其实调用的是一个VM,也就是虚拟机,通过VM在C++中跑通过torchscript导出的模型)。
所以fx转换后的模型类型和nn.Module
一毛一样,所以对nn.Module
能做的,对转换后的模型也能做,我们可以连续套娃:
自己写的Module -> fx后还是Module -> 连续fx变化 -> 得到最终的fx模型
FX的IR和Jit的IR
这俩IR不一样,FX的IR两个Jit的说法,有优点:
FX 运行时间可以集成地中的 Python 地,因为它是因为 FX
prograim representations
会jit.trace
考试。FX的Graph有什么区别,它的
torch.nn.module
IR和没有什么区别,所以说起来用起来更简单,还能提升提升。
这里列一下FX的IR很简单,只有六种,大概就是函数、提取属性、获取输入输出等:
placeholder
表示函数输入。该name
属性指定此值将采用的名称。target
同样是参数的名称。args
持有:1) 没有,或 2) 表示函数输入的默认参数的单个参数。kwargs
是不在乎。占位符对应x
于图形打印输出中的函数参数(例如 )。get_attr
从模块层次结构中检索参数。name
同样是获取结果分配给的名称。target
是参数在模块层次结构中位置的完全限定名称。args
并且kwargs
不在乎call_function
对某些值应用自由函数。name
同样是要分配的值的名称。target
是要应用的函数。args
并kwargs
表示函数的参数,遵循 Python 调用约定call_module
将模块层次结构的forward()
方法中的模块应用于给定的参数。name
和以前一样。target
是模块层次结构中要调用的模块的完全限定名称。args
并kwargs
表示调用模块的参数,包括 self 参数。call_method
调用一个值的方法。name
是一样的。target
是应用于self
参数的方法的字符串名称。args
并kwargs
表示调用模块的参数,包括 self 参数output
在其args[0]
属性中包含跟踪函数的输出。这对应于图表打印输出中的“return”语句。
相对torchscript的IR,FX的可就更简单了,我们理解起来使用起来很简单。
符号示踪寻找
回到一个故事就是故事的那段代码,其中有一行是symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
,这里的核心功能symbolic_trace
,也是FX解析、转换模型的起点。
@compatibility(is_backward_compatible=True)
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None,
enable_cpatching: bool = False) -> GraphModule:
"""
Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
constructed by recording operations seen while tracing through ``root``.
...
"""
tracer = Tracer(enable_cpatching=enable_cpatching)
graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)
首先会创建一个Tracer
类然后使用成员函数trace
我们的torch.nn.Module
。我们在trace这个模型之后,就可以对这个模型进行修改了:
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# Step 1: Acquire a Graph representing the code in `m`
# 使用 Tracer 类对象去trace模型 m
# 这边是拆开了,这个transform函数就是实现torch.fx.symbolic_trace的功能
graph : torch.fx.Graph = tracer_class().trace(m)
# Step 2: 这里就可以任意修改模型了,这也是重点
graph = ...
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
修改之后的模型可以直接来用,也可以这样通过graph_module.to_folder
,把这个模型拿出来当做自己的模块去使用(这个之后说的)。整体的流程大概就是:
符号追踪 -> 中间表示 -> 转换 -> Python 代码生成。
各自的功能为:
象征性的
符号跟踪器执行 Python 代码的“符号执行”。它通过代码提供虚假值,称为代理。记录对这些代理的操作。有关符号跟踪的更多信息可以在 symbolic_trace() 和 Tracer 文档中找到。
中间表示
中间表示是符号跟踪期间记录的操作的容器。它由表示函数输入、调用点(函数、方法或 torch.nn.Module 实例)和返回值的节点列表组成。有关 IR 的更多信息可以在 Graph 的文档中找到。IR 是应用转换的格式。
Python代码生成
Python 代码生成使 FX 成为 Python 到 Python(或模块到模块)的转换工具包。对于每个 Graph IR,我们可以创建匹配 Graph 语义的有效 Python 代码。此功能包含在 GraphModule 中,它是一个 torch.nn.Module 实例,它包含一个 Graph 以及从 Graph 生成的 forward 方法。
上面就是FX的三个核心功能。
Proxy/Retracing
是symbolic trace
的Proxy/Retracing
核心。
代理对象是节点包装器,在符号跟踪期间流经程序,并将它们接触到的所有操作(火炬函数调用、方法调用、运算符)记录到不断增长的 FX Graph 中。
如果您正在执行图形转换,您可以将自己的 Proxy 方法包装在原始节点周围,以便您可以使用重载的运算符向图形添加其他内容。
相关结构
FX主要的和,就是其中。这些记录可以为中的一个网络结构,在网络中的一个关键节点(比如节点、 Graph
relu 、add、concat等),这些记录了的方法和输入输出信息,然后可以串起来网络的逻辑。GraphModule
A Graph is a data structure that represents a method on a GraphModule
Graph