电竞比分网-中国电竞赛事及体育赛事平台

分享

使用PyTorch從零構(gòu)建Llama 3(附代碼)

 東西二王 2024-11-02

2024-10-16 07:08·數(shù)據(jù)派THU

來源:DeepHub IMBA

本文約12000字,建議閱讀15+分鐘

本文將詳細(xì)指導(dǎo)如何從零開始構(gòu)建完整的Llama 3模型架構(gòu)。

我們上次發(fā)了用PyTorch從零開始編寫DeepSeek-V2的文章后,有小伙伴留言說希望介紹一下Llama 3。那么今天他就來了,本文將詳細(xì)指導(dǎo)如何從零開始構(gòu)建完整的Llama 3模型架構(gòu),并在自定義數(shù)據(jù)集上執(zhí)行訓(xùn)練和推理。

[圖1]:Llama 3架構(gòu)展示訓(xùn)練和推理流程。因?yàn)楣俜絃lama 3論文中未提供相關(guān)圖表。所以此圖為大概架構(gòu)圖,閱讀本文后你應(yīng)能繪制出更為精確的架構(gòu)圖。

本文目標(biāo)

通過本文。你可以了解到:

  1. 深入理解Llama 3模型各組件的底層工作原理。

  2. 編寫代碼構(gòu)建Llama 3的每個(gè)組件,并將它們組裝成一個(gè)功能完整的Llama 3模型。

  3. 編寫代碼使用新的自定義數(shù)據(jù)集訓(xùn)練模型。

  4. 編寫代碼執(zhí)行推理,使Llama 3模型能夠根據(jù)輸入提示生成新文本。

1、輸入模塊

如圖1所示,輸入模塊包含三個(gè)組件:文本/提示、分詞器和嵌入。

輸入模塊內(nèi)部工作流程

讓我們通過下圖了解輸入模塊內(nèi)的工作流程。

[圖2]:輸入模塊流程圖,展示提示、分詞器和嵌入流程。

首先,單個(gè)或批量文本/提示被輸入模型。例如:圖中的"Hello World"。

輸入模型的必須是數(shù)字格式,因?yàn)槟P蜔o法直接處理文本。分詞器將這些文本/提示轉(zhuǎn)換為標(biāo)記ID(詞匯表中標(biāo)記的索引號(hào)表示)。我們將使用Tiny Shakespeare數(shù)據(jù)集構(gòu)建詞匯表并訓(xùn)練模型。Llama 3模型使用TikToken作為分詞器,這是一種子詞分詞器。但是我們這個(gè)實(shí)現(xiàn)將使用字符級(jí)分詞器。這樣做的主要原因是讓我們能夠自行構(gòu)建詞匯表和分詞器,包括編碼和解碼函數(shù),這樣可以深入理解底層工作原理并完全掌控代碼。

每個(gè)標(biāo)記ID將被轉(zhuǎn)換為128維的嵌入向量(原始Llama 3 8B中為4096維)。然后這些嵌入將被傳遞到下一個(gè)解碼器模塊。

輸入模塊代碼實(shí)現(xiàn):

 # 導(dǎo)入必要的庫(kù)  
 import torch  
 from torch import nn  
 from torch.nn import functional as F  


 import math  
 import numpy as np  
 import time  
 from dataclasses import dataclass  
 from typing import Optional, Tuple, List  
 import pandas as pd  
 from matplotlib import pyplot as plt ### 步驟1: 輸入模塊 ###  


 # 使用Tiny Shakespeare數(shù)據(jù)集實(shí)現(xiàn)字符級(jí)分詞器。部分字符級(jí)分詞器代碼參考自Andrej Karpathy的GitHub倉(cāng)庫(kù)
 # (https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/prepare.py)
 # 加載tiny_shakespeare數(shù)據(jù)文件 (https://github.com/tamangmilan/llama3/blob/main/tiny_shakespeare.txt)  


 device: str = 'cuda' if torch.cuda.is_available() else 'cpu'   # 根據(jù)可用性分配設(shè)備為cuda或cpu  


 # 加載tiny_shakespeare數(shù)據(jù)文件  
 with open('tiny_shakespeare.txt', 'r') as f:  
   data = f.read()  


 # 通過提取tiny_shakespeare數(shù)據(jù)中的所有唯一字符準(zhǔn)備詞匯表  
 vocab = sorted(list(set(data)))  


 # 訓(xùn)練Llama 3模型需要額外的標(biāo)記,如<|begin_of_text|>、<|end_of_text|>和<|pad_id|>,將它們添加到詞匯表中  
 vocab.extend(['<|begin_of_text|>','<|end_of_text|>','<|pad_id|>'])  
 vocab_size = len(vocab)  


 # 創(chuàng)建字符與詞匯表中對(duì)應(yīng)整數(shù)索引之間的映射。
 # 這對(duì)于構(gòu)建分詞器的編碼和解碼函數(shù)至關(guān)重要。
 itos = {i:ch for i, ch in enumerate(vocab)}  
 stoi = {ch:i for i, ch in enumerate(vocab)}  


 # 分詞器編碼函數(shù):輸入字符串,輸出整數(shù)列表  
 def encode(s):  
   return [stoi[ch] for ch in s]  


 # 分詞器解碼函數(shù):輸入整數(shù)列表,輸出字符串  
 def decode(l):  
   return ''.join(itos[i] for i in l)  


 # 定義稍后在模型訓(xùn)練中使用的張量標(biāo)記變量  
 token_bos = torch.tensor([stoi['<|begin_of_text|>']], dtype=torch.int, device=device)  
 token_eos = torch.tensor([stoi['<|end_of_text|>']], dtype=torch.int, device=device)  
 token_pad = torch.tensor([stoi['<|pad_id|>']], dtype=torch.int, device=device)  


 prompts = "Hello World"  
 encoded_tokens = encode(prompts)  
 decoded_text = decode(encoded_tokens)  


 ### 輸入模塊代碼測(cè)試 ###  
 # 取消下面的三重引號(hào)來執(zhí)行測(cè)試  
 """  
 print(f"Shakespeare文本字符長(zhǎng)度: {len(data)}")  
 print(f"詞匯表內(nèi)容: {''.join(vocab)}\n")  
 print(f"詞匯表大小: {vocab_size}")  
 print(f"編碼后的標(biāo)記: {encoded_tokens}")  
 print(f"解碼后的文本: {decoded_text}")  
 """  
 ### 測(cè)試結(jié)果: ###  
 """  
 Shakespeare文本字符長(zhǎng)度: 1115394  
 詞匯表內(nèi)容:   
  !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz<|begin_of_text|><|end_of_text|><|pad_id|>  


 詞匯表大小: 68  
 編碼后的標(biāo)記: [20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42]  
 解碼后的文本: Hello World  
 """

2、解碼器模塊

參照?qǐng)D1的架構(gòu)圖,解碼器模塊包含以下子組件:

  • RMS歸一化

  • 旋轉(zhuǎn)位置編碼

  • KV緩存

  • 分組查詢注意力

  • 前饋網(wǎng)絡(luò)

  • 解碼器塊

RMS歸一化(Root Mean Square Normalization)

RMSNorm的必要性

從圖1可以看出,輸入模塊輸出(嵌入向量)經(jīng)過RMSNorm模塊。這是因?yàn)榍度胂蛄烤哂卸鄠€(gè)維度(Llama3-8b中為4096維),可能出現(xiàn)不同范圍的值。這會(huì)導(dǎo)致模型梯度爆炸或消失,從而導(dǎo)致收斂緩慢甚至發(fā)散。而RMSNorm將這些值歸一化到一定范圍,有助于穩(wěn)定和加速訓(xùn)練過程。這使得梯度具有更一致的幅度,從而加快模型收斂。

RMSNorm的工作原理

[圖3]:對(duì)形狀為[3,3]的輸入嵌入應(yīng)用RMSNorm

類似于層歸一化,RMSNorm沿嵌入特征或維度應(yīng)用。上圖中的嵌入形狀為[3,3],意味著每個(gè)標(biāo)記有3個(gè)維度。

示例:對(duì)第一個(gè)標(biāo)記X1的嵌入應(yīng)用RMSNorm:

X1標(biāo)記在每個(gè)維度上的值(x11、x12和x13)分別除以所有這些值的均方根。公式如圖3所示。

為避免除以零并保證數(shù)值穩(wěn)定性,在均方根中加入一個(gè)小常數(shù)E(Epsilon)。乘以一個(gè)縮放參數(shù)Gamma (Y)。每個(gè)特征都有一個(gè)獨(dú)特的Gamma參數(shù)(如圖中d1維度的Y1、d2維度的Y2和d3維度的Y3),這是一個(gè)學(xué)習(xí)參數(shù),可以向上或向下縮放以進(jìn)一步穩(wěn)定歸一化。gamma參數(shù)初始化為1(如上面的計(jì)算所示)。

如示例所示,嵌入值原本較大且分布范圍寬。應(yīng)用RMSNorm后,值變小且范圍縮小。計(jì)算使用實(shí)際的RMSNorm函數(shù)完成。

RMSNorm相比層歸一化的優(yōu)勢(shì)

如上例所示沒有計(jì)算任何均值或方差,而這在層歸一化中是必需的。所以RMSNorm通過避免計(jì)算均值和方差減少了計(jì)算開銷。根據(jù)作者的研究,RMSNorm在不影響準(zhǔn)確性的同時(shí)提供了性能優(yōu)勢(shì)。

RMSNorm代碼實(shí)現(xiàn):

 # 步驟2: 解碼器模塊  
 # 注:由于Llama 3模型由Meta開發(fā),為了與他們的代碼庫(kù)保持一致并考慮未來兼容性,
 # 我將使用Meta GitHub上的大部分代碼,并進(jìn)行必要的修改以實(shí)現(xiàn)我們的目標(biāo)。


 # 定義參數(shù)數(shù)據(jù)類:我們將在模型構(gòu)建、訓(xùn)練和推理過程中使用這些參數(shù)。
 # 注:為了更快地看到訓(xùn)練和推理結(jié)果,而不是專注于高準(zhǔn)確性,我們對(duì)大多數(shù)參數(shù)采用較低的值,
 # 這些值在Llama 3模型中設(shè)置得更高。 @dataclass  
 class ModelArgs:  
     dim: int = 512              # 嵌入維度  
     n_layers: int = 8           # 模型解碼器塊的數(shù)量  
     n_heads: int = 8            # 查詢嵌入的頭數(shù)  
     n_kv_heads: int = 4         # 鍵和值嵌入的頭數(shù)  
     vocab_size: int = len(vocab) # 詞匯表長(zhǎng)度  
     multiple_of: int = 256        # 用于計(jì)算前饋網(wǎng)絡(luò)維度  
     ffn_dim_multiplier: Optional[float] = None  # 用于計(jì)算前饋網(wǎng)絡(luò)維度  
     norm_eps: float = 1e-5                       # RMSNorm計(jì)算的默認(rèn)Epsilon值  
     rope_theta: float = 10000.0   # RePE計(jì)算的默認(rèn)theta值  


     max_batch_size: int = 10     # 最大批量大小  
     max_seq_len: int = 256         # 最大序列長(zhǎng)度  


     epochs: int = 2500             # 總訓(xùn)練迭代次數(shù)  
     log_interval: int = 10        # 打印日志和損失值的間隔數(shù)     
     device: str = 'cuda' if torch.cuda.is_available() else 'cpu'   # 根據(jù)可用性分配設(shè)備為cuda或cpu 


 ## 步驟2a: RMSNorm  


 class RMSNorm(nn.Module):  
   def __init__(self, dim: int, eps: float = 1e-6):  
     super().__init__()  
     device = ModelArgs.device  
     self.eps = eps  
     # 縮放參數(shù)gamma,初始化為1,參數(shù)數(shù)量等于dim的大小  
     self.weight = nn.Parameter(torch.ones(dim).to(device))  


   def _norm(self, x):  
     return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(device)  


   def forward(self, x):  
     #形狀: x[bs,seq,dim]  
     output = self._norm(x.float()).type_as(x)  


     #形狀: x[bs,seq,dim] -> x_norm[bs,seq,dim]  
     return output * self.weight  


 ### RMSNorm代碼測(cè)試 ###  
 # 取消下面的三重引號(hào)來執(zhí)行測(cè)試  
 """  
 x = torch.randn((ModelArgs.max_batch_size, ModelArgs.max_seq_len, ModelArgs.dim), device=device)  
 rms_norm = RMSNorm(dim=ModelArgs.dim)  
 x_norm = rms_norm(x)  


 print(f"x的形狀: {x.shape}")  
 print(f"x_norm的形狀: {x_norm.shape}")  
 """  
 ### 測(cè)試結(jié)果: ###  
 """  
 x的形狀: torch.Size([10, 256, 512])  
 x_norm的形狀: torch.Size([10, 256, 512])  
 """

旋轉(zhuǎn)位置編碼(Rotary Positional Encoding, RoPE)

回顧之前的步驟,我們已將輸入文本轉(zhuǎn)換為嵌入,并對(duì)嵌入應(yīng)用了RMSNorm。然而,這里存在一個(gè)問題:假設(shè)輸入文本是"I love apple"或"apple love I",模型會(huì)將兩個(gè)句子視為相同并以相同方式學(xué)習(xí)。這是因?yàn)榍度胫袥]有為模型定義順序信息。因此對(duì)于任何語言模型來說,保持標(biāo)記的順序至關(guān)重要。在Llama 3模型架構(gòu)中,引入了旋轉(zhuǎn)位置編碼(RoPE)來定義句子中每個(gè)標(biāo)記的位置,這不僅維護(hù)了順序,還保留了句子中標(biāo)記的相對(duì)位置信息。

旋轉(zhuǎn)位置編碼的工作原理

RoPE是一種位置編碼方法,它通過添加絕對(duì)位置信息以及包含標(biāo)記之間的相對(duì)位置信息來編碼嵌入,從而維護(hù)句子中標(biāo)記的順序。它通過使用一個(gè)特殊的旋轉(zhuǎn)矩陣來旋轉(zhuǎn)給定的嵌入來執(zhí)行編碼操作。這種利用旋轉(zhuǎn)矩陣的簡(jiǎn)潔而強(qiáng)大的數(shù)學(xué)推導(dǎo)是RoPE的核心。

[圖4]:應(yīng)用于2維向量的旋轉(zhuǎn)矩陣

上圖展示了旋轉(zhuǎn)矩陣應(yīng)用于2維向量的情況。Llama 3模型中的維度數(shù)是4096,遠(yuǎn)高于此。我們?cè)敿?xì)介紹如何對(duì)更高維度的嵌入應(yīng)用旋轉(zhuǎn)。

[圖5]:RoPE應(yīng)用于嵌入的示例

嵌入的旋轉(zhuǎn)涉及每個(gè)嵌入位置(m)值和theta (θ)對(duì)每對(duì)嵌入維度的乘法。這就是RoPE如何通過實(shí)現(xiàn)旋轉(zhuǎn)矩陣來捕獲絕對(duì)位置和相對(duì)位置信息的方式。

注意:在執(zhí)行旋轉(zhuǎn)之前,需要將旋轉(zhuǎn)矩陣轉(zhuǎn)換為極坐標(biāo)形式,并將嵌入向量轉(zhuǎn)換為復(fù)數(shù)。旋轉(zhuǎn)完成后,旋轉(zhuǎn)后的嵌入需要轉(zhuǎn)換回實(shí)數(shù)以進(jìn)行注意力操作。另外RoPE僅應(yīng)用于查詢和鍵嵌入,不適用于值嵌入。

RoPE的代碼實(shí)現(xiàn):

 ## 步驟2b: RoPE實(shí)現(xiàn)  
 def precompute_freqs_cis(dim:int, seq_len: int, theta: float=10000.0):  
   # 計(jì)算每對(duì)維度的Theta值,即dim/2  
   device = ModelArgs.device  
   freqs = 1.0 / (theta ** (torch.arange(0, dim, 2,device=device)[:(dim//2)].float()/dim))  


   # 計(jì)算序列中位置(m)的范圍  
   t = torch.arange(seq_len, dtype=torch.float32, device=device)  


   # freqs給出序列中所有標(biāo)記位置的Theta值范圍  
   freqs = torch.outer(t, freqs).to(device)  


   # 這是需要轉(zhuǎn)換為極坐標(biāo)形式的旋轉(zhuǎn)矩陣,以便對(duì)嵌入執(zhí)行旋轉(zhuǎn)  
   freqs_cis = torch.polar(torch.ones_like(freqs).to(device), freqs).to(device)  
   return freqs_cis  


 def reshape_for_broadcast(freqs_cis, x):  
   ndim = x.ndim  
   assert 0<=1<ndim  
   assert freqs_cis.shape == (x.shape[1],x.shape[-1]), "freqs_cis的最后兩個(gè)維度必須與x匹配"  
   shape = [d if i==1 or i==ndim-1 else 1 for i,d in enumerate(x.shape)]  
   return freqs_cis.view(*shape)  


 def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor)->Tuple[torch.Tensor, torch.Tensor]:  
   device = ModelArgs.device  
   # 同時(shí)對(duì)查詢和鍵嵌入應(yīng)用旋轉(zhuǎn)位置編碼  
   # 首先:xq和xk嵌入的最后一個(gè)維度需要重塑為一對(duì)。因?yàn)樾D(zhuǎn)矩陣應(yīng)用于每對(duì)維度。  
   # 其次:將xq和xk轉(zhuǎn)換為復(fù)數(shù),因?yàn)樾D(zhuǎn)矩陣只適用于復(fù)數(shù)  
   xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)).to(device)    #xq_:[bsz, seq_len, n_heads, head_dim/2]  
   xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)).to(device)    #xk_:[bsz, seq_len, n_heads, head_dim/2]  


   # 旋轉(zhuǎn)矩陣(freqs_cis)在seq_len(dim=1)和head_dim(dim=3)維度上應(yīng)與嵌入匹配  
   # 此外,freqs_cis的形狀應(yīng)與xq和xk相同,因此將freqs_cis的形狀從[seq_len,head_dim]改變?yōu)閇1,seq_len,1,head_dim]  
   freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  


   # 最后,通過與freqs_cis相乘執(zhí)行旋轉(zhuǎn)操作。  
   # 旋轉(zhuǎn)完成后,將xq_out和xk_out轉(zhuǎn)換回實(shí)數(shù)并返回  
   xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).to(device) #xq_out:[bsz, seq_len, n_heads, head_dim]  
   xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).to(device) #xk_out:[bsz, seq_len, n_heads, head_dim]  
   return xq_out.type_as(xq), xk_out.type_as(xk)  


 ### RoPE代碼測(cè)試 ###  
 # 注:x_norm在RMSNorm測(cè)試中計(jì)算,這里用于測(cè)試。  
 # 取消下面的三重引號(hào)來執(zhí)行測(cè)試  
 """  
 head_dim = ModelArgs.dim//ModelArgs.n_heads  
 wq = nn.Linear(ModelArgs.dim, ModelArgs.n_heads * head_dim, bias=False, device=device)  
 wk = nn.Linear(ModelArgs.dim, ModelArgs.n_kv_heads * head_dim, bias=False, device=device)  
 xq = wq(x_norm)  
 xk = wk(x_norm)  
 print(f"xq.shape: {xq.shape}")  
 print(f"xk.shape: {xk.shape}")  


 xq = xq.view(xq.shape[0],xq.shape[1],ModelArgs.n_heads, head_dim)  
 xk = xk.view(xk.shape[0],xk.shape[1],ModelArgs.n_kv_heads, head_dim)  
 print(f"xq.re-shape: {xq.shape}")  
 print(f"xk.re-shape: {xk.shape}")  


 freqs_cis = precompute_freqs_cis(dim=head_dim, seq_len=ModelArgs.max_seq_len)  
 print(f"freqs_cis.shape: {freqs_cis.shape}")  


 xq_rotate, xk_rotate = apply_rotary_emb(xq, xk, freqs_cis)  
 print(f"xq_rotate.shape: {xq_rotate.shape}")  
 print(f"xk_rotate.shape: {xk_rotate.shape}")  
 """  
 ### 測(cè)試結(jié)果: ###  
 """  
 xq.shape: torch.Size([10, 256, 512])  
 xk.shape: torch.Size([10, 256, 256])  
 xq.re-shape: torch.Size([10, 256, 8, 64])  
 xk.re-shape: torch.Size([10, 256, 4, 64])  
 freqs_cis.shape: torch.Size([256, 32])  
 xq_rotate.shape: torch.Size([10, 256, 8, 64])  
 xk_rotate.shape: torch.Size([10, 256, 4, 64])  
 """

KV緩存(僅用于推理)

在Llama 3架構(gòu)中,推理階段引入了KV緩存的概念,用于以鍵和值緩存的形式存儲(chǔ)先前生成的標(biāo)記。這些緩存用于計(jì)算自注意力以生成下一個(gè)標(biāo)記。只緩存鍵和值標(biāo)記,而不緩存查詢標(biāo)記,因此稱為KV緩存。

KV緩存的必要性

讓我們通過下圖來理解KV緩存的重要性。

[圖6]:KV緩存實(shí)現(xiàn)

圖中的A塊:在生成output3標(biāo)記時(shí),仍在計(jì)算先前的輸出標(biāo)記(output1, output2),這是不必要的。這在注意力計(jì)算期間導(dǎo)致了額外的矩陣乘法,顯著增加了計(jì)算資源的使用。

圖中的B塊:輸出標(biāo)記替換了查詢嵌入中的輸入標(biāo)記。KV緩存存儲(chǔ)了先前生成的標(biāo)記。在注意力分?jǐn)?shù)計(jì)算期間,我們只需要使用查詢中的1個(gè)標(biāo)記,并使用鍵和值緩存中的先前標(biāo)記。這將矩陣乘法從A塊的3x3減少到B塊的1x3,減少了約66%。在實(shí)際應(yīng)用中,對(duì)于巨大的序列長(zhǎng)度和批量大小,這將顯著減少計(jì)算資源的使用。

分組查詢注意力

分組查詢注意力與之前模型(如Llama 1)中使用的多頭注意力相似,唯一的區(qū)別在于為查詢、鍵和值”使用單獨(dú)的頭。分配給查詢的頭數(shù)是鍵和值頭數(shù)的n倍。讓我們通過圖表來進(jìn)一步理解。

[圖7]:分組查詢注意力和多頭注意力對(duì)比

在給定的圖中,多頭注意力在所有查詢、鍵和值中都有相等數(shù)量的頭,即n_heads = 8。

分組查詢注意力塊有8個(gè)查詢頭(n_heads)和4個(gè)鍵和值頭(n_kv_heads),這是查詢頭數(shù)量的一半。

分組查詢注意力的優(yōu)勢(shì)

盡管多頭注意力已經(jīng)表現(xiàn)出色,引入分組查詢注意力是有其特定原因。我們先回顧KV緩存,KV緩存確實(shí)大大減少了計(jì)算資源的使用。但是隨著KV緩存存儲(chǔ)越來越多的先前標(biāo)記,內(nèi)存使用會(huì)顯著增加。這對(duì)模型性能和計(jì)算成本都不利。所以引入了分組查詢注意力。減少K和V的頭數(shù)會(huì)減少需要存儲(chǔ)的參數(shù)數(shù)量,從而減少內(nèi)存使用。多項(xiàng)測(cè)試結(jié)果表明,使用這種方法模型的準(zhǔn)確性仍保持在相近的范圍內(nèi)。

注意力模塊的代碼實(shí)現(xiàn):

 ## 注意力模塊 [步驟2c: KV緩存; 步驟2d: 分組查詢注意力]  
 ## 如前所述,命名約定遵循原始Meta LLama3 GitHub  


 class Attention(nn.Module):  
   def __init__(self, args: ModelArgs):  
     super().__init__()  
     self.args = args  
     # 嵌入維度  
     self.dim = args.dim  
     # 分配給查詢的頭數(shù)  
     self.n_heads = args.n_heads  
     # 分配給鍵和值的頭數(shù)。如果為"None",則數(shù)量與查詢相同。  
     self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads  
     # 每個(gè)頭相對(duì)于模型維度的維度  
     self.head_dim = args.dim // args.n_heads  
     # 重復(fù)次數(shù),以使鍵、值頭數(shù)與查詢頭數(shù)匹配  
     self.n_rep = args.n_heads // args.n_kv_heads  


     # 初始化鍵、查詢、值和輸出的權(quán)重。注意q和kv的權(quán)重out_feature值基于其頭數(shù)  
     self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False, device=device)  
     self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)  
     self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)  
     self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False, device=device)  


     # 初始化緩存以在開始時(shí)存儲(chǔ)鍵、值 (KV緩存實(shí)現(xiàn))  
     self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)  
     self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)  


   def forward(self, x: torch.Tensor, start_pos, inference):  
     # 輸入嵌入的形狀: [bsz,seq_len,dim]  
     bsz, seq_len, _ = x.shape  
     # 掩碼將在"訓(xùn)練"期間使用,由于使用KV緩存,"推理"不需要掩碼。
     mask = None  


     xq = self.wq(x)  #x[bsz,seq_len,dim]*wq[dim,n_heads * head_dim] -> q[bsz,seq_len,n_heads * head_dim]  
     xk = self.wk(x)  #x[bsz,seq_len,dim]*wq[dim,n_kv_heads * head_dim] -> k[bsz,seq_len,n_kv_heads * head_dim]  
     xv = self.wv(x)  #x[bsz,seq_len,dim]*wq[dim,n_kv_heads * head_dim] -> v[bsz,seq_len,n_kv_heads * head_dim]  


     # 根據(jù)頭數(shù)重塑查詢、鍵和值 (分組查詢注意力實(shí)現(xiàn))  
     xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)      #xq[bsz,seq_len,n_heads, head_dim]  
     xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)   #xk[bsz,seq_len,n_kv_heads, head_dim]  
     xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)   #xv[bsz,seq_len,n_kv_heads, head_dim]  


     # 模型 - 推理模式: kv-cache僅在推理模式下啟用  
     if inference:  
       # 計(jì)算序列中每個(gè)位置的旋轉(zhuǎn)矩陣  
       freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len * 2)  
       # 在推理過程中,我們應(yīng)該只取從當(dāng)前標(biāo)記位置開始的旋轉(zhuǎn)矩陣范圍  
       freqs_cis = freqs_cis[start_pos : start_pos + seq_len]  
       # 將RoPE應(yīng)用于查詢和鍵嵌入  
       xq, xk = apply_rotary_emb(xq, xk, freqs_cis)  


       self.cache_k = self.cache_k.to(xq)  
       self.cache_v = self.cache_v.to(xq)  
       # 將鍵和值標(biāo)記嵌入存儲(chǔ)到它們各自的緩存中 [KV緩存實(shí)現(xiàn)]  
       self.cache_k[:bsz, start_pos:start_pos + seq_len] = xk  
       self.cache_v[:bsz, start_pos:start_pos + seq_len] = xv  


       # 為注意力計(jì)算分配所有直到當(dāng)前標(biāo)記位置的先前標(biāo)記嵌入給鍵和值變量  
       keys = self.cache_k[:bsz, :start_pos + seq_len]  
       values = self.cache_v[:bsz, :start_pos + seq_len]  


       # 此時(shí),鍵和值的形狀與查詢嵌入不同,但為了計(jì)算注意力分?jǐn)?shù),它們必須相同  
       # 使用repeat_kv函數(shù)使鍵、值的形狀與查詢形狀相同  
       keys = repeat_kv(keys, self.n_rep)      #keys[bsz,seq_len,n_heads,head_dim]  
       values = repeat_kv(values, self.n_rep)  #values[bsz,seq_len,n_heads,head_dim]  


     # 模式 - 訓(xùn)練模式: 未實(shí)現(xiàn)KV-Cache  
     else:  
       # 計(jì)算旋轉(zhuǎn)矩陣并將RoPE應(yīng)用于訓(xùn)練的查詢和鍵  
       freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len)  


       #xq[bsz,seq_len,n_heads, head_dim], xk[bsz,seq_len,n_heads, head_dim]  
       xq, xk = apply_rotary_emb(xq, xk, freqs_cis)  


       # 使用repeat_kv函數(shù)使鍵、值的形狀與查詢形狀相同  
       #keys[bsz,seq_len,n_heads,head_dim], #values[bsz,seq_len,n_heads,head_dim]  
       keys = repeat_kv(xk, self.n_rep)  
       values = repeat_kv(xv, self.n_rep)  


       # 對(duì)于訓(xùn)練模式,我們將計(jì)算掩碼并稍后應(yīng)用于注意力分?jǐn)?shù)  
       mask = torch.full((seq_len, seq_len),float("-inf"),device=self.args.device)  
       mask = torch.triu(mask, diagonal=1).to(self.args.device)  


     # 為了計(jì)算注意力,我們需要執(zhí)行轉(zhuǎn)置操作來重塑所有查詢、鍵和值,將頭部放在維度1,序列放在維度2  
     xq = xq.transpose(1,2)                  #xq[bsz,n_heads,seq_len,head_dim]  
     keys = keys.transpose(1,2)              #keys[bsz,n_heads,seq_len,head_dim]  
     values = values.transpose(1,2)          #values[bsz,n_heads,seq_len,head_dim]  


     # 計(jì)算注意力分?jǐn)?shù)  
     scores = torch.matmul(xq, keys.transpose(2,3)).to(self.args.device)/math.sqrt(self.head_dim)  
     if mask is not None:  
       scores = scores + mask  


     # 對(duì)注意力分?jǐn)?shù)應(yīng)用softmax  
     scores = F.softmax(scores.float(), dim=-1).type_as(xq)  
     # 注意力分?jǐn)?shù)與值的矩陣乘法  
     output = torch.matmul(scores, values).to(self.args.device)  


     # 我們得到了每個(gè)頭部的上下文嵌入  
     # 所有頭部需要重塑回來并組合,以給出單個(gè)上下文注意力輸出  
     # 形狀變化: output[bsz,n_heads,seq_len,head_dim] -> output[bsz,seq_len, n_heads,head_dim] -> output[bsz,seq_len, n_heads * head_dim]  
     output = output.transpose(1,2).contiguous().view(bsz, seq_len, -1)  


     # 形狀: output [bsz,seq_len,dim]  
     return self.wo(output)  


 # 如果鍵/值頭的數(shù)量少于查詢頭,此函數(shù)使用所需的重復(fù)次數(shù)擴(kuò)展鍵/值嵌入  
 def repeat_kv(x:torch.Tensor, n_rep: int)-> torch.Tensor:  
   bsz, seq_len, n_kv_heads, head_dim = x.shape  
   if n_rep == 1:  
     return x  
   return (  
       x[:,:,:,None,:]  
       .expand(bsz,seq_len,n_kv_heads,n_rep, head_dim)  
       .reshape(bsz,seq_len,n_kv_heads * n_rep, head_dim)  
   )  


 ### 測(cè)試: Repeat_kv函數(shù) ###  
 # 注: xk, x_norm已在RoPE, RMSNorm測(cè)試中計(jì)算,這里用于測(cè)試  
 # 取消下面的三重引號(hào)來執(zhí)行測(cè)試  
 """  
 n_rep = ModelArgs.n_heads // ModelArgs.n_kv_heads  
 keys = repeat_kv(xk, n_rep)  
 print(f"xk.shape: {xk.shape}")  
 print(f"keys.shape: {keys.shape}")  


 ## 測(cè)試: Attention函數(shù)  
 # 取消下面的三重引號(hào)來執(zhí)行測(cè)試  


 attention = Attention(ModelArgs)  
 x_out = attention(x_norm,start_pos=0, inference=False)  
 print(f"x_out.shape: {x_out.shape}")  
 """  
 ### 測(cè)試結(jié)果: ###  
 """  
 xk.shape: torch.Size([10, 256, 4, 64])  
 keys.shape: torch.Size([10, 256, 8, 64])  
 x_out.shape: torch.Size([10, 256, 512])  
 """

前饋網(wǎng)絡(luò) (使用SwiGLU激活函數(shù))

如圖1所示,注意力輸出首先經(jīng)過RMSNorm,然后輸入前饋網(wǎng)絡(luò)。在前饋網(wǎng)絡(luò)中,注意力輸出嵌入會(huì)在其隱藏層中擴(kuò)展到更高維度,學(xué)習(xí)標(biāo)記的更復(fù)雜特征。

為什么選擇SwiGLU而非ReLU

[圖8]:帶有SwiGLU函數(shù)的前饋網(wǎng)絡(luò)

如圖所示,SwiGLU函數(shù)在正軸上的行為與ReLU相似。然而,在負(fù)軸上,SwiGLU輸出一些負(fù)值,這在學(xué)習(xí)較小值時(shí)可能有用,而不是像ReLU那樣在負(fù)軸上為平坦的0。根據(jù)作者的研究,使用SwiGLU的性能優(yōu)于ReLU,因此被選用。

前饋網(wǎng)絡(luò)的代碼實(shí)現(xiàn):

 ## 步驟2e: 前饋網(wǎng)絡(luò) (SwiGLU激活)  
 class FeedForward(nn.Module):  
   def __init__(self, dim:int, hidden_dim:int, multiple_of:int, ffn_dim_multiplier: Optional[float]):  
     super().__init__()  
     # 模型嵌入維度  
     self.dim = dim  


     # 我們必須使用Meta提供的隱藏維度計(jì)算方法,這是該模型的理想設(shè)置  
     # 隱藏維度的計(jì)算方式使其是256的倍數(shù)  
     hidden_dim = int(2 * hidden_dim/3)  
     if ffn_dim_multiplier is not None:  
       hidden_dim = int(ffn_dim_multiplier * hidden_dim)  
     hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)  


     # 定義隱藏層權(quán)重  
     self.w1 = nn.Linear(self.dim, hidden_dim, bias=False, device=device)  
     self.w2 = nn.Linear(hidden_dim, self.dim, bias=False, device=device)  
     self.w3 = nn.Linear(self.dim, hidden_dim, bias=False, device=device)  


   def forward(self, x):  
     # 形狀: [bsz,seq_len,dim]  
     return self.w2(F.silu(self.w1(x)) * self.w3(x))  


 ### 測(cè)試: 前饋模塊 ###  
 # 注: x_out已在Attention測(cè)試中計(jì)算,這里用于測(cè)試  
 # 取消下面的三重引號(hào)來執(zhí)行測(cè)試  
 """  
 feed_forward = FeedForward(ModelArgs.dim, 4 * ModelArgs.dim, ModelArgs.multiple_of, ModelArgs.ffn_dim_multiplier)  
 x_out = rms_norm(x_out)  
 x_out = feed_forward(x_out)  
 print(f"前饋輸出: x_out.shape: {x_out.shape}")  
 """  


 ### 測(cè)試結(jié)果: ###  
 """  
 前饋輸出: x_out.shape: torch.Size([10, 256, 512])  
 """

解碼器塊

如圖1所示,解碼器塊由多個(gè)子組件組成,我們?cè)谇懊娴牟糠种幸呀?jīng)實(shí)現(xiàn)了這些組件。以下是解碼器塊內(nèi)進(jìn)行的逐步操作:

1、來自輸入模塊的嵌入首先經(jīng)過注意力-RMSNorm,然后輸入分組查詢注意力模塊。

2、同時(shí),來自輸入模塊的原始嵌入與注意力輸出相加。

3、然后,這個(gè)結(jié)果經(jīng)過前饋-RMSNorm,輸入前饋網(wǎng)絡(luò)模塊。

4、前饋網(wǎng)絡(luò)的輸出再次與步驟2的結(jié)果相加。

5、最終輸出被稱為解碼器輸出。這個(gè)解碼器輸出然后作為輸入傳遞給下一個(gè)解碼器塊。這個(gè)過程在接下來的31個(gè)解碼器塊中重復(fù)。第32個(gè)解碼器塊的最終輸出然后傳遞到輸出模塊。

解碼器塊的代碼實(shí)現(xiàn):

 ## 步驟2f: 解碼器塊。類名為TransformerBlock,以匹配Meta Llama 3代碼庫(kù)  


 class TransformerBlock(nn.Module):  
   def __init__(self, args: ModelArgs):  
     super().__init__()  
     self.args = args  
     # 初始化注意力的RMSNorm  
     self.attention_norm = RMSNorm(dim=args.dim, eps = args.norm_eps)  
     # 初始化注意力類  
     self.attention = Attention(args)  
     # 初始化前饋網(wǎng)絡(luò)的RMSNorm  
     self.ff_norm = RMSNorm(dim=args.dim, eps = args.norm_eps)  
     # 初始化前饋網(wǎng)絡(luò)類  
     self.feedforward = FeedForward(args.dim, 4 * args.dim, args.multiple_of, args.ffn_dim_multiplier)  


   def forward(self, x, start_pos, inference):  
     # start_pos: 推理模式下的標(biāo)記位置, inference: True表示推理模式,False表示訓(xùn)練模式  
     # 1) 將輸入嵌入傳遞給attention_norm,然后傳遞給注意力模塊  
     # 2) 注意力的輸出與原始輸入(歸一化前)相加  
     h = x + self.attention(self.attention_norm(x), start_pos,inference)  


     # 1) 將注意力輸出傳遞給ff_norm,然后傳遞給前饋網(wǎng)絡(luò)  
     # 2) 前饋網(wǎng)絡(luò)的輸出與注意力輸出(ff_norm前)相加  
     out = h + self.feedforward(self.ff_norm(h))  
     # 形狀: [bsz,seq_len,dim]  
     return out  


 ### 測(cè)試: TransformerBlock ###  
 # 取消下面的三重引號(hào)來執(zhí)行測(cè)試  
 """  
 x = torch.randn((ModelArgs.max_batch_size, ModelArgs.max_seq_len, ModelArgs.dim), device=device)  
 transformer_block = TransformerBlock(ModelArgs)  
 transformer_block_out = transformer_block(x,start_pos=0, inference=False)  
 print(f"transformer_block_out.shape: {transformer_block_out.shape}")  
 """  


 ### 測(cè)試結(jié)果: ###  
 """  
 transformer_block_out.shape: torch.Size([10, 64, 128])  
 """

3、輸出模塊

最后一個(gè)解碼器塊的輸出將傳入輸出模塊。它首先經(jīng)過RMSNorm處理,然后傳入線性層生成logits。接下來根據(jù)模式的不同,會(huì)執(zhí)行以下兩種操作之一:

如果是推理模式,計(jì)算top_p概率并生成下一個(gè)標(biāo)記。如果達(dá)到最大生成長(zhǎng)度或生成的下一個(gè)標(biāo)記為句子結(jié)束標(biāo)記,則停止生成。

如果是訓(xùn)練模式,使用目標(biāo)標(biāo)簽計(jì)算損失,并重復(fù)訓(xùn)練直到達(dá)到最大epoch數(shù)。

下圖展示了輸出模塊的流程:

[圖9]:Llama 3在訓(xùn)練和推理模式下的輸出流程圖

最終的Llama 3模型實(shí)現(xiàn)

我們將組合三個(gè)模塊(輸入模塊、解碼器模塊和輸出模塊)的所有組件。這就構(gòu)成了我們的完整Llama 3模型。

 ## 步驟3: 輸出模塊  
 # 這是Llama 3模型。類名保持為Transformer以匹配Meta Llama 3模型  


 class Transformer(nn.Module):  
   def __init__(self, params: ModelArgs):  
     super().__init__()  
     # 設(shè)置params變量中的所有ModelArgs  
     self.params = params  
     # 從輸入模塊初始化嵌入類  
     self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)  


     # 初始化解碼器塊并將其存儲(chǔ)在ModuleList中   
     # 這是因?yàn)槲覀兊腖lama 3模型中有4個(gè)解碼器塊 (官方Llama 3有32個(gè)塊)  
     self.layers = nn.ModuleList()  
     for layer_id in range(params.n_layers):  
       self.layers.append(TransformerBlock(args=params))  


     # 為輸出模塊初始化RMSNorm  
     self.norm = RMSNorm(params.dim, eps = params.norm_eps)  


     # 在輸出模塊初始化線性層  
     self.output = nn.Linear(params.dim, params.vocab_size, bias=False)  


   def forward(self, x, start_pos=0, targets=None):  


     # start_pos: 推理模式的標(biāo)記位置, inference: True表示推理模式, False表示訓(xùn)練模式  
     # x是使用分詞器從文本或提示生成的標(biāo)記ID批次  
     # x[bsz, seq_len] -> h[bsz, seq_len, dim]  
     h = self.tok_embeddings(x)  


     # 如果目標(biāo)為None,則激活推理模式并設(shè)置為"True",否則為訓(xùn)練模式"False"  
     inference = targets is None  


     # 嵌入(h)然后將通過所有解碼器塊  
     for layer in self.layers:  
       h = layer(h, start_pos, inference)  


     # 最后解碼器塊的輸出將饋入RMSNorm  
     h = self.norm(h)  


     # 歸一化后,嵌入h將饋入線性層   
     # 線性層的主要任務(wù)是生成將嵌入映射到詞匯表大小的logits  
     # h[bsz, seq_len, dim] -> logits[bsz, seq_len, vocab_size]  
     logits = self.output(h).float()  
     loss = None  


     # 如果目標(biāo)不可用,則為推理模式  
     if targets is None:  
       loss = None  
     # 如果目標(biāo)可用,則為訓(xùn)練模式。計(jì)算損失以進(jìn)行進(jìn)一步的模型訓(xùn)練   
     else:  
       loss = F.cross_entropy(logits.view(-1, self.params.vocab_size), targets.view(-1))  


     return logits, loss  


 ### 測(cè)試: Transformer (Llama模型) ###  
 # 取消下面的三重引號(hào)來執(zhí)行測(cè)試  
 """  
 model = Transformer(ModelArgs).to(ModelArgs.device)  
 print(model)  
 """

[圖10]: Llama 3分層架構(gòu)

我們剛剛構(gòu)建的Llama 3模型結(jié)構(gòu)看起來很完整。現(xiàn)在我們可以開始訓(xùn)練過程了。

4、訓(xùn)練Llama 3模型

訓(xùn)練流程在輸出模塊流程圖(圖9)中已經(jīng)展示。在開始訓(xùn)練之前,讓我們先實(shí)現(xiàn)訓(xùn)練代碼。以下代碼塊中包含了必要的解釋。

 ## 步驟4: 訓(xùn)練Llama 3模型:  


 # 使用我們?cè)谳斎肽K部分構(gòu)建的分詞器的encode函數(shù),通過對(duì)整個(gè)tiny_shakespeare數(shù)據(jù)進(jìn)行編碼來創(chuàng)建數(shù)據(jù)集  
 dataset = torch.tensor(encode(data), dtype=torch.int).to(ModelArgs.device)  
 print(f"dataset-shape: {dataset.shape}")  


 # 定義函數(shù)從給定數(shù)據(jù)集生成批次  
 def get_dataset_batch(data, split, args:ModelArgs):  
   seq_len = args.max_seq_len  
   batch_size = args.max_batch_size  
   device = args.device  


   train = data[:int(0.8 * len(data))]  
   val = data[int(0.8 * len(data)): int(0.9 * len(data))]  
   test = data[int(0.9 * len(data)):]  


   batch_data = train  
   if split == "val":  
     batch_data = val  
   elif split == "test":  
     batch_data = test  


   # 從數(shù)據(jù)集中選擇隨機(jī)起點(diǎn),為訓(xùn)練、驗(yàn)證和測(cè)試提供隨機(jī)樣本  
   ix = torch.randint(0, len(batch_data) - seq_len - 3, (batch_size,)).to(device)  
   x = torch.stack([torch.cat([token_bos, batch_data[i:i+seq_len-1]]) for i in ix]).long().to(device)  
   y = torch.stack([torch.cat([batch_data[i+1:i+seq_len], token_eos]) for i in ix]).long().to(device)  


   return x, y  


 ### 測(cè)試: get_dataset函數(shù) ###  
 """  
 xs, ys = get_dataset_batch(dataset, split="train", args=ModelArgs)  
 print([(decode(xs[i].tolist()), decode(ys[i].tolist())) for i in range(len(xs))])  
 """  


 # 定義evaluate_loss函數(shù)來計(jì)算和存儲(chǔ)訓(xùn)練和驗(yàn)證損失,用于日志記錄和繪圖   @torch.no_grad()  
 def evaluate_loss(model, args:ModelArgs):  
   out = {}  
   model.eval()  


   for split in ["train", "val"]:  
     losses = []  
     for _ in range(10):        
       xb, yb = get_dataset_batch(dataset, split, args)  
       _, loss = model(x=xb, targets=yb)  
       losses.append(loss.item())  
     out[split] = np.mean(losses)  


   model.train()  
   return out  


 # 定義訓(xùn)練函數(shù)來執(zhí)行模型訓(xùn)練  
 def train(model, optimizer, args:ModelArgs):  
     epochs = args.epochs  
     log_interval = args.log_interval  
     device = args.device  
     losses = []     
     start_time = time.time()  


     for epoch in range(epochs):  
         optimizer.zero_grad()  


         xs, ys = get_dataset_batch(dataset, 'train', args)  
         xs = xs.to(device)  
         ys = ys.to(device)  
         logits, loss = model(x=xs, targets=ys)  
         loss.backward()  
         optimizer.step()  


         if epoch % log_interval == 0:  
             batch_time = time.time() - start_time  
             x = evaluate_loss(model, args)  
             losses.append(x)              
             print(f"Epoch {epoch} | val loss {x['val']:.3f} | Time {batch_time:.3f}")  
             start_time = time.time()  


     # 打印最終驗(yàn)證損失  
     print("驗(yàn)證損失: ", losses[-1]['val'])  
     # 在圖表中顯示間隔損失   
     return pd.DataFrame(losses).plot()

定義完訓(xùn)練函數(shù)。就可以開始訓(xùn)練過程,并在訓(xùn)練完成后觀察結(jié)果。

 ## 開始訓(xùn)練我們的Llama 3模型  
 model = Transformer(ModelArgs).to(ModelArgs.device)  
 optimizer = torch.optim.Adam(model.parameters())  


 train(model, optimizer, ModelArgs)

[圖11] 訓(xùn)練與驗(yàn)證損失圖

上圖顯示了訓(xùn)練和驗(yàn)證損失的變化。訓(xùn)練進(jìn)行了2500個(gè)epoch。使用Google Colab的默認(rèn)GPU和RAM設(shè)置,整個(gè)訓(xùn)練過程大約花費(fèi)了10分鐘,這是相當(dāng)快速的。最后一個(gè)epoch的驗(yàn)證損失為2.19,考慮到我們使用的訓(xùn)練數(shù)據(jù)量和epoch數(shù)量,這個(gè)結(jié)果是可以接受的。要顯著降低損失,我們還需要增加訓(xùn)練數(shù)據(jù)的規(guī)模、提高epoch數(shù)量,并使用更強(qiáng)大的GPU或處理能力。

5、Llama 3模型推理

推理流程在輸出模塊流程圖(圖9)中已經(jīng)展示。讓我們實(shí)現(xiàn)推理代碼。

 ## 步驟5: Llama 3模型推理   # 這個(gè)函數(shù)使用我們構(gòu)建和訓(xùn)練的Llama 3模型,基于提供的提示生成文本序列  


 def generate(model, prompts: str, params: ModelArgs, max_gen_len: int=500, temperature: float = 0.6, top_p: float = 0.9):  

     # prompt_tokens: 用戶輸入文本或提示列表       # max_gen_len: 生成文本序列的最大長(zhǎng)度       # temperature: 用于控制采樣隨機(jī)性的溫度值。默認(rèn)為0.6       # top_p: 從logits采樣prob輸出的top-p概率閾值。默認(rèn)為0.9  
     bsz = 1  # 對(duì)于推理,通常用戶只輸入一個(gè)提示,我們將其作為1個(gè)批次  
     prompt_tokens = token_bos.tolist() + encode(prompts)  
     assert len(prompt_tokens) <= params.max_seq_len, "提示標(biāo)記長(zhǎng)度應(yīng)小于max_seq_len"  
     total_len = min(len(prompt_tokens)+max_gen_len, params.max_seq_len)     

     # 這個(gè)tokens矩陣用于存儲(chǔ)輸入提示和模型生成的所有輸出       # 稍后我們將使用分詞器的decode函數(shù)來解碼這個(gè)token,以文本格式查看結(jié)果  
     tokens = torch.full((bsz,total_len), fill_value=token_pad.item(), dtype=torch.long, device=params.device)  

     # 將提示tokens填入token矩陣  
     tokens[:,:len(prompt_tokens)] = torch.tensor(prompt_tokens, dtype=torch.long, device=params.device)  

     # 創(chuàng)建一個(gè)prompt_mask_token,用于稍后識(shí)別token是提示token還是填充token       # 如果是提示token則為True,如果是填充token則為False  
     input_text_mask = tokens != token_pad.item()  

     # 現(xiàn)在我們可以從第一個(gè)位置開始,一次使用一個(gè)token從prompt_tokens列表開始推理  
     prev_pos = 0  
     for cur_pos in range(1, total_len):  
       with torch.no_grad():  
         logits, _ = model(x=tokens[:,prev_pos:cur_pos], start_pos=prev_pos)  
       if temperature > 0:        
         probs = torch.softmax(logits[:, -1]/temperature, dim=-1)  
         next_token = sample_top_p(probs, top_p)          
       else:  
         next_token = torch.argmax(logits[:, -1], dim=-1)          


       next_token = next_token.reshape(-1)  

       # 只有在是填充token時(shí)才替換token  
       next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)  
       tokens[:, cur_pos] = next_token  


       prev_pos = cur_pos  
       if tokens[:,cur_pos]==token_pad.item() and next_token == token_eos.item():  
         break  


     output_tokens, output_texts = [], []      


     for i, toks in enumerate(tokens.tolist()):  
       if token_eos.item() in toks:  
         eos_idx = toks.index(token_eos.item())  
         toks = toks[:eos_idx]  


       output_tokens.append(toks)  
       output_texts.append(decode(toks))  
     return output_tokens, output_texts  

 # 對(duì)概率分布執(zhí)行top-p (nucleus) 采樣   # probs (torch.Tensor): 由logits導(dǎo)出的概率分布張量   # p: top-p采樣的概率閾值   # 根據(jù)相關(guān)研究,Top-p采樣選擇累積概率質(zhì)量超過閾值p的最小標(biāo)記集    # 基于選定的標(biāo)記重新歸一化分布  
 def sample_top_p(probs, p):  
     probs_sort, prob_idx = torch.sort(probs, dim=-1, descending=True)  
     probs_sum = torch.cumsum(probs_sort, dim=-1)  
     mask = probs_sum - probs_sort > p  
     probs_sort[mask] = 0.0  
     probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))  
     next_token = torch.multinomial(probs_sort, num_samples=1)  
     next_token = torch.gather(prob_idx, -1, next_token)           # 返回從詞匯表中采樣的標(biāo)記索引   
     return next_token

對(duì)新的提示執(zhí)行推理,并檢查生成的輸出:

 ## 對(duì)用戶輸入的提示執(zhí)行推理  
 prompts = "Consider you what services he has done"  
 output_tokens, output_texts = generate(model, prompts, ModelArgs)  
 output_texts = output_texts[0].replace("<|begin_of_text|>", "")  
 print(output_texts)  


 ## 輸出 ##  
 """  
 Consider you what services he has done o eretrane  
 adetranytnn i eey i ade hs rcuh i eey,ad hsatsTns rpae,T  
 eon o i hseflns o i eee ee hs ote i ocal ersl,Bnnlnface  
 o i hmr a il nwye ademto nt i a ere  
 h i ees.  
 Frm oe o etrane o oregae,alh,t orede i oeral  
 """

從結(jié)果可以看出,我們的Llama 3模型能夠?qū)π碌奶崾緢?zhí)行推理并生成文本。雖然考慮到我們使用的訓(xùn)練數(shù)據(jù)量和訓(xùn)練輪數(shù),輸出質(zhì)量并不是很高,但這證明了模型的基本功能是正常的。通過使用更大規(guī)模的訓(xùn)練數(shù)據(jù)和更多的訓(xùn)練輪數(shù),我們將能夠獲得更高質(zhì)量的輸出。

總結(jié)

我們已經(jīng)成功地從零開始構(gòu)建了自己的Llama 3模型。我們不僅實(shí)現(xiàn)了模型的架構(gòu),還成功地進(jìn)行了訓(xùn)練,并能夠執(zhí)行推理以生成新的文本。值得注意的是,我們?cè)谙鄬?duì)有限的計(jì)算資源(Google Colab Notebook提供的免費(fèi)GPU和RAM)下,在較短的時(shí)間內(nèi)完成了這個(gè)過程。

本文中的代碼和方法主要用于教育和研究目的。在實(shí)際應(yīng)用中,可能需要進(jìn)行更多的優(yōu)化和調(diào)整,以達(dá)到生產(chǎn)級(jí)別的性能和效果。

    本站是提供個(gè)人知識(shí)管理的網(wǎng)絡(luò)存儲(chǔ)空間,所有內(nèi)容均由用戶發(fā)布,不代表本站觀點(diǎn)。請(qǐng)注意甄別內(nèi)容中的聯(lián)系方式、誘導(dǎo)購(gòu)買等信息,謹(jǐn)防詐騙。如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,請(qǐng)點(diǎn)擊一鍵舉報(bào)。
    轉(zhuǎn)藏 分享 獻(xiàn)花(0

    0條評(píng)論

    發(fā)表

    請(qǐng)遵守用戶 評(píng)論公約

    類似文章 更多