绑定手机号
获取验证码
确认绑定
提问
0/255
提问
订阅开课提醒需关注服务号
回答成功
知道了
扫码关注智猩猩服务号登录
请使用微信扫描二维码
扫描二维码分享给微信好友
您已订阅成功,有新课程,我们将第一时间提醒您。
知道了
发送提问成功
回答可在
“我的——我的提问”中查看
知道了
失败
欢迎来智东西
关注我们
智东西
车东西
芯东西
智猩猩
0
0
实践torch.fx——基于Pytorch的模型优化量化神器
分类: AI技术
2022-04-22 17:09:45

今天聊聊比较重要torch.fx,也趁着这次机会把之前的torch.fx整理笔记,大致分为三篇,分别为三篇:

  • 什么是torch.fx

  • 基于torch.fx做量化

  • 基于torch.fx量化部署到TensorRT

本文第十五,主要介绍torch.fx和基本使用方法。废话不多说,直接开始吧!

什么是Torch.FX

torch.fxPytorch 1.8出来的一套工具或者说一个库,是做python-to-python code transformation,大意就是可以把pytorch中的python前向代码转换为你想要的样子,官方介绍如下:

我们应用了ch.f,这是一个完全用Python编写的PyTorch程序和库,并针对高生产力ML从业者进行了优化。上面可以看FX的论文,FX的PRACTICAL PROGRAM CAPTURE AND TRANSFORM DEEP LEARNING IN PYTHON [1]相对而言,知上也有不错的解读[2],这里就不再重复讲述个人了。

核心的关键词是program capturetransformation library,这两个很重要的概念。

FX怎么用呢?今天了解一下,我们定义了一个pytorch.nn.module

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(34))
        self.linear = torch.nn.Linear(45)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

很简单地继承于的模块(最开始的pytorch的应该都懂)。其中前向转向函数也记录了这个模块的具体操作逻辑torch.nn.Module

如果想把这个模块中转向中的部分操作逻辑self.linear(x + self.param).clamp(min=0.0, max=1.0)clamp部分替换为sigmoid,应该怎么搞呢?

可以操作这些代码,或者说你直接写了很多这样的模块,或者说你在本地写了很多实验了(如果在中改不改),再去比较无聊了。

这时候就需要修改FX,不需要我们自己手动设置这个修改前锋的规则),只需要设置好,使用,然后进入torch.fx这个实例模型,跑一下代码。MyModuleself.linear(x + self.param).sigmoid()

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
# 打印查看FX的IR
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param "#users=1] = call_function[target=operator.add"), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add, "#users=1] = call_module[target=linear"), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear, "#users=1] = call_method[target=clamp"), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""


# Code generation - valid Python code
# 通过FX生成的代码,可以视为module中的forward代码
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

这样,FX会帮助你修改这个模块,并且修改这个好的model就和平时一样使用就可以了,注意这里,FX capture了你写的转发代码,然后进行了transform,修改了其中的操作。

这只是很简单很简单的fx的一个功能,当然我们还可以通过fx:

  • 融合两个op,比如conv和bn

  • 去掉某个操作

  • 替换某些操作

  • 在某些操作后插入一些操作或其他操作

等等。

可能大家也会疑惑,这些操作不是很像AI编译器中的PASS ,而操作对象也是神经网络这种DAG(有无环图)。其实吧,FX你可以理解为是编译器,不过这个编译器为什么最终会产生一个说文件python->python的规则,不同的是,最终的结果还是基于 Pytorch 的 python 代码,也不是一直Python-to-Python (or Module-to-Module) transformation toolkitcompiler

FX当前API已经稳定(在torch-1.10正式发布),使用起来历史包袱不大。

fx的官方介绍:

  • https://pytorch.org/docs/stable/fx.html

torch.fx 与量化的关系

FX的出现的第一利好是基于Pytorch的量化工具,也是我介绍FX的一个原因。借助FX可以很方便地对pytorch模型进行操作,之前的商汤工具就一个基于fx的量化基准[3 ]

对于实际而言,无论是PTQ(需要引入观察操作来增加每一层的激活以及权重分发)QTA(需要插入计算到类似来模拟),都会涉及到fx的功能。所以如果想基于Pytorch框架来做量化,建议直接上手torch.fx

fxpytorch-1.10已经中也拿到了等表状态,已经有API,我的ch模型,使用的版本是Pytorch-1.10TensorRT-8.2

fx 部分自己修改了下源码,添加了一些操作。这里我是直接把最新发布的 pytorch 中的 fx 部分提取出来,然后 pip 安装torch-1.10.0+cu113-cp38-cp38-linux_x86_64.whl,用于食用。

与 TorchScript 的区别

真正开始torch.fx的时候也有这两个最终出现有啥,都是先解析、然后生成几个IR版、基于IR做优化,最后生成一个优化后的模型,然后是python的版本一个是当你用FX的时候,会发现简单的FX并没有什么改变,FX的性质更适合于对这样的功能进行修改,增加操作,比如运行)而torchscript更适合于统计当前模型的性能,并且可以脱离python,仅在C++环境。

借一句大佬的回答:

torch.fx 与 TorchScript 的不同之处在于它是 PyTorch 代码的 Python 到 Python 转换的平台。另一方面,TorchScript 更针对将 PyTorch 程序移出 Python 以进行部署。从这个意义上说,FX 和 TorchScript 是相互正交的,甚至可以相互组合(例如用 FX 转换 PyTorch 程序,然后导出到 TorchScript 进行部署)。

大意就是转换后的FX模式做Python2Python的转换,因为Torchscript是为了部署(脱离Python这个,在C++中运行)而做转换。torchscript,是正交的。

Python 到 Python?

注意是,由 Python 到运行式,FX 并没有和我们使用 Python 生成的直接生成的代码nn.Module的网络,需要使用 Python 生成的直接代码,可以像不同的eager modeFXtorchscript一样,是另一个套runtime(我们跑torchscript的时候其实调用的是一个VM,也就是虚拟机,通过VM在C++中跑通过torchscript导出的模型)。

所以fx转换后的模型类型和nn.Module一毛一样,所以对nn.Module能做的,对转换后的模型也能做,我们可以连续套娃:

  • 自己写的Module -> fx后还是Module -> 连续fx变化 -> 得到最终的fx模型

FX的IR和Jit的IR

这俩IR不一样,FX的IR两个Jit的说法,有优点:

  • FX 运行时间可以集成地中的 Python 地,因为它是因为 FXprograim representationsjit.trace考试。

  • FX的Graph有什么区别,它的torch.nn.moduleIR和没有什么区别,所以说起来用起来更简单,还能提升提升。

这里列一下FX的IR很简单,只有六种,大概就是函数、提取属性、获取输入输出等:

  • placeholder表示函数输入。name属性指定此值将采用的名称。target同样是参数的名称。args持有:1) 没有,或 2) 表示函数输入的默认参数的单个参数。kwargs是不在乎。占位符对应x于图形打印输出中的函数参数(例如 )。

  • get_attr从模块层次结构中检索参数。name同样是获取结果分配给的名称。target是参数在模块层次结构中位置的完全限定名称。args并且kwargs不在乎

  • call_function对某些值应用自由函数。name同样是要分配的值的名称。target是要应用的函数。argskwargs表示函数的参数,遵循 Python 调用约定

  • call_module将模块层次结构的forward()方法中的模块应用于给定的参数。name和以前一样。target是模块层次结构中要调用的模块的完全限定名称。argskwargs表示调用模块的参数,包括 self 参数

  • call_method调用一个值的方法。name是一样的。target是应用于self参数的方法的字符串名称。argskwargs表示调用模块的参数,包括 self 参数

  • output在其args[0]属性中包含跟踪函数的输出。这对应于图表打印输出中的“return”语句。

相对torchscript的IR,FX的可就更简单了,我们理解起来使用起来很简单。

符号示踪寻找

回到一个故事就是故事的那段代码,其中有一行是symbolic_traced : torch.fx.GraphModule = symbolic_trace(module),这里的核心功能symbolic_trace,也是FX解析、转换模型的起点。

@compatibility(is_backward_compatible=True)
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None,
                   enable_cpatching: bool = False)
 -> GraphModule:

    """
    Symbolic tracing API

    Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
    constructed by recording operations seen while tracing through ``root``.

    ...
    """

    tracer = Tracer(enable_cpatching=enable_cpatching)
    graph = tracer.trace(root, concrete_args)
    name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
    return GraphModule(tracer.root, graph, name)

首先会创建一个Tracer类然后使用成员函数trace我们的torch.nn.Module。我们在trace这个模型之后,就可以对这个模型进行修改了:

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer)
 -> torch.nn.Module:

    # Step 1: Acquire a Graph representing the code in `m`
    # 使用 Tracer 类对象去trace模型 m
    # 这边是拆开了,这个transform函数就是实现torch.fx.symbolic_trace的功能
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: 这里就可以任意修改模型了,这也是重点
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

修改之后的模型可以直接来用,也可以这样通过graph_module.to_folder,把这个模型拿出来当做自己的模块去使用(这个之后说的)。整体的流程大概就是:

符号追踪 -> 中间表示 -> 转换 -> Python 代码生成。

各自的功能为:

  • 象征性的

符号跟踪器执行 Python 代码的“符号执行”。它通过代码提供虚假值,称为代理。记录对这些代理的操作。有关符号跟踪的更多信息可以在 symbolic_trace() 和 Tracer 文档中找到。

  • 中间表示

中间表示是符号跟踪期间记录的操作的容器。它由表示函数输入、调用点(函数、方法或 torch.nn.Module 实例)和返回值的节点列表组成。有关 IR 的更多信息可以在 Graph 的文档中找到。IR 是应用转换的格式。

  • Python代码生成

Python 代码生成使 FX 成为 Python 到 Python(或模块到模块)的转换工具包。对于每个 Graph IR,我们可以创建匹配 Graph 语义的有效 Python 代码。此功能包含在 GraphModule 中,它是一个 torch.nn.Module 实例,它包含一个 Graph 以及从 Graph 生成的 forward 方法。

上面就是FX的三个核心功能。

Proxy/Retracingsymbolic traceProxy/Retracing核心。

代理对象是节点包装器,在符号跟踪期间流经程序,并将它们接触到的所有操作(火炬函数调用、方法调用、运算符)记录到不断增长的 FX Graph 中。

如果您正在执行图形转换,您可以将自己的 Proxy 方法包装在原始节点周围,以便您可以使用重载的运算符向图形添加其他内容。

相关结构

FX主要的和,就是其中。这些记录可以为中的一个网络结构,在网络中的一个关键节点(比如节点、 Graphrelu 、add、concat等),这些记录了的方法和输入输出信息,然后可以串起来网络的逻辑。GraphModuleA Graph is a data structure that represents a method on a GraphModuleGraph