极市导读
本文对torch中的jit模块进行了详细的解读,主要介绍了jit的两种到处方式的使用例子、IR的形式、导出IR的两种方式的源码解读以及对IR优化的简单介绍。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
import torchvision.models as modelsresnet = torch.jit.trace(models.resnet18(), torch.rand(1,3,224,224))output=resnet(torch.ones(1,3,224,224))print(output)output=resnet(torch.ones(1,3,224,224))resnet.save('resnet.pt')
graph(%self.1 : __torch__.torchvision.models.resnet.___torch_mangle_194.ResNet,
%input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu)):
%1472 : __torch__.torch.nn.modules.linear.___torch_mangle_193.Linear = prim::GetAttr[name="fc"](%self.1)
%1469 : __torch__.torch.nn.modules.pooling.___torch_mangle_192.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1)
%1468 : __torch__.torch.nn.modulesjieshao.container.___torch_mangle_191.Sequential = prim::GetAttr[name="layer4"](%self.1)
%1422 : __torch__.torch.nn.modules.container.___torch_mangle_175.Sequential = prim::GetAttr[name="layer3"](%self.1)
....
%1556 : Tensor = prim::CallMethod[name="forward"](%1469, %1555)
%1202 : int = prim::Constant[value=1]()
%1203 : int = prim::Constant[value=-1]()
%input : Float(1:512, 512:1, requires_grad=1, device=cpu) = aten::flatten(%1556, %1202, %1203)
%1557 : Tensor = prim::CallMethod[name="forward"](%1472, %input)
return (%1557)
torch.jit.trace,参数为你需要导出的 model,以及合法输入 input,其大概原理恰如其名,便是跟踪模型 inference 过程,将模型对输入进行的操作逐一记录下来,并对应到 IR 的操作,从而得到原本模型 forward 的 IR。
if x > 2.0:r = torch.tensor(1.0)else:r = torch.tensor(2.0)return rftrace = torch.jit.trace(test, (torch.ones(1)))y = torch.ones(1) * 5print(ftrace(y))# results: tensor(2.)# 因为输入只走了的分支elsescript
@torch.jit.scriptdef foo(x, y):if x.max() > y.max():r = xelse:r = yreturn rprint(foo.graph)print(foo(torch.Tensor([0]), torch.Tensor([1])))print(foo(torch.Tensor([1]), torch.Tensor([0])))graph(%x.1 : Tensor,%y.1 : Tensor):%3 : Tensor = aten::max(%x.1)%5 : Tensor = aten::max(%y.1)# 可以看到确实捕捉到了控制语句,%6 : Tensor = aten::gt(%3, %5)%7 : bool = aten::Bool(%6)%r : Tensor = prim::If(%7)block0():-> (%x.1)block1():-> (%y.1)return (%r)tensor([1.])tensor([1.])
torch.jit.script,其转换方式跟 trace 是完全不同的思路,script 直接解析你的 PyTorch 代码,通过语法分析解析你的逻辑为一棵语法树,然后转换为中间表示 IR。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# torch.jit.trace produces a ScriptModule's conv1 and conv2
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
scripted_module = torch.jit.script(MyModule())
TORCH_LIBRARY(my_ops, m) {
m.def("warp_perspective", warp_perspective);
}
更多可以参考官方教程
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with,
Value,
Type等
# %x.1 valuegraph(%x.1 : Tensor,%y.1 : Tensor):# aten::max 就是一个Node# Tensor: Type-TensorType%3 : Tensor = aten::max(%x.1)%5 : Tensor = aten::max(%y.1)%6 : Tensor = aten::gt(%3, %5)%7 : bool = aten::Bool(%6)%r : Tensor = prim::If(%7)# Blocksblock0():-> (%x.1)block1():-> (%y.1)return (%r)
func,
example_inputs,
optimize=None,
check_trace=True,
check_inputs=None,
check_tolerance=1e-5,
strict=True,
_force_outplace=False,
_module_class=None,
_compilation_unit=_python_cu,
):
# 发现是nn.Module instacene forward, 追踪forward
if isinstance(func, torch.nn.Module):
return trace_module(
func,
{"forward": example_inputs},
None,
check_trace,
wrap_check_inputs(check_inputs),
check_tolerance,
strict,
_force_outplace,
_module_class,
)
# 传进来的是某个module instance的forward
if (
hasattr(func, "__self__")
and isinstance(func.__self__, torch.nn.Module)
and func.__name__ == "forward"
):
return trace_module(
func.__self__,
{"forward": example_inputs},
None,
check_trace,
wrap_check_inputs(check_inputs),
check_tolerance,
strict,
_force_outplace,
_module_class,
)
# 一个查找变量名的接口
var_lookup_fn = _create_interpreter_name_lookup_fn(0)
# C++ 入口
traced = torch._C._create_function_from_trace(
name, func, example_inputs, var_lookup_fn, strict, _force_outplace
)
# 检查traced 与 原func是否有差异
if check_trace:
if check_inputs is not None:
_check_trace(
check_inputs,
func,
traced,
check_tolerance,
strict,
_force_outplace,
False,
_module_class,
)
else:
_check_trace(
[example_inputs],
func,
traced,
check_tolerance,
strict,
_force_outplace,
False,
_module_class,
)
return traced
traced = torch._C._create_function_from_trace(
name, func, example_inputs, var_lookup_fn, strict, _force_outplace
)
std::pair<std::shared_ptr<TracingState>, Stack> trace(Stack inputs,const std::function<Stack(Stack)>& traced_fn,std::function<std::string(const Variable&)> var_name_lookup_fn,bool strict,bool force_outplace,Module* self) {try {auto state = std::make_shared<TracingState>();# setTracingState 将state 这个实例set下来,在之后计算节点get出来insert计算过程setTracingState(state);#state这个数据结构会在forward过程中存储trace到的计算过程if (self) {Value* self_value = state->graph->insertInput(0, "self")->setType(self->_ivalue()->type());gatherParametersAndBuffers(state, self_value, *self, {"__module"});}for (IValue& input : inputs) {input = addInput(state, input, input.type(), state->graph->addInput());}auto graph = state->graph;# 将python中的变量名解析函数绑定下来getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn);getTracingState()->strict = strict;getTracingState()->force_outplace = force_outplace;# 开始forward,在计算发生时,会把计算记录到state中auto out_stack = traced_fn(inputs);// Exit a trace, treating 'out_stack' as the outputs of the trace. These// are the variables whose values will be computed upon subsequent// invocations of the trace.size_t i = 0;for (auto& output : out_stack) {// NB: The stack is in "reverse" order, so when we pass the diagnostic// number we need to flip it based on size.state->graph->registerOutput(state->getOutput(output, out_stack.size() - i));i++;}setTracingState(nullptr);if (getInlineEverythingMode()) {Inline(*graph);}FixupTraceScopeBlocks(graph, self);NormalizeOps(graph);return {state, out_stack};} catch (...) {tracer::abandon();throw;}}
Operator createOperatorFromC10_withTracingHandledHere(const c10::OperatorHandle& op) {return Operator(op, [op](Stack& stack) {const auto input_size = op.schema().arguments().size();const auto output_size = op.schema().returns().size();Node* node = nullptr;std::shared_ptr<jit::tracer::TracingState> tracer_state;// trace the input before unwrapping, otherwise we may lose// the input informationif (jit::tracer::isTracing()) {# 获取 tracer_statetracer_state = jit::tracer::getTracingState();auto symbol = Symbol::fromQualString(op.schema().name());const auto& graph = tracer::getTracingState()->graph;node = graph->create(symbol, 0);tracer::recordSourceLocation(node);const auto& args = op.schema().arguments();int i = 0;# 记录argsfor (auto iter = stack.end() - input_size; iter != stack.end();++iter, ++i) {// TODO we need to refactor graph APIs (e.g., addInputs)// appropriately; after that, we can get rid of the giant if-else// block we will clean this tech debt together in the following PRsauto type = args[i].type();if (type->kind() == TypeKind::OptionalType) {if (iter->isNone()) {Value* none = graph->insertNode(graph->createNone())->output();node->addInput(none);continue;} else {type = type->expect<OptionalType>()->getElementType();}}if (type->isSubtypeOf(TensorType::get())) {AT_ASSERT(iter->isTensor());tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());} else if (type->kind() == TypeKind::FloatType) {AT_ASSERT(iter->isDouble());tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());} else if (type->kind() == TypeKind::IntType) {AT_ASSERT(iter->isInt());tracer::addInputs(node, args[i].name().c_str(), iter->toInt());} else if (type->kind() == TypeKind::BoolType) {AT_ASSERT(iter->isBool());tracer::addInputs(node, args[i].name().c_str(), iter->toBool());} else if (type->kind() == TypeKind::StringType) {AT_ASSERT(iter->isString());tracer::addInputs(node, args[i].name().c_str(), iter->toStringRef());} else if (type->kind() == TypeKind::NumberType) {tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());} else if (type->kind() == TypeKind::ListType) {const auto& elem_type = type->expect<ListType>()->getElementType();if (elem_type->isSubtypeOf(TensorType::get())) {AT_ASSERT(iter->isTensorList());auto list = iter->toTensorVector();tracer::addInputs(node, args[i].name().c_str(), list);} else if (elem_type->kind() == TypeKind::FloatType) {AT_ASSERT(iter->isDoubleList());// NB: now, tracer doesn't support tracing double list. We add// special handling here, since in our case, we assume that all the// doubles in the list are constantsauto value = iter->toDoubleVector();std::vector<Value*> info(value.size());for (size_t value_index = 0; value_index < value.size();++value_index) {info[value_index] = graph->insertConstant(value[value_index]);tracer::recordSourceLocation(info[value_index]->node());}node->addInput(graph->insertNode(graph->createList(jit::FloatType::get(), info))->output());} else if (elem_type->kind() == TypeKind::IntType) {AT_ASSERT(iter->isIntList());tracer::addInputs(node, args[i].name().c_str(), iter->toIntVector());} else if (elem_type->kind() == TypeKind::BoolType) {AT_ASSERT(iter->isBoolList());tracer::addInputs(node, args[i].name().c_str(), iter->toBoolList().vec());} else {throw std::runtime_error("unsupported input list type: " + elem_type->str());}} else if (iter->isObject()) {tracer::addInputs(node, args[i].name().c_str(), iter->toObject());} else {throw std::runtime_error("unsupported input type: " + type->str());}}# node嵌入graphgraph->insertNode(node);jit::tracer::setTracingState(nullptr);}
def script(obj, optimize=None, _frames_up=0, _rcb=None):# fucntion 分支if hasattr(obj, "__script_if_tracing_wrapper"):obj = obj.__original_fn_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)# 检查重载_check_directly_compile_overloaded(obj)# 是否之前被script过了maybe_already_compiled_fn = _try_get_jit_cached_function(obj)if maybe_already_compiled_fn:return maybe_already_compiled_fn# 得到ast语法树ast = get_jit_def(obj, obj.__name__)if _rcb is None:_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)#c++ 入口,根据ast得到irfn = torch._C._jit_script_compile(ast, _rcb, get_default_args(obj))# Forward docstrings= obj.__doc__# cache起来fn)return fn
def get_jit_def(fn, def_name, self_name=None):# 得到源代码的一些信息file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())sourcelines = normalize_source_lines(sourcelines)source = dedent_src ''.join(sourcelines)# dedent_src 为包含了要script函数的字符串dedent_src = dedent(source)# 调用python ast包将字符串解析为Python的astpy_ast = ast.parse(dedent_src)# 得到python类型注释type_line = torch.jit.annotations.get_type_line(source)#ctx中包含了函数所有原信息ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)fn_def = py_ast.body[0]# build_def将python 的ast 转化为torchjit 使用的ast格式return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
import astfunc_def= \"""def test(a):a = a + 2return a + 1"""results = ast.parse(func_def)
Binop具体为一个
Add,left 是
Name类型,
id为
`a,right是
Num,也就是2,这个
Binop即解析的
a = a + 2。
build_def是如何将 Python 的 ast 转化为自己需要的 ast 的。
buid_def
def build_def(ctx, py_def, type_line, def_name, self_name=None):
....
return Def(Ident(r, def_name),
decl,
build_stmts(ctx, body))
ctx 包含 source code 所有信息, body 是 Python ast 解析结果,那么
build_stmts中应该包含我们想要的答案。
a+2为例看会怎么转换,这部分可见 frontend.py
StmtBuilder
from torch._C._jit_tree_views import (ClassDef, Ident, Stmt, Decl, Def, Var,EmptyTypeAnnotation, Param, ExprStmt, Assign,Delete, Return, Raise, Assert, AugAssign, While,For, If, Pass, Break, Continue, Apply, Dots, Select,TrueLiteral, FalseLiteral, NoneLiteral, Starred,ListLiteral, TupleLiteral, DictLiteral, Const,StringLiteral, ListComp, Attribute, BinOp, UnaryOp,SliceExpr, Subscript, TernaryIf, With, WithItem, Property,DictComp,)# jit中定义的ast基本结构def build_stmts(ctx, stmts):#发现其调用了`build_stmt`stmts = [build_stmt(ctx, s) for s in stmts]return list(filter(None, stmts))#`build_stmt` 是一个StmtBuilder()的instancebuild_stmt = StmtBuilder()build_expr = ExprBuilder()class Builder(object):def __call__(self, ctx, node):# 可见会根据解析出的ast的类型返回相应的build方法,从截图可以看到`a+2`是一个`Assign`类型# 因此会调用build_Assignmethod = getattr(self, 'build_' + node.__class__.__name__, None)if method is None:raise UnsupportedNodeError(ctx, node)return method(ctx, node)class StmtBuilder(Builder):def build_Assign(ctx, stmt):# 截图可以看到stmt.value是一个Binop# build_expr是ExprBuilder的INSTANCE,其会调用`build_BinOp`rhs = build_expr(ctx, stmt.value)lhs = [build_expr(ctx, x) for x in stmt.targets]return Assign(lhs, rhs)def build_Expr(ctx, stmt):# Binopvalue = stmt.valueif value.__class__.__name__ == 'Str':# If a statement is a string literal expression,# then it is a docstring. Just ignore it.return Noneelse:return ExprStmt(build_expr(ctx, value))class ExprBuilder(Builder):binop_map = {ast.Add: '+',ast.Sub: '-',ast.Mult: '*',ast.Div: '/',ast.Pow: '**',ast.Mod: '%',ast.FloorDiv: '//',ast.BitAnd: '&',ast.BitXor: '^',ast.BitOr: '|',ast.LShift: '<<',ast.RShift: '>>',}def build_BinOp(ctx, expr):#expr.left是个`Name`调用build_Namelhs = build_expr(ctx, expr.left)rhs = build_expr(ctx, expr.right)op = type(expr.op)# 转化为约定的代表运算类型的string 符号op_token = ExprBuilder.binop_map.get(op)return BinOp(op_token, lhs, rhs)
(def
(ident test)
(decl
(list
(param
(ident a)
(option)
(option)
(False)))
(option))
(list
(assign
(list (variable (ident a)))
(option
(+
(variable (ident a))
(const 2)))
(option))
(return
(+
(variable (ident a))
(const 1)))))
static StrongFunctionPtr script_compile_function(
const c10::QualifiedName& name,
const Def& def,
const FunctionDefaults& defaults,
const ResolutionCallback& rcb) {
auto cu = get_python_cu();
#看来是get_python_cu这个类中的define函数完成的
auto defined_functions = cu->define(
QualifiedName(name.prefix()),
/*properties=*/{},
/*propResolvers=*/{},
{def},
{pythonResolver(rcb)},
nullptr,
true);
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
auto& defined = defined_functions[0];
defined->setSchema(getSchemaWithNameAndDefaults(
def.range(), defined->getSchema(), def.name().name(), defaults));
StrongFunctionPtr ret(std::move(cu), defined);
didFinishEmitFunction(ret);
return ret;
}
# 发现只是wapper了下CompilationUnit
inline std::shared_ptr<CompilationUnit> get_python_cu() {
return py::module::import("torch.jit._state")
.attr("_python_cu")
.cast<std::shared_ptr<CompilationUnit>>();
}
#关于compilation_unit
#/torch/csrc/jit/api/compilation_unit.h
// for historic reasons, these are defined in ir_emitter.cpp
// Returns the list of Functions just defined.
std::vector<Function*> define(
const c10::optional<c10::QualifiedName>& prefix,
const std::vector<Property>& properties,
const std::vector<ResolverPtr>& propResolvers,
const std::vector<Def>& definitions,
const std::vector<ResolverPtr>&
defResolvers, /* determines how we handle free
variables in each definition*/
// if non-null, the first argument to each def, is bound to this value
const Self* self,
// see [name mangling]
bool shouldMangle = false);
#实现在torch/csrc/jit/frontend/ir_emitter.cpp
std::unique_ptr<Function> CompilationUnit::define(
const c10::optional<QualifiedName>& prefix,
const Def& def,
const ResolverPtr& resolver,
const Self* self,
const std::unordered_map<std::string, Function*>& function_table,
bool shouldMangle) const {
auto _resolver = resolver;
.....
auto creator = [def, _resolver, self](Function& method) {
....
##核心代码to_ir
to_ir(def, _resolver, self, method);
};
auto fn = torch::make_unique<GraphFunction>(
std::move(name), std::make_shared<Graph>(), creator);
return fn;
}
struct to_ir ,其输入中有 def,也就是 ast,_resolver 是 Python 中传过来的解析名字的函数,我们可以在内部找到关键部分
to_ir(
const Def& def,
ResolverPtr resolver_,
const Self* self,
Function& method) // method being constructed
: method(method),
graph(method.graph()),
resolver(std::move(resolver_)),
typeParser_(resolver),
environment_stack(nullptr) {
AT_ASSERT(resolver);
pushFrame(graph->block(), /*starts_def=*/true);
#emitDef 中会调用emitStatements
method.setSchema(emitDef(def, self, graph->block()));
ConvertToSSA(graph);
CanonicalizeModifiedLoops(graph);
NormalizeOps(graph);
runCleanupPasses(graph);
}
private:
#在to_ir 的private中我们可以看到Graph Function这些我们之前介绍的IR的组成部分
Function& method;
std::shared_ptr<Graph> graph;
ResolverPtr resolver;
std::unordered_map<int64_t, Value*> integral_constants;
#emitDef 中会调用emitStatements
FunctionSchema emitDef(const Def& def, const Self* self, Block* block) {
......
// body
auto stmts_list = def.statements();
emitStatements(stmts_list.begin(), stmts_list.end());
........
}
void emitStatements(
List<Stmt>::const_iterator begin,
List<Stmt>::const_iterator end) {
for (; begin != end; ++begin) {
auto stmt = *begin;
ErrorReport::CallStack::update_pending_range(stmt.range());
switch (stmt.kind()) {
case TK_IF:
emitIf(If(stmt));
break;
case TK_WHILE:
emitWhile(While(stmt));
break;
case TK_FOR:
emitFor(For(stmt));
break;
case TK_ASSIGN:
emitAssignment(Assign(stmt));
.................
break;
default:
throw ErrorReport(stmt)
<< "Unrecognized statement kind " << kindToString(stmt.kind());
}
// Found an exit statement in this block. The remaining statements aren't
// reachable so we don't emit them.
if (exit_blocks.count(environment_stack->block()))
return;
}
}
我们可以看到根据stmt.kind(),会进入而各种emit里面,其中一定可以找到
graph->insertNode(graph->create(.....));
类似的操作,对应我们建立IR graph
class MyModule(torch.jit.ScriptModule):
@torch.jit.script_method
def f(self.x):
return x * x
@torch.jit.script_method
def forward(self, x):
return x + self.f(x)
关于script_method
def script_method(fn):
_rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2)
ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule")
#暂时没有script,只是返回包含ast的nametuple
return ScriptMethodStub(_rcb, ast, fn)
ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
1 移除所有script_method属性被(@script_method修饰的方法),确保访问到的是script function
2 修改module的_init_,确保module的self.param或者self.module初始化后立即编译所有的script_method,
从而生成的instance的forward已经被替换
class ScriptMeta(type):
def __init__(cls, name, bases, attrs): # noqa: B902
# cls ScriptMeta的instance,是一个类如ScriptModule
cls._methods: Dict[str, Any] = {}
cls._constants_set = set(getattr(cls, "__constants__", ()))
for base in reversed(bases):
# 还记得吗trace的module也是有一个_methods的属性
for k, v in getattr(base, "_methods", {}).items():
cls._methods[k] = v
base_constants = getattr(base, "_constants_set", set())
cls._constants_set = cls._constants_set.union(base_constants)
# 找到现在所有被@script_method修饰的方法,放到_method,并删除原有attr
# init后之后统一script
for k, v in sorted(attrs.items()):
if isinstance(v, ScriptMethodStub):
delattr(cls, k)
cls._methods[v.original_method.__name__] = v
original_init = getattr(cls, "__init__", lambda self: None)
# 此处实现了init结束后,调用create_script_module进行script
@functools.wraps(original_init)
def init_then_script(self, *args, **kwargs):
# 此处的self为instance
num_methods = len(cls._methods)
original_init(self, *args, **kwargs)
added_methods_in_init = len(cls._methods) > num_methods
if type(self) == cls:
# 选取需要script的method
def make_stubs(module):
cls = type(module)
if hasattr(cls, "_methods"):
return [v for k, v in sorted(cls._methods.items())]
else:
# infer_methods_to_compile 是一个选取要script函数的函数
return infer_methods_to_compile(module)
# 讲所有script_method一块编译为_actual_script_module属性
self.__dict__[
"_actual_script_module"
] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init)
# Delete the Python attributes that now shadow the ScriptModule
# ones, so that __getattr__ and __setattr__ will properly find
# the scripted versions.
concrete_type = self._actual_script_module._concrete_type
for name in concrete_type.get_attributes():
delattr(self, name)
for name, _ in concrete_type.get_modules():
delattr(self, name)
for name in ("_parameters", "_buffers", "_modules"):
delattr(self, name)
cls.__init__ = init_then_script # type: ignore
return super(ScriptMeta, cls).__init__(name, bases, attrs)
class _CachedForward(object):
def __get__(self, obj, cls):
return self.__getattr__("forward") # type: ignore
class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore
def __init__(self):
super(ScriptModule, self).__init__()
forward = _CachedForward()
# 想访问module的attr,返回_actual_script_module的attr
def __getattr__(self, attr):
if "_actual_script_module" not in self.__dict__:
return super(ScriptModule, self).__getattr__(attr)
return getattr(self._actual_script_module, attr)
def __setattr__(self, attr, value):
if "_actual_script_module" not in self.__dict__:
# Unwrap torch.jit.Attribute into a regular setattr + recording
# the provided type in __annotations__.
#
# This ensures that if we use the attr again in `__init__`, it
# will look like the actual value, not an instance of Attribute.
if isinstance(value, Attribute):
if "__annotations__" not in self.__class__.__dict__:
self.__class__.__annotations__ = {}
self.__annotations__[attr] = value.type
value = value.value
return super(ScriptModule, self).__setattr__(attr, value)
setattr(self._actual_script_module, attr, value)
...
def test(x):# Dead code Eliminationfor i in range(1000):y = x + 1for i in range(100):#peephole optimizationx = x.t()x = x.t()return x.sum()opt_test = torch.jit.script(test)s = time()inputs = torch.ones(4,4).cuda()s = time()for i in range(10000):test(inputs)print(time()-s)# 95ss = time()for i in range(10000):opt_test(inputs)print(time()-s)# 0.13sprint(opt_test.graph)print(opt_test.graph_for(inputs))95.138237953186040.13010907173156738graph(%x.1 : Tensor):%22 : None = prim::Constant()%13 : bool = prim::Constant[value=1]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4%10 : int = prim::Constant[value=100]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:19%x : Tensor = prim::Loop(%10, %13, %x.1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4block0(%i : int, %x.10 : Tensor):%x.4 : Tensor = aten::t(%x.10) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:11:12%x.7 : Tensor = aten::t(%x.4) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:12:12-> (%13, %x.7)%23 : Tensor = aten::sum(%x, %22) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11return (%23)graph(%x.1 : Tensor):%1 : None = prim::Constant()%2 : Tensor = aten::sum(%x.1, %1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11return (%2)
GraphExecutor& get_executor() {
return function_->get_executor();
}
GraphExecutor::GraphExecutor(
const std::shared_ptr<Graph>& graph,
std::string function_name)
: pImpl(
IsNewExecutorEnabled()
? dynamic_cast<GraphExecutorImplBase*>(
new ProfilingGraphExecutorImpl(
graph,
std::move(function_name)))
: dynamic_cast<GraphExecutorImplBase*>(
new GraphExecutorImpl(graph, std::move(function_name)))) {}
std::shared_ptr<Graph> GraphExecutor::graph() const {
return pImpl->graph;
}
const ExecutionPlan& GraphExecutor::getPlanFor(
Stack& inputs,
size_t remaining_bailout_depth) {
return pImpl->getPlanFor(inputs, remaining_bailout_depth);
}
std::shared_ptr<GraphExecutorImplBase> pImpl;
.....
关于GraphExecutorImplBase,/torch/csrc/jit/runtime/graph_executor.cpp
const ExecutionPlan& getOrCompile(const Stack& stack) {
.....
auto plan = compileSpec(spec);
}
}
# compileSpec 会返回一个plan
ExecutionPlan compileSpec(const ArgumentSpec& spec) {
auto opt_graph = graph->copy();
GRAPH_DUMP("Optimizing the following function:", opt_graph);
arg_spec_creator_.specializeTypes(*opt_graph, spec);
// Phase 0. Inline functions, then clean up any artifacts that the inliner
// left in that may inhibit optimization
.....
runRequiredPasses(opt_graph);
GRAPH_DEBUG(
"After runRequiredPasses, before ConstantPropagation\n", *opt_graph);
// Phase 2. Propagate detailed information about the spec through the
// graph (enabled more specializations in later passes).
// Shape propagation sometimes depends on certain arguments being
// constants, and constant propagation doesn't need shape
// information anyway, so it's better to run it first.
ConstantPropagation(opt_graph);
GRAPH_DEBUG(
"After ConstantPropagation, before PropagateInputShapes\n", *opt_graph);
PropagateInputShapes(opt_graph);
GRAPH_DEBUG(
"After PropagateInputShapes, before PropagateRequiresGrad\n",
*opt_graph);
PropagateRequiresGrad(opt_graph);
GRAPH_DEBUG(
"After PropagateRequiresGrad, before runOptimization\n", *opt_graph);
// Phase 3. Run differentiable optimizations (i.e. simple graph rewrites
// that we can still execute using autograd).
runOptimization(opt_graph);
.....各种优化
return ExecutionPlan(opt_graph, function_name_);
}
公众号后台回复“目标检测综述”获取目标检测二十年综述下载~
# CV技术社群邀请函 #
备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)
即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群
每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~