优化大型语言模型
随着大型语言模型 (LLM) 已成为许多不同领域的热门研究课题,将其部署到云和边缘设备已成为一项具有挑战性的任务。在本教程中,我们将演示如何使用 Apache TVM 优化大型语言模型。我们将使用来自 Hugging Face 的预训练 TinyLlama 模型,并将其部署到各种设备上。
回顾整体流程
整体流程包括以下步骤
构建或导入模型:构建神经网络模型或从其他框架(例如 PyTorch、ONNX)导入预训练模型,并创建 TVM IRModule,其中包含编译所需的所有信息,包括用于计算图的高级 Relax 函数,以及用于张量程序的低级 TensorIR 函数。
执行可组合的优化:执行一系列优化转换,例如图优化、张量程序优化和库调度。
构建和通用部署:将优化的模型构建为可部署模块到通用运行时,并在不同的设备(如 CPU、GPU 或其他加速器)上执行它。
构建模型架构
我们将使用来自 Hugging Face 的预训练 TinyLlama 模型。然而,通常我们只从 Hugging Face 加载预训练权重,而不是模型架构。我们需要自己构建模型架构。Apache TVM 准备了一个类似 PyTorch 的 API 来构建模型架构。我们可以使用该 API 来构建模型架构。
import dataclasses
import enum
import os
from pathlib import Path
from pprint import pprint
from typing import List, Optional
import tvm
from tvm import dlight, relax, te, tir
from tvm.relax import register_pipeline
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op
from tvm.relax.frontend.nn.llm.kv_cache import PagedKVCache, TIRPagedKVCache
from tvm.runtime import ShapeTuple
首先,我们需要定义模型配置。配置包括模型的关键参数,例如隐藏层大小、中间层大小等。为了方便起见,这里我们为 TinyLlama 模型专门定义了一个常量配置。
@dataclasses.dataclass
class LlamaConfig:
hidden_size: int = 2048
intermediate_size: int = 5632
num_attention_heads: int = 32
num_hidden_layers: int = 22
rms_norm_eps: float = 1e-05
vocab_size: int = 32000
rope_theta: int = 10000
context_window_size: int = 2048
prefill_chunk_size: int = 2048
num_key_value_heads: int = 4
head_dim: int = 64 # hidden_size // num_attention_heads
dev = tvm.device("cuda", 0)
target = tvm.target.Target.from_device(dev)
接下来,我们定义分页 KV 缓存的 RoPE 模式。RoPE 模式用于将相对位置编码 (RoPE) 应用于查询和键张量。RoPE 模式可以设置为 NONE、NORMAL 或 INLINE。如果 RoPE 模式为 NONE,则 KV 缓存不会将 RoPE 应用于查询和键张量。如果 RoPE 模式为 NORMAL,则 RoPE 将在将键张量添加到缓存之前应用于键张量。如果 RoPE 模式为 INLINE,则 RoPE 将在注意力内核中动态地应用于查询和键张量。
class RopeMode(enum.IntEnum):
"""The RoPE mode of the Paged KV cache.
If it is none, the KV cache will not apply RoPE to q and k.
If it is normal, RoPE will be applied to k before adding k to cache.
Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
"""
NONE = 0
NORMAL = 1
INLINE = 2
其次,我们定义模型架构。模型架构由三个部分组成
嵌入层:嵌入层将输入 token ID 转换为隐藏状态。
解码器层:解码器层是模型的核心。每个解码器层由一个自注意力层和一个前馈网络 (FFN) 层组成。
输出层:输出层将隐藏状态转换为 logits。
首先我们定义 FFN 层。请注意,以下 FFN 层是优化的实现,我们将 gate 和 up 投影融合到一个内核中。FFN 层的朴素实现是:FFN(x) = down_proj(silu(gate(x)) * up(x))
我们可以将 gate
和 up
投影组合到一个内核中以获得更好的性能。优化的实现是
concat_x = gate_up(x)
gate_x, up_x = split(concat_x, 2, axis=-1)
FFN(x) = down_proj(silu(gate_x) * up_x)
class LlamaFFN(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.gate_up_proj = nn.Linear(
in_features=config.hidden_size,
out_features=2 * config.intermediate_size,
bias=False,
)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x: Tensor):
concat_x1_x2 = self.gate_up_proj(x)
x1, x2 = op.split(concat_x1_x2, 2, axis=-1)
return self.down_proj(op.silu(x1) * x2)
然后我们定义自注意力层。自注意力层由三个部分组成
QKV 投影:QKV 投影将输入隐藏状态转换为查询、键和值张量。
注意力:注意力层计算注意力分数并应用 softmax 运算。
输出投影:输出投影将注意力输出转换为隐藏状态。
我们对自注意力层的不同部分执行优化
QKV 投影:我们利用 QKV 投影的水平融合,并将它们融合到一个内核中。
注意力:我们利用注意力的水平融合,并将 QKV 投影融合在一起
class LlamaAttention(nn.Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: LlamaConfig):
self.head_dim = config.head_dim
self.num_q_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
# horizontal fusion on QKV projection
self.qkv_proj = nn.Linear(
in_features=config.hidden_size,
out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,
bias=False,
)
self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads
b, s, _ = hidden_states.shape
# QKV Projection
qkv = self.qkv_proj(hidden_states)
qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))
# Attention
output = op.reshape(
paged_kv_cache.attention_with_fused_qkv(
layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5
),
(b, s, h_q * d),
)
# Output Projection
return self.o_proj(output)
最后,我们用 FFN 和自注意力层定义模型架构。
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
rms_norm_eps = config.rms_norm_eps
self.self_attn = LlamaAttention(config)
self.mlp = LlamaFFN(config)
self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
hidden_states += self.self_attn(
self.input_layernorm(hidden_states), paged_kv_cache, layer_id
)
hidden_states += self.mlp(self.post_attention_layernorm(hidden_states))
return hidden_states
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig):
assert config.hidden_size % config.num_attention_heads == 0
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)
def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
hidden_states = input_embed
for layer_id, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, paged_kv_cache, layer_id)
hidden_states = self.norm(hidden_states)
return hidden_states
class LlamaForCasualLM(nn.Module):
def __init__(self, config: LlamaConfig):
self.model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.num_hidden_layers = config.num_hidden_layers
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.vocab_size = config.vocab_size
self.rope_theta = config.rope_theta
self.dtype = "float32"
def to(self, dtype: Optional[str] = None):
super().to(dtype=dtype)
if dtype is not None:
self.dtype = dtype
def embed(self, input_ids: Tensor):
return self.model.embed_tokens(input_ids)
def get_logits(self, hidden_states: Tensor):
logits = self.lm_head(hidden_states)
if logits.dtype != "float32":
logits = logits.astype("float32")
return logits
def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
def _index(x: te.Tensor): # x[:-1,:]
b, s, d = x.shape
return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index")
hidden_states = self.model(input_embed, paged_kv_cache)
hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states])
logits = self.get_logits(hidden_states)
return logits, paged_kv_cache
def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
hidden_states = self.model(input_embed, paged_kv_cache)
logits = self.get_logits(hidden_states)
return logits, paged_kv_cache
def create_tir_paged_kv_cache(
self,
max_batch_size: tir.Var,
max_total_seq_len: tir.Var,
prefill_chunk_size: tir.Var,
page_size: tir.Var,
) -> PagedKVCache:
return TIRPagedKVCache(
attn_kind="mha",
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
page_size=page_size,
support_sliding_window=0,
layer_partition=relax.ShapeExpr([0, self.num_hidden_layers]),
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
qk_head_dim=self.head_dim,
v_head_dim=self.head_dim,
mla_original_qk_head_dim=0,
mla_original_v_head_dim=0,
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
rope_scaling={},
rope_ext_factors=relax.PrimValue(0),
rotary_dim=self.head_dim,
dtype=self.dtype,
target=target,
enable_disaggregation=False,
)
def get_default_spec(self):
mod_spec = {
"embed": {
"input_ids": nn.spec.Tensor(["seq_len"], "int32"),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
"prefill": {
"input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype),
"paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
"decode": {
"input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),
"paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
"create_tir_paged_kv_cache": {
"max_batch_size": int,
"max_total_seq_len": int,
"prefill_chunk_size": int,
"page_size": int,
"$": {
"param_mode": "none",
"effect_mode": "none",
},
},
}
return nn.spec.ModuleSpec.from_raw(mod_spec, self)
将模型导出到 Relax IRModule
定义模型架构后,我们可以将模型导出到 Relax IRModule。为了演示,我们只展示模型架构和参数的一部分。
model_config = LlamaConfig()
model = LlamaForCasualLM(model_config)
model.to("float16")
mod, named_params = model.export_tvm(spec=model.get_default_spec())
prefill_str = mod["prefill"].script()
print(*prefill_str.split("\n")[3:20], sep="\n") # Only show the first 10 lines for demonstration
print(" ...")
print("\nParameters:")
pprint(named_params[:5]) # Only show the first 5 parameters for demonstration
@R.function
def prefill(input_embed: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((32000, 2048), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((32000, 2048), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 32000), dtype="float32"), R.Object):
seq_len = T.int64()
R.func_attr({"num_input": 2})
with R.dataflow():
model_embed_tokens_weight1: R.Tensor((32000, 2048), dtype="float16") = packed_params[0]
model_layers_0_self_attn_qkv_proj_weight1: R.Tensor((2560, 2048), dtype="float16") = packed_params[1]
model_layers_0_self_attn_o_proj_weight1: R.Tensor((2048, 2048), dtype="float16") = packed_params[2]
model_layers_0_mlp_gate_up_proj_weight1: R.Tensor((11264, 2048), dtype="float16") = packed_params[3]
model_layers_0_mlp_down_proj_weight1: R.Tensor((2048, 5632), dtype="float16") = packed_params[4]
model_layers_0_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[5]
model_layers_0_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[6]
model_layers_1_self_attn_qkv_proj_weight1: R.Tensor((2560, 2048), dtype="float16") = packed_params[7]
model_layers_1_self_attn_o_proj_weight1: R.Tensor((2048, 2048), dtype="float16") = packed_params[8]
model_layers_1_mlp_gate_up_proj_weight1: R.Tensor((11264, 2048), dtype="float16") = packed_params[9]
model_layers_1_mlp_down_proj_weight1: R.Tensor((2048, 5632), dtype="float16") = packed_params[10]
model_layers_1_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[11]
...
Parameters:
[('model.embed_tokens.weight', Tensor([32000, 2048], "float16")),
('model.layers.0.self_attn.qkv_proj.weight', Tensor([2560, 2048], "float16")),
('model.layers.0.self_attn.o_proj.weight', Tensor([2048, 2048], "float16")),
('model.layers.0.mlp.gate_up_proj.weight', Tensor([11264, 2048], "float16")),
('model.layers.0.mlp.down_proj.weight', Tensor([2048, 5632], "float16"))]
定义优化流水线
我们定义了一系列优化 pass 来优化模型。优化流水线是专门为 LLM 设计的。
@register_pipeline("opt_llm")
def _pipeline( # pylint: disable=too-many-arguments
ext_mods: List[nn.ExternModule] = None,
):
ext_mods = ext_mods or []
@tvm.transform.module_pass(opt_level=0)
def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
seq = tvm.transform.Sequential(
[
# Phase 1. Passes on high-level operator graph
# We can enable cublas for further optimization
relax.transform.FuseTransposeMatmul(),
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
relax.transform.LegalizeOps(),
relax.transform.AnnotateTIROpPattern(),
relax.transform.FoldConstant(),
relax.transform.FuseOps(),
relax.transform.FuseTIR(),
# Phase 3. Passes on TIR
relax.transform.DeadCodeElimination(),
# Phase 4. Low-level Optimizations
dlight.ApplyDefaultSchedule(
dlight.gpu.Matmul(),
dlight.gpu.GEMV(),
dlight.gpu.Reduction(),
dlight.gpu.GeneralReduction(),
dlight.gpu.Fallback(),
),
# Phase 5. Lowering to VM bytecode
relax.transform.RewriteDataflowReshape(),
relax.transform.ToNonDataflow(),
relax.transform.RemovePurityChecking(),
relax.transform.CallTIRRewrite(),
relax.transform.StaticPlanBlockMemory(),
relax.transform.RewriteCUDAGraph(),
relax.transform.LowerAllocTensor(),
relax.transform.KillAfterLastUse(),
relax.transform.LowerRuntimeBuiltin(),
relax.transform.VMShapeLower(),
relax.transform.AttachGlobalSymbol(),
relax.transform.AttachExternModules(ext_mods),
]
)
mod = seq(mod)
return mod
return _pipeline
with target:
ex = tvm.compile(mod, target, relax_pipeline=relax.get_pipeline("opt_llm"))
vm = relax.VirtualMachine(ex, dev)
准备模型权重
我们从 Hugging Face 加载预训练权重并准备模型权重。预训练权重以 Hugging Face 格式存储。我们需要加载权重并准备模型参数。
注意
请注意,我们不会在本教程中执行以下代码,因为预训练权重在 CI 环境中不可用。
IS_IN_CI = os.getenv("CI", "") == "true"
HF_WEIGHT_PATH = None
# HF_WEIGHT_PATH = Path("/path/to/TinyLlama-1.1B-Chat-v1.0/")
if not IS_IN_CI:
import numpy as np
import safetensors.torch
import torch
if HF_WEIGHT_PATH is None or not HF_WEIGHT_PATH.exists():
raise ValueError("Please set the HF_WEIGHT_PATH to the path of the pre-trained weights.")
# Torch format weights
param_dict = safetensors.torch.load_file(HF_WEIGHT_PATH / "model.safetensors", device="cpu")
# Numpy format weights
param_dict = {
k: v.half().numpy() if v.dtype == torch.bfloat16 else v.numpy()
for k, v in param_dict.items()
}
named_params = dict(named_params)
for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
param_dict[f"{attn}.qkv_proj.weight"] = np.concatenate(
[
param_dict.pop(f"{attn}.q_proj.weight"), # Pop the old parameters to save memory
param_dict.pop(f"{attn}.k_proj.weight"),
param_dict.pop(f"{attn}.v_proj.weight"),
],
axis=0,
)
# Add gates in MLP
mlp = f"model.layers.{i}.mlp"
param_dict[f"{mlp}.gate_up_proj.weight"] = np.concatenate(
[
param_dict.pop(f"{mlp}.gate_proj.weight"),
param_dict.pop(f"{mlp}.up_proj.weight"),
],
axis=0,
)
# Convert params into ndarray
params = [
tvm.nd.array(param_dict[k].astype("float16"), device=dev) for k in named_params.keys()
]
部署编译后的模型
在模型和权重准备就绪后,我们可以将编译后的模型部署到目标设备。语言模型推理包括两个步骤:预填充和解码。预填充步骤用于处理输入 token 并存储 KVCache。解码步骤用于生成 token,直到生成结束 token。
分词
第一步是 token 化输入提示,并将 token 嵌入到隐藏状态中。分词和嵌入与原始模型相同。我们使用 HF tokenizer 来 token 化输入提示,并将 token 嵌入到隐藏状态中。请注意,不同的模型需要不同的分词和提示格式,请参考模型文档以获取正确的分词和提示格式。
if not IS_IN_CI:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_WEIGHT_PATH)
messages = [
{"role": "user", "content": "What's your name?"},
]
prompt = tokenizer.apply_chat_template(messages)
input_len = len(prompt)
# Load prompt tokens into TVM ndarray on the target device
tokens = tvm.nd.array(np.array(prompt).astype("int32"), device=dev)
创建 KVCache
在开始推理之前,我们需要创建 KVCache。KVCache 用于存储注意力层的键和值张量。Apache TVM 提供了 PagedKVCache 来存储键和值张量。我们使用指定的参数创建 PagedKVCache。
嵌入
下一步是将 token 嵌入到隐藏状态中。我们使用在 Relax IRModule 中编译的 embed 函数将 token 嵌入到隐藏状态中。
nd_view_func = tvm.get_global_func("vm.builtin.reshape")
def embed(tokens, params):
_embed = vm["embed"](tokens, params)
# Reshape hidden from [seq_len, hidden_size] to [1, seq_len, hidden_size]
_embed = nd_view_func(_embed, ShapeTuple([1, _embed.shape[0], _embed.shape[1]]))
return _embed
预填充
在运行前向传播之前,我们首先获取一些辅助函数进行准备。
add_sequence_func = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
begin_forward_func = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
end_forward_func = tvm.get_global_func("vm.builtin.kv_state_end_forward")
由于我们正在创建一个新序列,我们需要调用 add_sequence_func 来初始化请求。此外,我们需要调用 begin_forward_func 来启动前向传播,并调用 end_forward_func 来结束前向传播。
现在我们有了来自预填充步骤的输出 logits。logits 用于通过采样生成 token。让我们从 logits 中采样 token。
在本教程中,我们简化了采样过程,并选择概率最高的 token。在实践中,我们应该基于概率分布对 token 进行采样。此外,为了使教程简洁,我们在 CPU 上执行采样过程。
def sample_token(logits):
logits_np = logits.numpy()
return np.argmax(logits_np)
if not IS_IN_CI:
last_token = sample_token(logits)
output_tokens = [last_token]
解码
在预填充步骤之后,我们可以开始解码步骤。解码步骤用于生成 token,直到生成结束 token。我们使用在 Relax IRModule 中编译的 decode 函数来生成 token。
if not IS_IN_CI:
print("The generated token:")
while last_token != tokenizer.eos_token_id:
tokens = tvm.nd.array(np.array([last_token]).astype("int32"), device=dev)
hidden_states = embed(tokens, params)
begin_forward_func(kv_cache, ShapeTuple([seq_id]), ShapeTuple([1]))
logits, kv_cache = vm["decode"](hidden_states, kv_cache, params)
end_forward_func(kv_cache)
last_token = sample_token(logits)
output_tokens.append(last_token)
print(tokenizer.decode(output_tokens))