TVM的“hello world“基础流程 II

上一篇《TVM的“hello world“基础流程 I》中基于一个最基本的case介绍了TVM中计算的定义与schedule的构建。这篇沿用上一篇中的case,继续介绍接下去的一个重点部分,就是编译。

有了前面构建的schedule之后,接着就需要编译并生成目标代码了。这个工作主要由tvm.build()relay.build()两个函数来完成。它们的区别在于应用目标的范围,前者用于单个算子,后者用于整个网络。由于网络可看作由算子组成,后者会调用前者。本例中是针对单个算子的,因此这里使用的是前者:

tgt = tvm.target.Target(target="llvm", host="llvm")
fadd = tvm.build(s, [A, B, C], tgt, name="vecadd")

其中最主要的build()函数定义在driver/build_module.py文件中。该函数基于给定参数构建出可调用的目标函数。按照官方介绍里的说法,它主要做两个工作 :

  • Lowering:将high-level的循环嵌套结构转换成最终的low-level的IR。
  • Codegen:从low-level的IR生成目标机器代码。

该函数的第一个参数是前面构建出来的schedule,第二个参数是函数的参数列表,第三个参数是target。它提供用于lowering和codegen所需的目标平台信息。代码中对应的Target对象定义在target.*文件中。其构造函数有两个参数,其中第一个参数target指示目标平台的配置。其中的配置项比如:

  • kind: 平台类型,它基本决定了生成的代码是在什么处理器上运行。注册的target kind详细见target_kind.cc,有llvm, c, cuda, nvptx, romc, opencl, metal, vulkan, hexagon等。
  • keys: 如kind是opencl的话,key可以是mali, opencl, gpu。
  • device:对应实际运行的设备,它会添加到keys后面。
  • libs:外部库,如cblas, cudnn, cublas, mkl这些。

另外参数host与target类似,但它用于指示host平台。比如taret平台为cuda的话,毕竟GPU还是不能完全脱离CPU运行,因此还需要host的代码做胶水,如内存分配,kernel启动这些。默认为llvm。

Lowering过程可以单独用tvm.lower()函数完成,如:

m = tvm.lower(s, [A, B, C], name="vecadd")
rt_mod = tvm.build(m, target="llvm")

也可以通过tvm.build()函数完成(因为它一进去就会先调用lower()函数)。lower()函数的主要流程相关代码:

lower(sch, args, name="main", ...) // driver/build_module.py
    // Handle add_lower_pass, if any.
    lower_phases0 = ...
    ...
    // According to the given schedule, form a function (in IRModule).
    mod = form_irmodule(sch, args, ...) // build_module.py
        sch.normalize()
            Schedule::normalize() // schedule_dataflow_rewrite.cc
                InjectInline()
                RebaseNonZeroMinLoop()
                LegalizeInvalidAttach()
        bounds = schedule.InferBound(sch)  
            InferBound() // bound.cc
        stmt = schedule.ScheduleOps(sch, bounds)
            ScheduleOps() // schedule_ops.cc
                body = Stmt()
                // scan init and scan updates
                ...
                for each stage in schedule: // in reverse order
                    body = MakePipeline(stage, dom_map, body, ...)
                SchedulePostProc post_proc
                post_proc.Init(sch)
                return post_proc(body)
        compact = schedule.VerifyCompactBuffer(stmt)
        binds, arg_list = get_binds(args, compact, binds)
        stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, ...)
        // func type: PrimFunc
        func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, ...) // schedule_postproc_to_primfunc.cc
            // Prepare parameters
            ...
            return tie::PrimFunc(params, body, ...)
        // name: vecadd
        func = func.with_attr("global_symbol", name)
        // Set functions
        return tvm.IRModule({name: func})
        
    // Phase 0: InjectPrefetch, StorageFlatten, BF16Legalize, NarrowDataType, Simplify
    pass_list = lower_phase0
    
    // Phase 1: LoopPartition, VectorizeLoop, InjectVirtualThread, InjectDoubleBuffer, StorageRewrite, UnrollLoop
    pass_list += lower_phase1
    
    // Phase 3: Simplify, RemoveNoOp, RewriteUnsafeSelect, HoistIfThenElse
    pass_list += lower_phase2
    
    // Apply the above passes.
    optimize = tvm.transform.Sequential(pass_list)
    mod = optimize(mod) 
    
    // mod type: tvm.ir.module.IRModule
    return mod 

它主要根据参数给的schedule与参数生成对应的IRModule对象(定义在ir/module.h中)。IRModule是软件栈中所有IR变换的基础单元。它维护函数与类型定义。这里的各种pass就是在IRModule上进行并吐出IRModule

在这里插入图片描述

其中几个主要数据结构关系如下:
在这里插入图片描述

lower()函数中有四个阶段,第一个阶段中通过form_irmodule()函数根据给定的schedule生成IRModule对象,然后在这个IRModule对象上应用4轮的pass。这些pass主要分为几个阶段,分别是:

  • Phase 0:使用者自定义的pass。
  • Phase 1:使用者自定义的pass。以及:
    • InjectPrefetch
    • StorageFlatten
    • BF16Legalize
    • NarrowDataType
    • Simplify
  • Phase 2:使用者自定义的pass。以及:
    • LoopPartition
    • VectorizeLoop
    • InjectVirtualThread
    • InjectDoubleBuffer
    • StorageRewrite
    • UnrollLoop
  • Phase 3:使用者自定义的pass。以及:
    • Simplify
    • RemoveNoOp
    • RewriteUnsafeSelect
    • HoistIfThenElse
    • InstrumentBoundCheckers

这此pass其实是编译构建过程中的精华之一。但限于篇幅(其实是我自己也没了解全。。。),以后再进一步讨论。

lower()函数的最后返回经过上面多轮pass优化后的IRModule对象。其中form_irmodule()函数是相对比较复杂的一部分,它主要负责生成最初的IRModule对象,其中几个关键步骤如下:

  1. Schedule::normalize()函数规范化给定的schedule。主要实现在schedule_dataflow_rewrite.cc文件中。它调用以下三个函数。本例比较简单,因此它们实际都没有起什么作用。。。
    1. InjectInline()函数处理算子内联。用到调度原语 compute_inline的话会用到。
    2. RebaseNonZeroMinLoop()函数将循环迭代的最小界置为0。感觉有点canonicalization的意思。
    3. LegalizeInvalidAttach()函数处理在使用调度原语compute_at时且目标迭代又被split或fuse情况下的合法化。
  2. InferBound()函数顾名思义就是边界推导(Bound inference),主要用于推导循环边界。更具体地,就是确定每个IterVar的范围,它返回IterVarRange的映射,即每个循环变量的范围。这个信息在后面的MakeLoopNest()函数中用于确定for循环的范围,和在BuildRealize()函数中设置缓冲的大小。具体可参见官方文档 InferBound Pass
  3. ScheduleOps()函数基于前面经过一些处理后的Schedule对象和推导出来的循环边界产生Stmt对象。它表示一个初始的循环嵌套结构。C++层中的Stmt为所有语句(Statement)的容器。它的子类有LetStmtAttrStmtAssertStmtStoreAllocateSeqStmtIfThenElseEvaluateForWhile等等。该函数会处理schedule的依赖,核心部分是逆向遍历Schedule当中的Stage(对于上面例子中就是先Compute Op,再两个Placeholder Op)。对于每个stage(PlaceholderOp除外),根据其attach type调用相应的逻辑。
    1. 对于上面的例子,Compute Op没有attach在其它计算中,因此它对应Stage的attach type为kGroupRoot,因此这里调用MakePipeline()函数产生Stmt。这步比较关键比较复杂,后面再展开。
    2. 然后通过SchedulePostProc对象(继承自StmtExprMutator)对前面生成的Stmt进行后处理。
  4. get_binds()函数用于绑定buffer。它给每个参数张量分配buffer。如对于上面例子中的A, B, C三个张量,分别通过tvm.tir.decl_buffer()创建buffer并将之与张量绑定。
  5. SchedulePostProcToPrimFunc()函数基于ScheduleOps()产生的Stmt创建PrimFunc对象,它可以被用于TIR优化。PrimFunc代表包含了TIR statement的primitive function,它是low-level的代码表示。
  6. 创建IRModule对象。基于上面生成的对象封装成IRModule对象并返回。一个IRModule可以有多个函数,比较简单的情况下就一个。

上面第ScheduleOps()函数中会调用MakePipeline()函数针对ComputeOp对应Stage,返回一条由Stmt组成的pipeline,其大体流程相关代码如下:

MakePipeline(Stage, unordered_map<IterVar, Range>, Stmt, ...) // schedule_ops.cc
    producer = s->op->BuildProvide(stage, ...) // ComputeOpNode::BuildProvide() in compute_op.cc
        ComputeType ctype = DetectComputeType(this, stage)
        MakeComputeStmt(...) // compute_op.cc
            ComputeLoopNest n = ComputeLoopNest::Create(...) // compute_op.cc
                ComputeLoopNest ret
                // make main loop nest
                ret.main_nest = MakeLoopNest(stage, dom_map, ...) // op_utils.cc
                    vector<vector<Stmt>> nest
                    nest.resize(leaf_iter_vars.size() + 1)
                    for iter_var in leaf_iter_vars:
                        nest[i + 1].emplace_back(For(var, 0, dom->extent, kind, no_op))
                        nest[i + 1].emplace_back(AttrStmt(iv, tir::attr::loop_scope, iv->var, no_op))
                ...
            n.init_nest.emplace_back(MakeIfNest(n.init_predicates))
            n.main_nest.emplace_back(MakeIfNest(n.main_predicates))
            if has reduce_axis:
                ...
            else:
                vector<Stmt> provides
                ...
                // Array<Stmt> -> SeqStmt
                Stmt provide = SeqStmt::Flatten(provides) // stmt.h
                provide = MergeNest(n.main_nest, provide) // ir_utils.cc
                return Substitute(provide, n.main_vmap) // stmt_functor.cc
    Stmt pipeline = producer
    pipeline = s->op->BuildRealize(stage, dom_map, pipeline) 
        // set the sizes of allocated buffers
        BaseComputeOpNode::BuildRealize(stage, realize_map, body) // compute_op.cc
            Stmt realize = body
            realize = tir::ProducerRealize(...)
    pipeline = AttrStmt(s->op, tir::attr::realize_scope, ..., pipeline)
    return pipeline

MakePipeline()函数主要步骤如下:

  1. ComputeOpNode::BuildProvide()函数主要创建ComputeOp对应的循环嵌套对应的那些Stmt对象并串成pipeline。
    1. 首先用DetectComputeType()函数检测计算类型。它遍历当前Stage的所有当前有效IterVar对象,并根据它们的属性判断计算类型,对于上面的简单例子这里为ComputeType::kNormal
    2. 然后根据类型调用相应函数创建Stmt对象。这里对应地是调用MakeComputeStmt()函数。
      1. 根据Stage对象和边界推导的结果通过ComputeLoopNest::Create()函数创建ComputeLoopNest对象。该对象表示循环嵌套,它几个主要成员:

        • init_predicatesmain_predicates:类型为vector<PrimExpr>。表示每个循环的边界判断,调用MakeBoundCheck()函数来生成。
        • init_nestmain_nest:类型为vector<vector<Stmt>>。 其中main_nest是最主要的表示循环嵌套的对象,对于上面的例子,经过split后这里包含两个for循环。
      2. 根据main_predicates创建对应的Stmt(如有),用于在循环中判断该predicate是否成立,并添加到main_nest结构中。

      3. 根据有无reduce axis走不同的path。如果没有的话(如本例),对于ComputeOpbody中的每一个输出,创建ProducerStore对象,再通过MergeNest()函数将之与主嵌套main_nest合并。

      4. 通过Substitute()函数基于main_vmap(在MakeLoopNest()函数中准备)进行替换。

  2. 如schedule中设置了double buffer(如s[A].double_buffer),则添加对应的AttrStmt。它通过增大额外的buffer来达到达到计算与访存的重叠。本例中没用到。
  3. 如传入的consumer有定义且不是no op(指无定义、const init的EvaluateNode,或者是长度为0的SeqStmtNode),则添加SeqStmtproducerconsumer串连起来。本例中也不适用。
  4. 调用BuildRealize()函数。对于每个输出的张量,在pipeline中加入ProducerRealize节点。
  5. 最后,在pipeline中添加AttrStmt节点标注操作的范围,并返回该pipeline。

对于前面vecadd的例子,得到的pipeline大致如下示意图:
在这里插入图片描述

整个lower()函数后完成后的IR(TIR)打印出来如下:

primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024], []),
             A: Buffer(A_2: Pointer(float32), float32, [1024], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  for (i.outer: int32, 0, 16) {
    for (i.inner: int32, 0, 64) {
      C_2[((i.outer*64) + i.inner)] = ((float32*)A_2[((i.outer*64) + i.inner)] + (float32*)B_2[((i.outer*64) + i.inner)])
    }
  }
}

Lowering完成后,接下去就是build了。Build的主要流程相关代码如下:

build() # driver/build_module.py
    input_mod = lower(inputs, args, ...) 

    mod_host_all = tvm.IRModule()

    for tar, input_mod in target_input_mod.items():
        # build the lowered functions for a device with the given compilation
        mod_host, mdev = _build_for_device(input_mod, tar, target_host)
            # input_mod type: IRModule
            mod_mixed = input_mod 
            # Apply passes:  ThreadSync, InferFragment, LowerThreadAllreduce, MakePackedAPI, SplitHostDevice
            ...
            # Device optimizations: Filter, LowerWarpMemory, ,Simplify, LowerDeviceStorageAccessInfo, LowerIntrin
            ...
            mod_dev = opt_device(mod_mixed) # IRModule
            # Host optimization: LowerTVMBuiltin, LowerDeviceStorageAccessInfo, CustomDataType, LowerIntrin, CombineContextCall
            ...
            mod_host = opt_host(mod_mixed) # IRModule
            
            # Build IRModule into Module
            # If there are dev functions
            rt_mod_dev = codegen.build_module(mod_dev, target) # target/codegen.py
                _ffi_api.Build(mod, target) # codegen.py
            # mod_host type: IRModule, rt_mod_dev type: Module
            return mod_host, rt_mod_dev 
        mod_host_all.update(mod_host)
            # Insert functions in another Module to current one
            _ffi_api.Module_Update()
                IRModuleNode::Update() # ir/module.cc
        device_modules.append(mdev)
    # Generate a unified host module (type: runtime.Module)
    rt_mod_host = codegen.build_module(mod_host_all, target_host)
        # Create LLVMModuleNode and return the corresponding Module
        _ffi_api.Build(mod, target) # target/codegen.cc
    # Import all modules
    for mdev in device_modules:
        rt_mod_host.import_module(mdev)
            _LIB.TVMModImport(mod, dep) # c_runtime_api.cc
                GetModuleNode(mod)->Import(...) # runtime/module.cc
                    imports_.emplace_back(...)
    return rt_mod_host # runtime.module.Module

target_input_mod包含了前面lowering输出的需要编译的IRModule及相应的target信息。比如LLVM(CPU)为target,就是:{"llvm -keys=cpu -link-params=0", IRModule}。如cuda为target,可能就是{“cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", IRModule}。对于简单的case(如本文这个),target_input_mod只包含一个元素,_build_for_device()函数返回host端的IRModule,以及target端的Module(如是cuda平台的话C++层对应CUDAModuleNode对象)。然后将host端IRModule生成一个统一的host模块,再将前面生成的对应target的Module(如有)导入其中。

这里,其中mod_host_allmod_host的类型为tvm.ir.module.IRModulert_mod_hostmdev的类型为tvm.runtime.module.Module。注意mdev只有当目标为非CPU(如GPU等)平台时才会有,当target为llvm(即for CPU)时mdev为空。

这个流程大体示意图如下:
在这里插入图片描述

其中比较核心和重要的部分是Build()函数,实现在codegen.cc文件中。它会调用到具体后端的编译函数,进行目标代码生成。如cuda平台的话对应函数定义在build_cuda_on.cc文件中,llvm的话在llvm_module.cc文件中。以llvm后端为例,其主要流程相关代码为:

TVM_REGISTER_GLOBAL("target.build.llvm")
    .set_body_typed([](IRModule mod, Target target) -> runtime::Module { 
        auto n = make_object<LLVMModuleNode>();
        n->Init(mod, target); // llvm_module.cc
            InitializeLLVM();
                llvm::InitializeAllTargetInfos();
                llvm::InitializeAllTargets();
                ...
            unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(...) // codegen_llvm.cc
                // Call the corresponding codegen backend according to the target.
                const PackedFunc* f = runtime::Registry::Get("tvm.codegen.llvm.target_" + target);
                handle = (*f)() 
                return unique_ptr<CodeGenLLVM>(handle);
                
            vector<PrimFunc> funcs;
            for kv : mod->functions:
                ...
                f = Downcast<PrimFunc>(kv.second);
                if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc))
                    entry_func = global_symbol.value();
                funcs.push_back(f);
            cg->Init("TVMMod", ...);
                CodeGenCPU::Init() // codegen_cpu.cc
                    CodeGenLLVM::Init() // codegen_llvm.cc
                    
            for f in funcs:
                cg->AddFunction(f); // codegen_cpu.cc
                    CodeGenLLVM::AddFunction();
                        AddFunctionInternal(f);
                            llvm::FunctionType* ftype = llvm::FunctionType::get(...);
                            // kGlobalSymbol: "global_symbol"
                            global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
                            function_ = llvm::Function::Create(...);
                            llvm::BasicBlock* entry = llvm::BasicBlock::Create(..., function_);
                            IRBuilder::SetInsertPoint(entry);
                            this->VisitStmt(f->body);
                            builder_->CreateRet(ConstInt32(0));
            if entry_func.length() != 0:
                cg->AddMainFunction(entry_func); // codegen_cpu.cc
                    // tvm_module_main : "__tvm_main__"
                    llvm::GlobalVariable* global = new llvm::GlobalVariable(*module_, ..., tvm_module_main);
                    global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, entry_func_name))
                    global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
            module_ = cg->Finish(); // CodeGenCPU::Finish() in codegen_cpu.cc
                CodeGenLLVM::Finish(); // codegen_llvm.cc
                    CodeGenCPU::AddStartupFunction();
                        function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage,"__tvm_module_startup", module_.get());
                        llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
                        llvm::appendToGlobalCtors(*module_, function_, 65535);
                        builder_->CreateRet(nullptr);
                    CodeGenLLVM::Optimize(); // codegen_llvm.cc
                        // Function pass manager
                        FPassManager fpass(module_.get());
                        // Module pass manager
                        MPassManager mpass;
                        mpass.add(llvm::createTargetTransformInfoWrapperPass(getTargetIRAnalysis()));
                        fpass.add(llvm::createTargetTransformInfoWrapperPass(getTargetIRAnalysis()));
                        llvm::PassManagerBuilder builder;
                        builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, ...);
                        builder.LoopVectorize = true; 
                        builder.SLPVectorize = true; 
                        ...
                        // Run the function passes
                        for mod in module_:
                            fpass.run(mod);
                        fpass.doFinalization();
                        // Run the module passes.
                        mpass.run(*module_);
        return runtime::Module(n);
    });

该函数中先创建LLVMModuleNode对象,然后调用它的Init()函数进行初始化,最后封装成Module对象返回。其中的Init()函数主要是将之前生成的TIR转为LLVM IR。它主要分几步:

  1. InitializeLLVM()函数初始化LLVM环境。这里边主要是例行调用LLVM的一大坨初始化函数。

  2. 创建用于代码生成的CodeGenLLVM对象。这里由于target字符串为x86-64,因此工厂函数名为tvm.codegen.llvm.target_x86-64。该工厂函数中创建CodeGenX86_64对象。因为继承关系为CodeGenX86_64 -> CodeGenCPU -> CodeGenLLVM,所以返回的是CodeGenLLVM的指针。

  3. 类型为IRModule的参数mod中的functions成员包含了该模块中的函数。这一步中将这些函数存于类型PrimFunc的数组funcs中。对于标为入口函数(kIsEntryFunc)的函数,记录在entry_func变量中。

  4. 接下来初始化前面创建的CodeGenX86_64对象。先调用CodeGenCPU::Init(),它里边又会调用到CodeGenLLVM::Init()。前者主要创建一坨TVM运行时类型与函数。后者创建一些llvm中用于codegen的对象,如IRBuilderllvm::Modulellvm::MDBuilder

  5. 对前面放入funcs数组的每个函数,调用CodeGenCPU::AddFunction()函数进行代码生成。对本文涉及的case只有一个函数就是vecadd()

    1. 首先产生llvm::Functionllvm::BasicBlock对象,分别对应函数与基本块。前面在loewr()函数中将函数的名为global_symbol的属性设为相应的函数名(如vecadd)。这里将该属性取出,作为生成函数的链接时的symbol。
    2. 通过VisitStmt()函数遍历IRModule中的各节点并转为LLVM中对应的数据结构,生成LLVM IR。这是最关键的一步了。前面费了老大劲构建起的TIR主要就是为了这里的转换。举例来说,对于ForNode就会调用CodeGenLLVM::VisitStmt_(ForNode *op)函数。它继而会调用CreateSerialFor()函数来产生相应的LLVM IR。在优化pass中的MakePackedAPImake_packed_api.cc)会添加一个AttrStmt,它对应一个值为目标函数名加_compute_后缀的compute_scope。这样,在code generation时 CodeGenCPU::CreateComputeScope()函数(为什么加compute_scope在该函数的注释中有提到)被调用。因此,最终的binary(可通过fadd.export_library("vecadd.so")语句导出)中大概会是这个样子:
      在这里插入图片描述
  6. AddMainFunction()函数设置主函数。如上面的例子中只有一个函数vecadd(),它也是主函数。这个symbol会放在runtime::symbol::tvm_module_main(即__tvm_main__)这个全局变量中。我们可以拿编译好binary验证这一点。用objdump命令dump导出的so文件,可以看到如下这段。如果将里边的0x766563616464的16进制转为ASCII,就是主函数的symbol名:vecadd。

0000000000003c87 <__tvm_main__>:    
    3c87:   76 65                   jbe    3cee <__GNU_EH_FRAME_HDR+0x5e>
    3c89:   63 61 64                movslq 0x64(%rcx),%esp
    3c8c:   64                      fs     
  1. 最后,调用CodeGenCPU::Finish()函数将LLVM IR生成后端代码。它实际调用CodeGenLLVM::Finish()函数,它会调用CodeGenLLVM::Finish()函数。它主要调用CodeGenCPU::AddStartupFunction()函数和CodeGenLLVM::Optimize()函数。前者创建_tvm_module_startup函数,然后将一些需要启动时调用的函数填入。后者主要利用LLVM pass做一些优化。主要是向量化和函数内联。llvm中两种自动向量化。具体可参见Auto-Vectorization in LLVM

其实,到这里编译还没有完全结束,只是构建好了LLVM的module。到这里,剩下的事情就是交给LLVM来编译生成可执行的binary了。真正生成可执行的binary是在第一次运行时通过LazyInitJIT()函数完成。 运行时会调用到LLVMModuleNode::GetFunction()函数。当它发现还未生成可执行binary时,会调用LazyInitJIT()函数。该函数通过llvm::ExecutionEngine将前面产生的llvm::Module编译成真正的(能在机器上跑的)binary。然后GetFunctionAddr()函数从中获得相应的函数指针,用于执行。