出发点
上一篇解析了Chatglm2-6b的模型架构,并和Chatglm-6b进行对比,但是留下了几个问题(哭)这一篇的目的是讲明白attention和rotaryEmbedding,解决问题,并实现整体目标,完全替代modeling_chatglm.py,并将代码缩减到一半儿。
selfattention
class SelfAttention(torch.nn.Module):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
super(SelfAttention, self).__init__()
self.layer_number = max(1, layer_number)
self.projection_size = config.kv_channels * config.num_attention_heads# 128*32=4096 hidden_size
# Per attention head and per partition values.
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads# 128 每个attention头的hidden_size
self.num_attention_heads_per_partition = config.num_attention_heads# 32 attention头数
self.num_multi_query_groups_per_partition = config.multi_query_group_num# 2 分了多少组
self.qkv_hidden_size = (
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
)# 4096+2*128*2=4608 qkv对应的hidden_size
# 稍微解释一下为什么不是4096*3,因为这里使用了GQA的思想,下文会简单介绍一下
self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
device=device, **_config_to_kwargs(config)
)
self.core_attention = CoreAttention(config, self.layer_number)
# Output.
self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,device=device, **_config_to_kwargs(config))
def forward(
self, hidden_states, rotary_pos_emb, kv_cache=None, use_cache=True
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
],
dim=-1,
)
query_layer = query_layer.view(
query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
)
key_layer = key_layer.view(
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
)
value_layer = value_layer.view(
value_layer.size()[:-1]
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
)
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
# adjust key and value for inference
if kv_cache is not None:
cache_k, cache_v = kv_cache
key_layer = torch.cat((cache_k, key_layer), dim=0)
value_layer = torch.cat((cache_v, value_layer), dim=0)
if use_cache:
kv_cache = (key_layer, value_layer)
else:
kv_cache = None
key_layer = key_layer.unsqueeze(-2)
key_layer = key_layer.expand(
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
)
key_layer = key_layer.contiguous().view(
key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
)# GQA的操作:重复多次到原始尺寸,即32,128
value_layer = value_layer.unsqueeze(-2)
value_layer = value_layer.expand(
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
)
value_layer = value_layer.contiguous().view(
value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
)# GQA的操作:重复多次到原始尺寸,即32,128
# ==================================
# core attention computation
# ==================================
context_layer = self.core_attention(query_layer, key_layer, value_layer)# 核心操作attention,和Chatglm-6b中attention_fn是一样的
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
return output, kv_cache
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
可以看出来思想也比较朴素,MHA中query、key、value都是一对一的,这样虽然效果好,但是caches太多了。MQA中只有一组key和value,和多个query相对应,caches减少了,但是效果会不好。那GQA则取个平均,有g组key和value,每一组key和value都重复几次和query相对应。
GQA提供了MHA到MQA的自然过渡,当g=h时就是MHA,g=1时就是MQA,当1<g<h时,它只将KV Cache压缩到g/h,压缩率不如MQA,但同时也提供了更大的自由度,效果上更有保证。
这里也贴一下Fast Transformer Decoding: One Write-Head is All You Need
那这里就解决了两个问题:
- multi_query_group_num是GQA中要分组的数量
- kv_channels对应的是query、key、value每个头的hidden_size
coreattention
class CoreAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number):
super(CoreAttention, self).__init__()
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling# 对query、key层是否要进行缩放,实际是要缩放的
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32# softmax的精度要使用fp32
self.layer_number = max(1, layer_number)
# Per attention head and per partition values.
self.hidden_size_per_partition = config.kv_channels * config.num_attention_heads# 128*32
self.hidden_size_per_attention_head = config.kv_channels# 128
self.num_attention_heads_per_partition = config.num_attention_heads# 32
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)# sqrt(d)的操作
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
def forward(self, query_layer, key_layer, value_layer):
pytorch_major_version = int(torch.__version__.split('.')[0])
if pytorch_major_version >= 2:
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if query_layer.shape[2] == key_layer.shape[2]:# 只会在生成第一个token的时候,走这条路
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
is_causal=True)# 从这里可以看出来Chatglm2-6b完全就是一个decoder only的模型
else:# 这时候query的长度是1,key的长度是总token的长度
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
None)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
else:
# Raw attention scores
# [b, np, sq, sk]
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty(
output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
device=query_layer.device
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)# Chatglm-6b中将alpha放在了前面,让query单独除了一下,没啥结果上的差别
# 关于torch.baddbmm多说一句,因为beta=0,所以input选择empty没啥问题,反正要被跳过
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
if self.attention_softmax_in_fp32:
attention_scores = attention_scores.float()
if attention_scores.shape[2] == attention_scores.shape[3]:
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
device=attention_scores.device, dtype=torch.bool)
attention_mask.tril_()
attention_mask = ~attention_mask
else:
attention_mask = None
"""
重点看一下这一小段代码,当sq=sk时(即query长度和key长度一致时,给了一个attention_mask)
此时的attention_mask其实就是一个上三角为True、下三角为False的矩阵
结合后面的 attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) 这一句的操作
就是将上三角的scores值置为负无穷,这妥妥的就是decoder-only嘛
当sq!=sk时,attention_mask即为空,即预测第二个token时,此时query长度为1,而key长度带着之前的cache,所以长度>1,此时不相等,attention_mask为空,后续也就没有啥操作了
"""
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.type_as(value_layer)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
这里多写一句,代码中有关于self.coeff的操作,即layer_number
在代码中self.norm_factor=self.coeff *math.sqrt(self.hidden_size_per_attention_head)
在计算attention_scores中除以了self.coeff *math.sqrt(self.hidden_size_per_attention_head)
然后在计算softmax之前又将attention_scores乘以了self.coeff
那不就相当于只是除以了math.sqrt(self.hidden_size_per_attention_head)嘛????
不知道为什么要有这个操作,感觉怪怪的,最主要的是不知道目的,有了解的可以解释一下,谢谢
之前Chatglm-6b的代码中就有这样的操作,当时没注意到(汗),这里的代码是直接删去了这个操作,完全没影响的。
当然了因为在pytorch_major_version >= 2中其实是没有和layer_number相关的操作,这个时候应该就能明白这个操作是无用的了。
RotaryEmbedding
class RotaryEmbedding(nn.Module):
def __init__(self, dim, original_impl=False, device=None, dtype=None):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
self.register_buffer("inv_freq", inv_freq)
self.dim = dim
self.original_impl = original_impl
def forward_impl(
self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
):
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).float()
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
# this is to mimic the behaviour of complex32, else we will get different results
if dtype in (torch.float16, torch.bfloat16, torch.int8):
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
return cache
def forward(self, max_seq_len, offset=0):
return self.forward_impl(
max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
)
@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
# x: [sq, b, np, hn]
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
rot_dim = rope_cache.shape[-2] * 2# 32*2
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]# [:64],[64:] 将输入根据隐藏层维度,拆分得到两部分,只针对前部分x计算旋转位置信息
# truncate to support variable sizes
rope_cache = rope_cache[:sq]
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
# [q_0,q_1][q_2,q_3]
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
# [cos0,sin0][cos1,sin1]
x_out2 = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
# 对应复数的实部q_0*cos(m\theta)-q_1*sin(m\theta)
# [q0, q2, ] *[cos0, cos1] - [q1, q3, ] *[sin0, sin1]
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
# 对应复数的虚部q_1*cos(m\theta)+q_0*sin(m\theta)
# [q1, q3, ] *[cos0, cos1] + [q0, q2, ] *[sin0, sin1]
],
-1,
)
# q0cos0-q1sin0
# q1cos0+q0sin0
# q2cos1-q3sin1
# q3cos1+q2sin1
x_out2 = x_out2.flatten(3)
return torch.cat((x_out2, x_pass), dim=-1)
这里就可以解释位置Embedding中传入的dim为什么是rotary_dim // 2了,因为它只对一半的hidden_size进行了位置编码,这也是很迷的一项操作,我没看到什么很好的解释,有了解原因的,欢迎指导,谢谢
最后一点代码量
到此基本就写完了代码,最后补充上两个函数和一点import
""" PyTorch ChatGLM model. """
import math
import copy
import re
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import LayerNorm
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
from transformers.modeling_utils import PreTrainedModel
from configuration_chatglm import ChatGLMConfig
def _config_to_kwargs(args):
common_kwargs = {
"dtype": args.torch_dtype,
}
return common_kwargs
class ChatGLMPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
is_parallelizable = False
config_class = ChatGLMConfig
base_model_prefix = "transformer"
_no_split_modules = ["GLMBlock"]
把这些代码保存成chatglm.py,放在chatglm2-6b的代码中,就可以正常使用了,使用方法和chatglm-6b是一样的
from chatglm import *
from transformers import AutoTokenizer
model_path = "/usr/downloads/chatglm2-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = ChatGLMForConditionalGeneration.from_pretrained(model_path, trust_remote_code=True).half().cuda()
prompt = '你好'
response = model.chat(tokenizer, prompt)
代码量在650行,原始代码量是1280,减少一半的代码的小目标基本实现(成功)
参数量
简单分析一下参数量,其实从模型结构里就能很明白的看出来了,我这里就是记录一下
# word embedding
65024*4096*2=532676608
# 最后一层后面的LN
4096
# 下面几个是每层都有的
# query_key_value
4608*4096=18874368
# query_key_value.bias
4608
# dense
4096*4096=16777216
# LN
2*4096
# dense_h_to_4h
4096*27392=112197632
# dense_4h_to_h
13696*4096=56098816
# 28层
(18874368+4608+16777216+2*4096+112197632+56098816)*28=5710903296
5710903296+532676608+4096=6243584000
# 可以看出来主要的参数还是在word Embedding和dense_h_to_4h
结束语
这次解析了chatglm2-6b的代码,将代码缩减到650行,并分析了与chatglm-6b的区别,其实从结构里就可以看出来,它已经不是GLM的架构了,完全是一个decoder only的结构。改为了使用了RMSNorm、使用了GQA缩减caches、激活函数使用swiglu,基本就是这些了。
补充一点:经过查看代码,发现chatglm3-6b和chatglm2-6b的代码基本一模一样,只有在tokenizer处理输入的时候和返回response的时候有一点不一样,所以就不对chatglm3-6b做单独的介绍了。