ChatGLM2-6B如何从输入到输出-代码解析(二)

news/2025/2/26 16:22:09

出发点

上一篇解析了Chatglm2-6b的模型架构,并和Chatglm-6b进行对比,但是留下了几个问题(哭)这一篇的目的是讲明白attention和rotaryEmbedding,解决问题,并实现整体目标,完全替代modeling_chatglm.py,并将代码缩减到一半儿。

selfattention

selfattention


class SelfAttention(torch.nn.Module):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super(SelfAttention, self).__init__()
        self.layer_number = max(1, layer_number)

        self.projection_size = config.kv_channels * config.num_attention_heads# 128*32=4096 hidden_size

        # Per attention head and per partition values.
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads# 128 每个attention头的hidden_size
        self.num_attention_heads_per_partition = config.num_attention_heads# 32 attention头数

        self.num_multi_query_groups_per_partition = config.multi_query_group_num# 2 分了多少组
        self.qkv_hidden_size = (
                self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
        )# 4096+2*128*2=4608 qkv对应的hidden_size
        # 稍微解释一下为什么不是4096*3,因为这里使用了GQA的思想,下文会简单介绍一下
        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
                                         bias=config.add_bias_linear or config.add_qkv_bias,
                                         device=device, **_config_to_kwargs(config)
                                         )

        self.core_attention = CoreAttention(config, self.layer_number)

        # Output.
        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,device=device, **_config_to_kwargs(config))

    def forward(
            self, hidden_states, rotary_pos_emb, kv_cache=None, use_cache=True
    ):
        # hidden_states: [sq, b, h]

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer = self.query_key_value(hidden_states)

        (query_layer, key_layer, value_layer) = mixed_x_layer.split(
            [
                self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
            ],
            dim=-1,
        )
        query_layer = query_layer.view(
            query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )
        key_layer = key_layer.view(
            key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
        )
        value_layer = value_layer.view(
            value_layer.size()[:-1]
            + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
        )


        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

        # adjust key and value for inference
        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            key_layer = torch.cat((cache_k, key_layer), dim=0)
            value_layer = torch.cat((cache_v, value_layer), dim=0)
        if use_cache:
            kv_cache = (key_layer, value_layer)
        else:
            kv_cache = None

        key_layer = key_layer.unsqueeze(-2)
        key_layer = key_layer.expand(
            -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
        )
        key_layer = key_layer.contiguous().view(
            key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )# GQA的操作:重复多次到原始尺寸,即32,128
        value_layer = value_layer.unsqueeze(-2)
        value_layer = value_layer.expand(
            -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
        )
        value_layer = value_layer.contiguous().view(
            value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )# GQA的操作:重复多次到原始尺寸,即32,128

        # ==================================
        # core attention computation
        # ==================================

        context_layer = self.core_attention(query_layer, key_layer, value_layer)# 核心操作attention,和Chatglm-6b中attention_fn是一样的

        # =================
        # Output. [sq, b, h]
        # =================

        output = self.dense(context_layer)

        return output, kv_cache

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
GQA
可以看出来思想也比较朴素,MHA中query、key、value都是一对一的,这样虽然效果好,但是caches太多了。MQA中只有一组key和value,和多个query相对应,caches减少了,但是效果会不好。那GQA则取个平均,有g组key和value,每一组key和value都重复几次和query相对应。
GQA提供了MHA到MQA的自然过渡,当g=h时就是MHA,g=1时就是MQA,当1<g<h时,它只将KV Cache压缩到g/h,压缩率不如MQA,但同时也提供了更大的自由度,效果上更有保证。
这里也贴一下Fast Transformer Decoding: One Write-Head is All You Need
那这里就解决了两个问题:

  • multi_query_group_num是GQA中要分组的数量
  • kv_channels对应的是query、key、value每个头的hidden_size

coreattention

class CoreAttention(torch.nn.Module):
    def __init__(self, config: ChatGLMConfig, layer_number):
        super(CoreAttention, self).__init__()

        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling# 对query、key层是否要进行缩放,实际是要缩放的
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32# softmax的精度要使用fp32
        self.layer_number = max(1, layer_number)

        # Per attention head and per partition values.
        self.hidden_size_per_partition = config.kv_channels * config.num_attention_heads# 128*32
        self.hidden_size_per_attention_head = config.kv_channels# 128
        self.num_attention_heads_per_partition = config.num_attention_heads# 32
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)# sqrt(d)的操作
        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)

    def forward(self, query_layer, key_layer, value_layer):
        pytorch_major_version = int(torch.__version__.split('.')[0])
        
        if pytorch_major_version >= 2:
            query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
            if query_layer.shape[2] == key_layer.shape[2]:# 只会在生成第一个token的时候,走这条路
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                                 is_causal=True)# 从这里可以看出来Chatglm2-6b完全就是一个decoder only的模型
            else:# 这时候query的长度是1,key的长度是总token的长度
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                                 None)
            context_layer = context_layer.permute(2, 0, 1, 3)
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.reshape(*new_context_layer_shape)
        else:
            # Raw attention scores

            # [b, np, sq, sk]
            output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))

            # [sq, b, np, hn] -> [sq, b * np, hn]
            query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
            # [sk, b, np, hn] -> [sk, b * np, hn]
            key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)

            # preallocting input tensor: [b * np, sq, sk]
            matmul_input_buffer = torch.empty(
                output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
                device=query_layer.device
            )

            # Raw attention scores. [b * np, sq, sk]
            matmul_result = torch.baddbmm(
                matmul_input_buffer,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / self.norm_factor),
            )# Chatglm-6b中将alpha放在了前面,让query单独除了一下,没啥结果上的差别
            # 关于torch.baddbmm多说一句,因为beta=0,所以input选择empty没啥问题,反正要被跳过

            # change view to [b, np, sq, sk]
            attention_scores = matmul_result.view(*output_size)

            # ===========================
            # Attention probs and dropout
            # ===========================

            # attention scores and attention mask [b, np, sq, sk]
            if self.attention_softmax_in_fp32:
                attention_scores = attention_scores.float()

            if attention_scores.shape[2] == attention_scores.shape[3]:
                attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
                                            device=attention_scores.device, dtype=torch.bool)
                attention_mask.tril_()
                attention_mask = ~attention_mask
            else:
                attention_mask = None
            """
            重点看一下这一小段代码,当sq=sk时(即query长度和key长度一致时,给了一个attention_mask)
            此时的attention_mask其实就是一个上三角为True、下三角为False的矩阵
            结合后面的 attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) 这一句的操作
            就是将上三角的scores值置为负无穷,这妥妥的就是decoder-only嘛
            当sq!=sk时,attention_mask即为空,即预测第二个token时,此时query长度为1,而key长度带着之前的cache,所以长度>1,此时不相等,attention_mask为空,后续也就没有啥操作了

            """
            if attention_mask is not None:
                attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
            attention_probs = F.softmax(attention_scores, dim=-1)
            attention_probs = attention_probs.type_as(value_layer)

            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            attention_probs = self.attention_dropout(attention_probs)
            # =========================
            # Context layer. [sq, b, hp]
            # =========================

            # value_layer -> context layer.
            # [sk, b, np, hn] --> [b, np, sq, hn]

            # context layer shape: [b, np, sq, hn]
            output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
            # change view [sk, b * np, hn]
            value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
            # change view [b * np, sq, sk]
            attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
            # matmul: [b * np, sq, hn]
            context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
            # change view [b, np, sq, hn]
            context_layer = context_layer.view(*output_size)
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
            # [sq, b, np, hn] --> [sq, b, hp]
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer

这里多写一句,代码中有关于self.coeff的操作,即layer_number
在代码中self.norm_factor=self.coeff *math.sqrt(self.hidden_size_per_attention_head)
在计算attention_scores中除以了self.coeff *math.sqrt(self.hidden_size_per_attention_head)
然后在计算softmax之前又将attention_scores乘以了self.coeff
那不就相当于只是除以了math.sqrt(self.hidden_size_per_attention_head)嘛????
不知道为什么要有这个操作,感觉怪怪的,最主要的是不知道目的,有了解的可以解释一下,谢谢
之前Chatglm-6b的代码中就有这样的操作,当时没注意到(汗),这里的代码是直接删去了这个操作,完全没影响的。
当然了因为在pytorch_major_version >= 2中其实是没有和layer_number相关的操作,这个时候应该就能明白这个操作是无用的了。

RotaryEmbedding


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, original_impl=False, device=None, dtype=None):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.dim = dim
        self.original_impl = original_impl

    def forward_impl(
            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
    ):
        """Enhanced Transformer with Rotary Position Embedding.

        Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
        transformers/rope/__init__.py. MIT License:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
        """
        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))

        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, dtype=dtype, device=device)

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()

        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

        # this is to mimic the behaviour of complex32, else we will get different results
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
        return cache

    def forward(self, max_seq_len, offset=0):
        return self.forward_impl(
            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
        )


@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    rot_dim = rope_cache.shape[-2] * 2# 32*2
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]# [:64],[64:] 将输入根据隐藏层维度,拆分得到两部分,只针对前部分x计算旋转位置信息
    # truncate to support variable sizes
    rope_cache = rope_cache[:sq]
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
    # [q_0,q_1][q_2,q_3]
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
    # [cos0,sin0][cos1,sin1]
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            # 对应复数的实部q_0*cos(m\theta)-q_1*sin(m\theta)
            # [q0, q2, ] *[cos0, cos1] - [q1, q3, ] *[sin0, sin1]
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
            # 对应复数的虚部q_1*cos(m\theta)+q_0*sin(m\theta)
            # [q1, q3, ] *[cos0, cos1] + [q0, q2, ] *[sin0, sin1]
        ],
        -1,
    )
    # q0cos0-q1sin0
    # q1cos0+q0sin0
    # q2cos1-q3sin1
    # q3cos1+q2sin1
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)


这里就可以解释位置Embedding中传入的dim为什么是rotary_dim // 2了,因为它只对一半的hidden_size进行了位置编码,这也是很迷的一项操作,我没看到什么很好的解释,有了解原因的,欢迎指导,谢谢

最后一点代码量

到此基本就写完了代码,最后补充上两个函数和一点import

""" PyTorch ChatGLM model. """

import math
import copy
import re

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import LayerNorm
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
from transformers.modeling_utils import PreTrainedModel
from configuration_chatglm import ChatGLMConfig


def _config_to_kwargs(args):
    common_kwargs = {
        "dtype": args.torch_dtype,
    }
    return common_kwargs


class ChatGLMPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and
    a simple interface for downloading and loading pretrained models.
    """

    is_parallelizable = False
    config_class = ChatGLMConfig
    base_model_prefix = "transformer"
    _no_split_modules = ["GLMBlock"]

把这些代码保存成chatglm.py,放在chatglm2-6b的代码中,就可以正常使用了,使用方法和chatglm-6b是一样的

from chatglm import *
from transformers import AutoTokenizer
model_path = "/usr/downloads/chatglm2-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = ChatGLMForConditionalGeneration.from_pretrained(model_path, trust_remote_code=True).half().cuda()

prompt = '你好'
response = model.chat(tokenizer, prompt)

代码量在650行,原始代码量是1280,减少一半的代码的小目标基本实现(成功)

参数量

简单分析一下参数量,其实从模型结构里就能很明白的看出来了,我这里就是记录一下

# word embedding
65024*4096*2=532676608
# 最后一层后面的LN
4096
# 下面几个是每层都有的
# query_key_value
4608*4096=18874368
# query_key_value.bias
4608
# dense
4096*4096=16777216
# LN
2*4096
# dense_h_to_4h
4096*27392=112197632
# dense_4h_to_h
13696*4096=56098816

# 28层
(18874368+4608+16777216+2*4096+112197632+56098816)*28=5710903296
5710903296+532676608+4096=6243584000
# 可以看出来主要的参数还是在word Embedding和dense_h_to_4h

结束语

这次解析了chatglm2-6b的代码,将代码缩减到650行,并分析了与chatglm-6b的区别,其实从结构里就可以看出来,它已经不是GLM的架构了,完全是一个decoder only的结构。改为了使用了RMSNorm、使用了GQA缩减caches、激活函数使用swiglu,基本就是这些了。
补充一点:经过查看代码,发现chatglm3-6b和chatglm2-6b的代码基本一模一样,只有在tokenizer处理输入的时候和返回response的时候有一点不一样,所以就不对chatglm3-6b做单独的介绍了。


http://www.niftyadmin.cn/n/5868924.html

相关文章

C++学习之C概述、数据类型、进制转换与数据存储

一.C概述 1.什么是C语言 2.C语言发展历史 3.编写C程序--环境搭建 4.编写C程序-第一个C程序 #define _CRT_SECURE_NO_WARNINGS #include<stdio.h> #include<string.h> #include<stdlib.h> void test01() { system("cls"); printf("…

idea导入新项目pom报错设置

修改项目中各module的java版本 修改maven 执行的java版本 打开Product Structrue 修改一遍module的java版本 清缓存重启idea 先 mvn clean 再刷包下载 以上不生效将项目从maven中移除再引入 操作步骤&#xff1a; 右键项目根路径的pom.xml文件&#xff0c;maven中ignore…

一键导出数据库表到Excel

工作中&#xff0c;我们经常需要将数据库表导出到Excel&#xff0c;通常我们会用数据库编辑器之类的工具提供的导出功能来导出&#xff0c;但是它们的导出功能通常都比较简单。 这篇文章将介绍一种简单易用并且功能强大的导出方法。 新增导出 打开的卢导表工具&#xff0c;新…

第十一章 Kubernetes运维—镜像仓库

目录 一、Harbor介绍 二、Harbor部署 一、Harbor介绍 基于策略的镜像复制&#xff1a;可以在不同的镜像仓库中对镜像进行复制传送&#xff0c;比如我们可以在官方镜像仓库中定义一个策略&#xff0c;让他实时地同步到我们自己的私有镜像仓库中。镜像的漏洞扫描&#xff1a;在我…

MongoDB 数据库简介

MongoDB 数据库简介 引言 随着互联网技术的飞速发展,数据已经成为企业的重要资产。为了高效地管理和处理这些数据,数据库技术应运而生。MongoDB作为一种流行的NoSQL数据库,因其灵活的数据模型和高效的数据处理能力,受到了广泛的关注。本文将为您详细介绍MongoDB的基本概念…

【算法设计与分析】(一)介绍算法与复杂度分析

【算法设计与分析】&#xff08;一&#xff09;介绍算法与复杂度分析 前言一、什么是算法&#xff1f;二、算法的抽象机制三、描述算法四、复杂度分析4.1 时间复杂度4.2 空间复杂度 前言 从搜索引擎的高效检索&#xff0c;到推荐系统的个性化推荐&#xff0c;再到人工智能领域…

索提诺比率(Sortino Ratio):更精准的风险调整收益指标(中英双语)

索提诺比率&#xff08;Sortino Ratio&#xff09;&#xff1a;更精准的风险调整收益指标 &#x1f4c9;&#x1f4ca; &#x1f4cc; 什么是索提诺比率&#xff1f; 在投资分析中&#xff0c;我们通常使用 夏普比率&#xff08;Sharpe Ratio&#xff09; 来衡量风险调整后的…

无人机定点运输技术!

核心要点 定位与导航 GPS/北斗定位&#xff1a;依赖卫星系统实现高精度定位。 视觉导航&#xff1a;通过摄像头和计算机视觉技术识别环境。 惯性导航&#xff1a;利用加速度计和陀螺仪进行位置推算。 路径规划 避障算法&#xff1a;实时检测并避开障碍物。 动态路径调整…