來源:DeepHub IMBA
本文約12000字,建議閱讀15+分鐘
本文將詳細(xì)指導(dǎo)如何從零開始構(gòu)建完整的Llama 3模型架構(gòu)。
我們上次發(fā)了用PyTorch從零開始編寫DeepSeek-V2的文章后,有小伙伴留言說希望介紹一下Llama 3。那么今天他就來了,本文將詳細(xì)指導(dǎo)如何從零開始構(gòu)建完整的Llama 3模型架構(gòu),并在自定義數(shù)據(jù)集上執(zhí)行訓(xùn)練和推理。

[圖1]:Llama 3架構(gòu)展示訓(xùn)練和推理流程。因?yàn)楣俜絃lama 3論文中未提供相關(guān)圖表。所以此圖為大概架構(gòu)圖,閱讀本文后你應(yīng)能繪制出更為精確的架構(gòu)圖。
本文目標(biāo)
通過本文。你可以了解到:
深入理解Llama 3模型各組件的底層工作原理。
編寫代碼構(gòu)建Llama 3的每個(gè)組件,并將它們組裝成一個(gè)功能完整的Llama 3模型。
編寫代碼使用新的自定義數(shù)據(jù)集訓(xùn)練模型。
編寫代碼執(zhí)行推理,使Llama 3模型能夠根據(jù)輸入提示生成新文本。
1、輸入模塊
如圖1所示,輸入模塊包含三個(gè)組件:文本/提示、分詞器和嵌入。
輸入模塊內(nèi)部工作流程
讓我們通過下圖了解輸入模塊內(nèi)的工作流程。

[圖2]:輸入模塊流程圖,展示提示、分詞器和嵌入流程。
首先,單個(gè)或批量文本/提示被輸入模型。例如:圖中的"Hello World"。
輸入模型的必須是數(shù)字格式,因?yàn)槟P蜔o法直接處理文本。分詞器將這些文本/提示轉(zhuǎn)換為標(biāo)記ID(詞匯表中標(biāo)記的索引號(hào)表示)。我們將使用Tiny Shakespeare數(shù)據(jù)集構(gòu)建詞匯表并訓(xùn)練模型。Llama 3模型使用TikToken作為分詞器,這是一種子詞分詞器。但是我們這個(gè)實(shí)現(xiàn)將使用字符級(jí)分詞器。這樣做的主要原因是讓我們能夠自行構(gòu)建詞匯表和分詞器,包括編碼和解碼函數(shù),這樣可以深入理解底層工作原理并完全掌控代碼。
每個(gè)標(biāo)記ID將被轉(zhuǎn)換為128維的嵌入向量(原始Llama 3 8B中為4096維)。然后這些嵌入將被傳遞到下一個(gè)解碼器模塊。
輸入模塊代碼實(shí)現(xiàn):
# 導(dǎo)入必要的庫(kù)
import torch
from torch import nn
from torch.nn import functional as F
import math
import numpy as np
import time
from dataclasses import dataclass
from typing import Optional, Tuple, List
import pandas as pd
from matplotlib import pyplot as plt ### 步驟1: 輸入模塊 ###
# 使用Tiny Shakespeare數(shù)據(jù)集實(shí)現(xiàn)字符級(jí)分詞器。部分字符級(jí)分詞器代碼參考自Andrej Karpathy的GitHub倉(cāng)庫(kù)
# (https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/prepare.py)
# 加載tiny_shakespeare數(shù)據(jù)文件 (https://github.com/tamangmilan/llama3/blob/main/tiny_shakespeare.txt)
device: str = 'cuda' if torch.cuda.is_available() else 'cpu' # 根據(jù)可用性分配設(shè)備為cuda或cpu
# 加載tiny_shakespeare數(shù)據(jù)文件
with open('tiny_shakespeare.txt', 'r') as f:
data = f.read()
# 通過提取tiny_shakespeare數(shù)據(jù)中的所有唯一字符準(zhǔn)備詞匯表
vocab = sorted(list(set(data)))
# 訓(xùn)練Llama 3模型需要額外的標(biāo)記,如<|begin_of_text|>、<|end_of_text|>和<|pad_id|>,將它們添加到詞匯表中
vocab.extend(['<|begin_of_text|>','<|end_of_text|>','<|pad_id|>'])
vocab_size = len(vocab)
# 創(chuàng)建字符與詞匯表中對(duì)應(yīng)整數(shù)索引之間的映射。
# 這對(duì)于構(gòu)建分詞器的編碼和解碼函數(shù)至關(guān)重要。
itos = {i:ch for i, ch in enumerate(vocab)}
stoi = {ch:i for i, ch in enumerate(vocab)}
# 分詞器編碼函數(shù):輸入字符串,輸出整數(shù)列表
def encode(s):
return [stoi[ch] for ch in s]
# 分詞器解碼函數(shù):輸入整數(shù)列表,輸出字符串
def decode(l):
return ''.join(itos[i] for i in l)
# 定義稍后在模型訓(xùn)練中使用的張量標(biāo)記變量
token_bos = torch.tensor([stoi['<|begin_of_text|>']], dtype=torch.int, device=device)
token_eos = torch.tensor([stoi['<|end_of_text|>']], dtype=torch.int, device=device)
token_pad = torch.tensor([stoi['<|pad_id|>']], dtype=torch.int, device=device)
prompts = "Hello World"
encoded_tokens = encode(prompts)
decoded_text = decode(encoded_tokens)
### 輸入模塊代碼測(cè)試 ###
# 取消下面的三重引號(hào)來執(zhí)行測(cè)試
"""
print(f"Shakespeare文本字符長(zhǎng)度: {len(data)}")
print(f"詞匯表內(nèi)容: {''.join(vocab)}\n")
print(f"詞匯表大小: {vocab_size}")
print(f"編碼后的標(biāo)記: {encoded_tokens}")
print(f"解碼后的文本: {decoded_text}")
"""
### 測(cè)試結(jié)果: ###
"""
Shakespeare文本字符長(zhǎng)度: 1115394
詞匯表內(nèi)容:
!$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz<|begin_of_text|><|end_of_text|><|pad_id|>
詞匯表大小: 68
編碼后的標(biāo)記: [20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42]
解碼后的文本: Hello World
"""2、解碼器模塊
參照?qǐng)D1的架構(gòu)圖,解碼器模塊包含以下子組件:
RMS歸一化
旋轉(zhuǎn)位置編碼
KV緩存
分組查詢注意力
前饋網(wǎng)絡(luò)
解碼器塊
RMS歸一化(Root Mean Square Normalization)
RMSNorm的必要性
從圖1可以看出,輸入模塊輸出(嵌入向量)經(jīng)過RMSNorm模塊。這是因?yàn)榍度胂蛄烤哂卸鄠€(gè)維度(Llama3-8b中為4096維),可能出現(xiàn)不同范圍的值。這會(huì)導(dǎo)致模型梯度爆炸或消失,從而導(dǎo)致收斂緩慢甚至發(fā)散。而RMSNorm將這些值歸一化到一定范圍,有助于穩(wěn)定和加速訓(xùn)練過程。這使得梯度具有更一致的幅度,從而加快模型收斂。
RMSNorm的工作原理

[圖3]:對(duì)形狀為[3,3]的輸入嵌入應(yīng)用RMSNorm
類似于層歸一化,RMSNorm沿嵌入特征或維度應(yīng)用。上圖中的嵌入形狀為[3,3],意味著每個(gè)標(biāo)記有3個(gè)維度。
示例:對(duì)第一個(gè)標(biāo)記X1的嵌入應(yīng)用RMSNorm:
X1標(biāo)記在每個(gè)維度上的值(x11、x12和x13)分別除以所有這些值的均方根。公式如圖3所示。
為避免除以零并保證數(shù)值穩(wěn)定性,在均方根中加入一個(gè)小常數(shù)E(Epsilon)。乘以一個(gè)縮放參數(shù)Gamma (Y)。每個(gè)特征都有一個(gè)獨(dú)特的Gamma參數(shù)(如圖中d1維度的Y1、d2維度的Y2和d3維度的Y3),這是一個(gè)學(xué)習(xí)參數(shù),可以向上或向下縮放以進(jìn)一步穩(wěn)定歸一化。gamma參數(shù)初始化為1(如上面的計(jì)算所示)。
如示例所示,嵌入值原本較大且分布范圍寬。應(yīng)用RMSNorm后,值變小且范圍縮小。計(jì)算使用實(shí)際的RMSNorm函數(shù)完成。
RMSNorm相比層歸一化的優(yōu)勢(shì)
如上例所示沒有計(jì)算任何均值或方差,而這在層歸一化中是必需的。所以RMSNorm通過避免計(jì)算均值和方差減少了計(jì)算開銷。根據(jù)作者的研究,RMSNorm在不影響準(zhǔn)確性的同時(shí)提供了性能優(yōu)勢(shì)。
RMSNorm代碼實(shí)現(xiàn):
# 步驟2: 解碼器模塊
# 注:由于Llama 3模型由Meta開發(fā),為了與他們的代碼庫(kù)保持一致并考慮未來兼容性,
# 我將使用Meta GitHub上的大部分代碼,并進(jìn)行必要的修改以實(shí)現(xiàn)我們的目標(biāo)。
# 定義參數(shù)數(shù)據(jù)類:我們將在模型構(gòu)建、訓(xùn)練和推理過程中使用這些參數(shù)。
# 注:為了更快地看到訓(xùn)練和推理結(jié)果,而不是專注于高準(zhǔn)確性,我們對(duì)大多數(shù)參數(shù)采用較低的值,
# 這些值在Llama 3模型中設(shè)置得更高。 @dataclass
class ModelArgs:
dim: int = 512 # 嵌入維度
n_layers: int = 8 # 模型解碼器塊的數(shù)量
n_heads: int = 8 # 查詢嵌入的頭數(shù)
n_kv_heads: int = 4 # 鍵和值嵌入的頭數(shù)
vocab_size: int = len(vocab) # 詞匯表長(zhǎng)度
multiple_of: int = 256 # 用于計(jì)算前饋網(wǎng)絡(luò)維度
ffn_dim_multiplier: Optional[float] = None # 用于計(jì)算前饋網(wǎng)絡(luò)維度
norm_eps: float = 1e-5 # RMSNorm計(jì)算的默認(rèn)Epsilon值
rope_theta: float = 10000.0 # RePE計(jì)算的默認(rèn)theta值
max_batch_size: int = 10 # 最大批量大小
max_seq_len: int = 256 # 最大序列長(zhǎng)度
epochs: int = 2500 # 總訓(xùn)練迭代次數(shù)
log_interval: int = 10 # 打印日志和損失值的間隔數(shù)
device: str = 'cuda' if torch.cuda.is_available() else 'cpu' # 根據(jù)可用性分配設(shè)備為cuda或cpu
## 步驟2a: RMSNorm
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
device = ModelArgs.device
self.eps = eps
# 縮放參數(shù)gamma,初始化為1,參數(shù)數(shù)量等于dim的大小
self.weight = nn.Parameter(torch.ones(dim).to(device))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(device)
def forward(self, x):
#形狀: x[bs,seq,dim]
output = self._norm(x.float()).type_as(x)
#形狀: x[bs,seq,dim] -> x_norm[bs,seq,dim]
return output * self.weight
### RMSNorm代碼測(cè)試 ###
# 取消下面的三重引號(hào)來執(zhí)行測(cè)試
"""
x = torch.randn((ModelArgs.max_batch_size, ModelArgs.max_seq_len, ModelArgs.dim), device=device)
rms_norm = RMSNorm(dim=ModelArgs.dim)
x_norm = rms_norm(x)
print(f"x的形狀: {x.shape}")
print(f"x_norm的形狀: {x_norm.shape}")
"""
### 測(cè)試結(jié)果: ###
"""
x的形狀: torch.Size([10, 256, 512])
x_norm的形狀: torch.Size([10, 256, 512])
"""旋轉(zhuǎn)位置編碼(Rotary Positional Encoding, RoPE)
回顧之前的步驟,我們已將輸入文本轉(zhuǎn)換為嵌入,并對(duì)嵌入應(yīng)用了RMSNorm。然而,這里存在一個(gè)問題:假設(shè)輸入文本是"I love apple"或"apple love I",模型會(huì)將兩個(gè)句子視為相同并以相同方式學(xué)習(xí)。這是因?yàn)榍度胫袥]有為模型定義順序信息。因此對(duì)于任何語言模型來說,保持標(biāo)記的順序至關(guān)重要。在Llama 3模型架構(gòu)中,引入了旋轉(zhuǎn)位置編碼(RoPE)來定義句子中每個(gè)標(biāo)記的位置,這不僅維護(hù)了順序,還保留了句子中標(biāo)記的相對(duì)位置信息。
旋轉(zhuǎn)位置編碼的工作原理
RoPE是一種位置編碼方法,它通過添加絕對(duì)位置信息以及包含標(biāo)記之間的相對(duì)位置信息來編碼嵌入,從而維護(hù)句子中標(biāo)記的順序。它通過使用一個(gè)特殊的旋轉(zhuǎn)矩陣來旋轉(zhuǎn)給定的嵌入來執(zhí)行編碼操作。這種利用旋轉(zhuǎn)矩陣的簡(jiǎn)潔而強(qiáng)大的數(shù)學(xué)推導(dǎo)是RoPE的核心。

[圖4]:應(yīng)用于2維向量的旋轉(zhuǎn)矩陣
上圖展示了旋轉(zhuǎn)矩陣應(yīng)用于2維向量的情況。Llama 3模型中的維度數(shù)是4096,遠(yuǎn)高于此。我們?cè)敿?xì)介紹如何對(duì)更高維度的嵌入應(yīng)用旋轉(zhuǎn)。

[圖5]:RoPE應(yīng)用于嵌入的示例
嵌入的旋轉(zhuǎn)涉及每個(gè)嵌入位置(m)值和theta (θ)對(duì)每對(duì)嵌入維度的乘法。這就是RoPE如何通過實(shí)現(xiàn)旋轉(zhuǎn)矩陣來捕獲絕對(duì)位置和相對(duì)位置信息的方式。
注意:在執(zhí)行旋轉(zhuǎn)之前,需要將旋轉(zhuǎn)矩陣轉(zhuǎn)換為極坐標(biāo)形式,并將嵌入向量轉(zhuǎn)換為復(fù)數(shù)。旋轉(zhuǎn)完成后,旋轉(zhuǎn)后的嵌入需要轉(zhuǎn)換回實(shí)數(shù)以進(jìn)行注意力操作。另外RoPE僅應(yīng)用于查詢和鍵嵌入,不適用于值嵌入。
RoPE的代碼實(shí)現(xiàn):
## 步驟2b: RoPE實(shí)現(xiàn)
def precompute_freqs_cis(dim:int, seq_len: int, theta: float=10000.0):
# 計(jì)算每對(duì)維度的Theta值,即dim/2
device = ModelArgs.device
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2,device=device)[:(dim//2)].float()/dim))
# 計(jì)算序列中位置(m)的范圍
t = torch.arange(seq_len, dtype=torch.float32, device=device)
# freqs給出序列中所有標(biāo)記位置的Theta值范圍
freqs = torch.outer(t, freqs).to(device)
# 這是需要轉(zhuǎn)換為極坐標(biāo)形式的旋轉(zhuǎn)矩陣,以便對(duì)嵌入執(zhí)行旋轉(zhuǎn)
freqs_cis = torch.polar(torch.ones_like(freqs).to(device), freqs).to(device)
return freqs_cis
def reshape_for_broadcast(freqs_cis, x):
ndim = x.ndim
assert 0<=1<ndim
assert freqs_cis.shape == (x.shape[1],x.shape[-1]), "freqs_cis的最后兩個(gè)維度必須與x匹配"
shape = [d if i==1 or i==ndim-1 else 1 for i,d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor)->Tuple[torch.Tensor, torch.Tensor]:
device = ModelArgs.device
# 同時(shí)對(duì)查詢和鍵嵌入應(yīng)用旋轉(zhuǎn)位置編碼
# 首先:xq和xk嵌入的最后一個(gè)維度需要重塑為一對(duì)。因?yàn)樾D(zhuǎn)矩陣應(yīng)用于每對(duì)維度。
# 其次:將xq和xk轉(zhuǎn)換為復(fù)數(shù),因?yàn)樾D(zhuǎn)矩陣只適用于復(fù)數(shù)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)).to(device) #xq_:[bsz, seq_len, n_heads, head_dim/2]
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)).to(device) #xk_:[bsz, seq_len, n_heads, head_dim/2]
# 旋轉(zhuǎn)矩陣(freqs_cis)在seq_len(dim=1)和head_dim(dim=3)維度上應(yīng)與嵌入匹配
# 此外,freqs_cis的形狀應(yīng)與xq和xk相同,因此將freqs_cis的形狀從[seq_len,head_dim]改變?yōu)閇1,seq_len,1,head_dim]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
# 最后,通過與freqs_cis相乘執(zhí)行旋轉(zhuǎn)操作。
# 旋轉(zhuǎn)完成后,將xq_out和xk_out轉(zhuǎn)換回實(shí)數(shù)并返回
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).to(device) #xq_out:[bsz, seq_len, n_heads, head_dim]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).to(device) #xk_out:[bsz, seq_len, n_heads, head_dim]
return xq_out.type_as(xq), xk_out.type_as(xk)
### RoPE代碼測(cè)試 ###
# 注:x_norm在RMSNorm測(cè)試中計(jì)算,這里用于測(cè)試。
# 取消下面的三重引號(hào)來執(zhí)行測(cè)試
"""
head_dim = ModelArgs.dim//ModelArgs.n_heads
wq = nn.Linear(ModelArgs.dim, ModelArgs.n_heads * head_dim, bias=False, device=device)
wk = nn.Linear(ModelArgs.dim, ModelArgs.n_kv_heads * head_dim, bias=False, device=device)
xq = wq(x_norm)
xk = wk(x_norm)
print(f"xq.shape: {xq.shape}")
print(f"xk.shape: {xk.shape}")
xq = xq.view(xq.shape[0],xq.shape[1],ModelArgs.n_heads, head_dim)
xk = xk.view(xk.shape[0],xk.shape[1],ModelArgs.n_kv_heads, head_dim)
print(f"xq.re-shape: {xq.shape}")
print(f"xk.re-shape: {xk.shape}")
freqs_cis = precompute_freqs_cis(dim=head_dim, seq_len=ModelArgs.max_seq_len)
print(f"freqs_cis.shape: {freqs_cis.shape}")
xq_rotate, xk_rotate = apply_rotary_emb(xq, xk, freqs_cis)
print(f"xq_rotate.shape: {xq_rotate.shape}")
print(f"xk_rotate.shape: {xk_rotate.shape}")
"""
### 測(cè)試結(jié)果: ###
"""
xq.shape: torch.Size([10, 256, 512])
xk.shape: torch.Size([10, 256, 256])
xq.re-shape: torch.Size([10, 256, 8, 64])
xk.re-shape: torch.Size([10, 256, 4, 64])
freqs_cis.shape: torch.Size([256, 32])
xq_rotate.shape: torch.Size([10, 256, 8, 64])
xk_rotate.shape: torch.Size([10, 256, 4, 64])
"""KV緩存(僅用于推理)
在Llama 3架構(gòu)中,推理階段引入了KV緩存的概念,用于以鍵和值緩存的形式存儲(chǔ)先前生成的標(biāo)記。這些緩存用于計(jì)算自注意力以生成下一個(gè)標(biāo)記。只緩存鍵和值標(biāo)記,而不緩存查詢標(biāo)記,因此稱為KV緩存。
KV緩存的必要性
讓我們通過下圖來理解KV緩存的重要性。

[圖6]:KV緩存實(shí)現(xiàn)
圖中的A塊:在生成output3標(biāo)記時(shí),仍在計(jì)算先前的輸出標(biāo)記(output1, output2),這是不必要的。這在注意力計(jì)算期間導(dǎo)致了額外的矩陣乘法,顯著增加了計(jì)算資源的使用。
圖中的B塊:輸出標(biāo)記替換了查詢嵌入中的輸入標(biāo)記。KV緩存存儲(chǔ)了先前生成的標(biāo)記。在注意力分?jǐn)?shù)計(jì)算期間,我們只需要使用查詢中的1個(gè)標(biāo)記,并使用鍵和值緩存中的先前標(biāo)記。這將矩陣乘法從A塊的3x3減少到B塊的1x3,減少了約66%。在實(shí)際應(yīng)用中,對(duì)于巨大的序列長(zhǎng)度和批量大小,這將顯著減少計(jì)算資源的使用。
分組查詢注意力
分組查詢注意力與之前模型(如Llama 1)中使用的多頭注意力相似,唯一的區(qū)別在于為查詢、鍵和值”使用單獨(dú)的頭。分配給查詢的頭數(shù)是鍵和值頭數(shù)的n倍。讓我們通過圖表來進(jìn)一步理解。

[圖7]:分組查詢注意力和多頭注意力對(duì)比
在給定的圖中,多頭注意力在所有查詢、鍵和值中都有相等數(shù)量的頭,即n_heads = 8。
分組查詢注意力塊有8個(gè)查詢頭(n_heads)和4個(gè)鍵和值頭(n_kv_heads),這是查詢頭數(shù)量的一半。
分組查詢注意力的優(yōu)勢(shì)
盡管多頭注意力已經(jīng)表現(xiàn)出色,引入分組查詢注意力是有其特定原因。我們先回顧KV緩存,KV緩存確實(shí)大大減少了計(jì)算資源的使用。但是隨著KV緩存存儲(chǔ)越來越多的先前標(biāo)記,內(nèi)存使用會(huì)顯著增加。這對(duì)模型性能和計(jì)算成本都不利。所以引入了分組查詢注意力。減少K和V的頭數(shù)會(huì)減少需要存儲(chǔ)的參數(shù)數(shù)量,從而減少內(nèi)存使用。多項(xiàng)測(cè)試結(jié)果表明,使用這種方法模型的準(zhǔn)確性仍保持在相近的范圍內(nèi)。
注意力模塊的代碼實(shí)現(xiàn):
## 注意力模塊 [步驟2c: KV緩存; 步驟2d: 分組查詢注意力]
## 如前所述,命名約定遵循原始Meta LLama3 GitHub
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
# 嵌入維度
self.dim = args.dim
# 分配給查詢的頭數(shù)
self.n_heads = args.n_heads
# 分配給鍵和值的頭數(shù)。如果為"None",則數(shù)量與查詢相同。
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
# 每個(gè)頭相對(duì)于模型維度的維度
self.head_dim = args.dim // args.n_heads
# 重復(fù)次數(shù),以使鍵、值頭數(shù)與查詢頭數(shù)匹配
self.n_rep = args.n_heads // args.n_kv_heads
# 初始化鍵、查詢、值和輸出的權(quán)重。注意q和kv的權(quán)重out_feature值基于其頭數(shù)
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False, device=device)
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False, device=device)
# 初始化緩存以在開始時(shí)存儲(chǔ)鍵、值 (KV緩存實(shí)現(xiàn))
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)
def forward(self, x: torch.Tensor, start_pos, inference):
# 輸入嵌入的形狀: [bsz,seq_len,dim]
bsz, seq_len, _ = x.shape
# 掩碼將在"訓(xùn)練"期間使用,由于使用KV緩存,"推理"不需要掩碼。
mask = None
xq = self.wq(x) #x[bsz,seq_len,dim]*wq[dim,n_heads * head_dim] -> q[bsz,seq_len,n_heads * head_dim]
xk = self.wk(x) #x[bsz,seq_len,dim]*wq[dim,n_kv_heads * head_dim] -> k[bsz,seq_len,n_kv_heads * head_dim]
xv = self.wv(x) #x[bsz,seq_len,dim]*wq[dim,n_kv_heads * head_dim] -> v[bsz,seq_len,n_kv_heads * head_dim]
# 根據(jù)頭數(shù)重塑查詢、鍵和值 (分組查詢注意力實(shí)現(xiàn))
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) #xq[bsz,seq_len,n_heads, head_dim]
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) #xk[bsz,seq_len,n_kv_heads, head_dim]
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) #xv[bsz,seq_len,n_kv_heads, head_dim]
# 模型 - 推理模式: kv-cache僅在推理模式下啟用
if inference:
# 計(jì)算序列中每個(gè)位置的旋轉(zhuǎn)矩陣
freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len * 2)
# 在推理過程中,我們應(yīng)該只取從當(dāng)前標(biāo)記位置開始的旋轉(zhuǎn)矩陣范圍
freqs_cis = freqs_cis[start_pos : start_pos + seq_len]
# 將RoPE應(yīng)用于查詢和鍵嵌入
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
# 將鍵和值標(biāo)記嵌入存儲(chǔ)到它們各自的緩存中 [KV緩存實(shí)現(xiàn)]
self.cache_k[:bsz, start_pos:start_pos + seq_len] = xk
self.cache_v[:bsz, start_pos:start_pos + seq_len] = xv
# 為注意力計(jì)算分配所有直到當(dāng)前標(biāo)記位置的先前標(biāo)記嵌入給鍵和值變量
keys = self.cache_k[:bsz, :start_pos + seq_len]
values = self.cache_v[:bsz, :start_pos + seq_len]
# 此時(shí),鍵和值的形狀與查詢嵌入不同,但為了計(jì)算注意力分?jǐn)?shù),它們必須相同
# 使用repeat_kv函數(shù)使鍵、值的形狀與查詢形狀相同
keys = repeat_kv(keys, self.n_rep) #keys[bsz,seq_len,n_heads,head_dim]
values = repeat_kv(values, self.n_rep) #values[bsz,seq_len,n_heads,head_dim]
# 模式 - 訓(xùn)練模式: 未實(shí)現(xiàn)KV-Cache
else:
# 計(jì)算旋轉(zhuǎn)矩陣并將RoPE應(yīng)用于訓(xùn)練的查詢和鍵
freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len)
#xq[bsz,seq_len,n_heads, head_dim], xk[bsz,seq_len,n_heads, head_dim]
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
# 使用repeat_kv函數(shù)使鍵、值的形狀與查詢形狀相同
#keys[bsz,seq_len,n_heads,head_dim], #values[bsz,seq_len,n_heads,head_dim]
keys = repeat_kv(xk, self.n_rep)
values = repeat_kv(xv, self.n_rep)
# 對(duì)于訓(xùn)練模式,我們將計(jì)算掩碼并稍后應(yīng)用于注意力分?jǐn)?shù)
mask = torch.full((seq_len, seq_len),float("-inf"),device=self.args.device)
mask = torch.triu(mask, diagonal=1).to(self.args.device)
# 為了計(jì)算注意力,我們需要執(zhí)行轉(zhuǎn)置操作來重塑所有查詢、鍵和值,將頭部放在維度1,序列放在維度2
xq = xq.transpose(1,2) #xq[bsz,n_heads,seq_len,head_dim]
keys = keys.transpose(1,2) #keys[bsz,n_heads,seq_len,head_dim]
values = values.transpose(1,2) #values[bsz,n_heads,seq_len,head_dim]
# 計(jì)算注意力分?jǐn)?shù)
scores = torch.matmul(xq, keys.transpose(2,3)).to(self.args.device)/math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask
# 對(duì)注意力分?jǐn)?shù)應(yīng)用softmax
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# 注意力分?jǐn)?shù)與值的矩陣乘法
output = torch.matmul(scores, values).to(self.args.device)
# 我們得到了每個(gè)頭部的上下文嵌入
# 所有頭部需要重塑回來并組合,以給出單個(gè)上下文注意力輸出
# 形狀變化: output[bsz,n_heads,seq_len,head_dim] -> output[bsz,seq_len, n_heads,head_dim] -> output[bsz,seq_len, n_heads * head_dim]
output = output.transpose(1,2).contiguous().view(bsz, seq_len, -1)
# 形狀: output [bsz,seq_len,dim]
return self.wo(output)
# 如果鍵/值頭的數(shù)量少于查詢頭,此函數(shù)使用所需的重復(fù)次數(shù)擴(kuò)展鍵/值嵌入
def repeat_kv(x:torch.Tensor, n_rep: int)-> torch.Tensor:
bsz, seq_len, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:,:,:,None,:]
.expand(bsz,seq_len,n_kv_heads,n_rep, head_dim)
.reshape(bsz,seq_len,n_kv_heads * n_rep, head_dim)
)
### 測(cè)試: Repeat_kv函數(shù) ###
# 注: xk, x_norm已在RoPE, RMSNorm測(cè)試中計(jì)算,這里用于測(cè)試
# 取消下面的三重引號(hào)來執(zhí)行測(cè)試
"""
n_rep = ModelArgs.n_heads // ModelArgs.n_kv_heads
keys = repeat_kv(xk, n_rep)
print(f"xk.shape: {xk.shape}")
print(f"keys.shape: {keys.shape}")
## 測(cè)試: Attention函數(shù)
# 取消下面的三重引號(hào)來執(zhí)行測(cè)試
attention = Attention(ModelArgs)
x_out = attention(x_norm,start_pos=0, inference=False)
print(f"x_out.shape: {x_out.shape}")
"""
### 測(cè)試結(jié)果: ###
"""
xk.shape: torch.Size([10, 256, 4, 64])
keys.shape: torch.Size([10, 256, 8, 64])
x_out.shape: torch.Size([10, 256, 512])
"""前饋網(wǎng)絡(luò) (使用SwiGLU激活函數(shù))
如圖1所示,注意力輸出首先經(jīng)過RMSNorm,然后輸入前饋網(wǎng)絡(luò)。在前饋網(wǎng)絡(luò)中,注意力輸出嵌入會(huì)在其隱藏層中擴(kuò)展到更高維度,學(xué)習(xí)標(biāo)記的更復(fù)雜特征。
為什么選擇SwiGLU而非ReLU

[圖8]:帶有SwiGLU函數(shù)的前饋網(wǎng)絡(luò)
如圖所示,SwiGLU函數(shù)在正軸上的行為與ReLU相似。然而,在負(fù)軸上,SwiGLU輸出一些負(fù)值,這在學(xué)習(xí)較小值時(shí)可能有用,而不是像ReLU那樣在負(fù)軸上為平坦的0。根據(jù)作者的研究,使用SwiGLU的性能優(yōu)于ReLU,因此被選用。
前饋網(wǎng)絡(luò)的代碼實(shí)現(xiàn):
## 步驟2e: 前饋網(wǎng)絡(luò) (SwiGLU激活)
class FeedForward(nn.Module):
def __init__(self, dim:int, hidden_dim:int, multiple_of:int, ffn_dim_multiplier: Optional[float]):
super().__init__()
# 模型嵌入維度
self.dim = dim
# 我們必須使用Meta提供的隱藏維度計(jì)算方法,這是該模型的理想設(shè)置
# 隱藏維度的計(jì)算方式使其是256的倍數(shù)
hidden_dim = int(2 * hidden_dim/3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
# 定義隱藏層權(quán)重
self.w1 = nn.Linear(self.dim, hidden_dim, bias=False, device=device)
self.w2 = nn.Linear(hidden_dim, self.dim, bias=False, device=device)
self.w3 = nn.Linear(self.dim, hidden_dim, bias=False, device=device)
def forward(self, x):
# 形狀: [bsz,seq_len,dim]
return self.w2(F.silu(self.w1(x)) * self.w3(x))
### 測(cè)試: 前饋模塊 ###
# 注: x_out已在Attention測(cè)試中計(jì)算,這里用于測(cè)試
# 取消下面的三重引號(hào)來執(zhí)行測(cè)試
"""
feed_forward = FeedForward(ModelArgs.dim, 4 * ModelArgs.dim, ModelArgs.multiple_of, ModelArgs.ffn_dim_multiplier)
x_out = rms_norm(x_out)
x_out = feed_forward(x_out)
print(f"前饋輸出: x_out.shape: {x_out.shape}")
"""
### 測(cè)試結(jié)果: ###
"""
前饋輸出: x_out.shape: torch.Size([10, 256, 512])
"""解碼器塊
如圖1所示,解碼器塊由多個(gè)子組件組成,我們?cè)谇懊娴牟糠种幸呀?jīng)實(shí)現(xiàn)了這些組件。以下是解碼器塊內(nèi)進(jìn)行的逐步操作:
1、來自輸入模塊的嵌入首先經(jīng)過注意力-RMSNorm,然后輸入分組查詢注意力模塊。
2、同時(shí),來自輸入模塊的原始嵌入與注意力輸出相加。
3、然后,這個(gè)結(jié)果經(jīng)過前饋-RMSNorm,輸入前饋網(wǎng)絡(luò)模塊。
4、前饋網(wǎng)絡(luò)的輸出再次與步驟2的結(jié)果相加。
5、最終輸出被稱為解碼器輸出。這個(gè)解碼器輸出然后作為輸入傳遞給下一個(gè)解碼器塊。這個(gè)過程在接下來的31個(gè)解碼器塊中重復(fù)。第32個(gè)解碼器塊的最終輸出然后傳遞到輸出模塊。
解碼器塊的代碼實(shí)現(xiàn):
## 步驟2f: 解碼器塊。類名為TransformerBlock,以匹配Meta Llama 3代碼庫(kù)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
# 初始化注意力的RMSNorm
self.attention_norm = RMSNorm(dim=args.dim, eps = args.norm_eps)
# 初始化注意力類
self.attention = Attention(args)
# 初始化前饋網(wǎng)絡(luò)的RMSNorm
self.ff_norm = RMSNorm(dim=args.dim, eps = args.norm_eps)
# 初始化前饋網(wǎng)絡(luò)類
self.feedforward = FeedForward(args.dim, 4 * args.dim, args.multiple_of, args.ffn_dim_multiplier)
def forward(self, x, start_pos, inference):
# start_pos: 推理模式下的標(biāo)記位置, inference: True表示推理模式,False表示訓(xùn)練模式
# 1) 將輸入嵌入傳遞給attention_norm,然后傳遞給注意力模塊
# 2) 注意力的輸出與原始輸入(歸一化前)相加
h = x + self.attention(self.attention_norm(x), start_pos,inference)
# 1) 將注意力輸出傳遞給ff_norm,然后傳遞給前饋網(wǎng)絡(luò)
# 2) 前饋網(wǎng)絡(luò)的輸出與注意力輸出(ff_norm前)相加
out = h + self.feedforward(self.ff_norm(h))
# 形狀: [bsz,seq_len,dim]
return out
### 測(cè)試: TransformerBlock ###
# 取消下面的三重引號(hào)來執(zhí)行測(cè)試
"""
x = torch.randn((ModelArgs.max_batch_size, ModelArgs.max_seq_len, ModelArgs.dim), device=device)
transformer_block = TransformerBlock(ModelArgs)
transformer_block_out = transformer_block(x,start_pos=0, inference=False)
print(f"transformer_block_out.shape: {transformer_block_out.shape}")
"""
### 測(cè)試結(jié)果: ###
"""
transformer_block_out.shape: torch.Size([10, 64, 128])
"""3、輸出模塊
最后一個(gè)解碼器塊的輸出將傳入輸出模塊。它首先經(jīng)過RMSNorm處理,然后傳入線性層生成logits。接下來根據(jù)模式的不同,會(huì)執(zhí)行以下兩種操作之一:
如果是推理模式,計(jì)算top_p概率并生成下一個(gè)標(biāo)記。如果達(dá)到最大生成長(zhǎng)度或生成的下一個(gè)標(biāo)記為句子結(jié)束標(biāo)記,則停止生成。
如果是訓(xùn)練模式,使用目標(biāo)標(biāo)簽計(jì)算損失,并重復(fù)訓(xùn)練直到達(dá)到最大epoch數(shù)。
下圖展示了輸出模塊的流程:

[圖9]:Llama 3在訓(xùn)練和推理模式下的輸出流程圖
最終的Llama 3模型實(shí)現(xiàn)
我們將組合三個(gè)模塊(輸入模塊、解碼器模塊和輸出模塊)的所有組件。這就構(gòu)成了我們的完整Llama 3模型。
## 步驟3: 輸出模塊 # 這是Llama 3模型。類名保持為Transformer以匹配Meta Llama 3模型 class Transformer(nn.Module): def __init__(self, params: ModelArgs): super().__init__() # 設(shè)置params變量中的所有ModelArgs self.params = params # 從輸入模塊初始化嵌入類 self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) # 初始化解碼器塊并將其存儲(chǔ)在ModuleList中 # 這是因?yàn)槲覀兊腖lama 3模型中有4個(gè)解碼器塊 (官方Llama 3有32個(gè)塊) self.layers = nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(args=params)) # 為輸出模塊初始化RMSNorm self.norm = RMSNorm(params.dim, eps = params.norm_eps) # 在輸出模塊初始化線性層 self.output = nn.Linear(params.dim, params.vocab_size, bias=False) def forward(self, x, start_pos=0, targets=None): # start_pos: 推理模式的標(biāo)記位置, inference: True表示推理模式, False表示訓(xùn)練模式 # x是使用分詞器從文本或提示生成的標(biāo)記ID批次 # x[bsz, seq_len] -> h[bsz, seq_len, dim] h = self.tok_embeddings(x) # 如果目標(biāo)為None,則激活推理模式并設(shè)置為"True",否則為訓(xùn)練模式"False" inference = targets is None # 嵌入(h)然后將通過所有解碼器塊 for layer in self.layers: h = layer(h, start_pos, inference) # 最后解碼器塊的輸出將饋入RMSNorm h = self.norm(h) # 歸一化后,嵌入h將饋入線性層 # 線性層的主要任務(wù)是生成將嵌入映射到詞匯表大小的logits # h[bsz, seq_len, dim] -> logits[bsz, seq_len, vocab_size] logits = self.output(h).float() loss = None # 如果目標(biāo)不可用,則為推理模式 if targets is None: loss = None # 如果目標(biāo)可用,則為訓(xùn)練模式。計(jì)算損失以進(jìn)行進(jìn)一步的模型訓(xùn)練 else: loss = F.cross_entropy(logits.view(-1, self.params.vocab_size), targets.view(-1)) return logits, loss ### 測(cè)試: Transformer (Llama模型) ### # 取消下面的三重引號(hào)來執(zhí)行測(cè)試 """ model = Transformer(ModelArgs).to(ModelArgs.device) print(model) """

[圖10]: Llama 3分層架構(gòu)
我們剛剛構(gòu)建的Llama 3模型結(jié)構(gòu)看起來很完整。現(xiàn)在我們可以開始訓(xùn)練過程了。
4、訓(xùn)練Llama 3模型
訓(xùn)練流程在輸出模塊流程圖(圖9)中已經(jīng)展示。在開始訓(xùn)練之前,讓我們先實(shí)現(xiàn)訓(xùn)練代碼。以下代碼塊中包含了必要的解釋。
## 步驟4: 訓(xùn)練Llama 3模型:
# 使用我們?cè)谳斎肽K部分構(gòu)建的分詞器的encode函數(shù),通過對(duì)整個(gè)tiny_shakespeare數(shù)據(jù)進(jìn)行編碼來創(chuàng)建數(shù)據(jù)集
dataset = torch.tensor(encode(data), dtype=torch.int).to(ModelArgs.device)
print(f"dataset-shape: {dataset.shape}")
# 定義函數(shù)從給定數(shù)據(jù)集生成批次
def get_dataset_batch(data, split, args:ModelArgs):
seq_len = args.max_seq_len
batch_size = args.max_batch_size
device = args.device
train = data[:int(0.8 * len(data))]
val = data[int(0.8 * len(data)): int(0.9 * len(data))]
test = data[int(0.9 * len(data)):]
batch_data = train
if split == "val":
batch_data = val
elif split == "test":
batch_data = test
# 從數(shù)據(jù)集中選擇隨機(jī)起點(diǎn),為訓(xùn)練、驗(yàn)證和測(cè)試提供隨機(jī)樣本
ix = torch.randint(0, len(batch_data) - seq_len - 3, (batch_size,)).to(device)
x = torch.stack([torch.cat([token_bos, batch_data[i:i+seq_len-1]]) for i in ix]).long().to(device)
y = torch.stack([torch.cat([batch_data[i+1:i+seq_len], token_eos]) for i in ix]).long().to(device)
return x, y
### 測(cè)試: get_dataset函數(shù) ###
"""
xs, ys = get_dataset_batch(dataset, split="train", args=ModelArgs)
print([(decode(xs[i].tolist()), decode(ys[i].tolist())) for i in range(len(xs))])
"""
# 定義evaluate_loss函數(shù)來計(jì)算和存儲(chǔ)訓(xùn)練和驗(yàn)證損失,用于日志記錄和繪圖 @torch.no_grad()
def evaluate_loss(model, args:ModelArgs):
out = {}
model.eval()
for split in ["train", "val"]:
losses = []
for _ in range(10):
xb, yb = get_dataset_batch(dataset, split, args)
_, loss = model(x=xb, targets=yb)
losses.append(loss.item())
out[split] = np.mean(losses)
model.train()
return out
# 定義訓(xùn)練函數(shù)來執(zhí)行模型訓(xùn)練
def train(model, optimizer, args:ModelArgs):
epochs = args.epochs
log_interval = args.log_interval
device = args.device
losses = []
start_time = time.time()
for epoch in range(epochs):
optimizer.zero_grad()
xs, ys = get_dataset_batch(dataset, 'train', args)
xs = xs.to(device)
ys = ys.to(device)
logits, loss = model(x=xs, targets=ys)
loss.backward()
optimizer.step()
if epoch % log_interval == 0:
batch_time = time.time() - start_time
x = evaluate_loss(model, args)
losses.append(x)
print(f"Epoch {epoch} | val loss {x['val']:.3f} | Time {batch_time:.3f}")
start_time = time.time()
# 打印最終驗(yàn)證損失
print("驗(yàn)證損失: ", losses[-1]['val'])
# 在圖表中顯示間隔損失
return pd.DataFrame(losses).plot()定義完訓(xùn)練函數(shù)。就可以開始訓(xùn)練過程,并在訓(xùn)練完成后觀察結(jié)果。
## 開始訓(xùn)練我們的Llama 3模型 model = Transformer(ModelArgs).to(ModelArgs.device) optimizer = torch.optim.Adam(model.parameters()) train(model, optimizer, ModelArgs)

[圖11] 訓(xùn)練與驗(yàn)證損失圖
上圖顯示了訓(xùn)練和驗(yàn)證損失的變化。訓(xùn)練進(jìn)行了2500個(gè)epoch。使用Google Colab的默認(rèn)GPU和RAM設(shè)置,整個(gè)訓(xùn)練過程大約花費(fèi)了10分鐘,這是相當(dāng)快速的。最后一個(gè)epoch的驗(yàn)證損失為2.19,考慮到我們使用的訓(xùn)練數(shù)據(jù)量和epoch數(shù)量,這個(gè)結(jié)果是可以接受的。要顯著降低損失,我們還需要增加訓(xùn)練數(shù)據(jù)的規(guī)模、提高epoch數(shù)量,并使用更強(qiáng)大的GPU或處理能力。
5、Llama 3模型推理
推理流程在輸出模塊流程圖(圖9)中已經(jīng)展示。讓我們實(shí)現(xiàn)推理代碼。
## 步驟5: Llama 3模型推理 # 這個(gè)函數(shù)使用我們構(gòu)建和訓(xùn)練的Llama 3模型,基于提供的提示生成文本序列 def generate(model, prompts: str, params: ModelArgs, max_gen_len: int=500, temperature: float = 0.6, top_p: float = 0.9): # prompt_tokens: 用戶輸入文本或提示列表 # max_gen_len: 生成文本序列的最大長(zhǎng)度 # temperature: 用于控制采樣隨機(jī)性的溫度值。默認(rèn)為0.6 # top_p: 從logits采樣prob輸出的top-p概率閾值。默認(rèn)為0.9 bsz = 1 # 對(duì)于推理,通常用戶只輸入一個(gè)提示,我們將其作為1個(gè)批次 prompt_tokens = token_bos.tolist() + encode(prompts) assert len(prompt_tokens) <= params.max_seq_len, "提示標(biāo)記長(zhǎng)度應(yīng)小于max_seq_len" total_len = min(len(prompt_tokens)+max_gen_len, params.max_seq_len) # 這個(gè)tokens矩陣用于存儲(chǔ)輸入提示和模型生成的所有輸出 # 稍后我們將使用分詞器的decode函數(shù)來解碼這個(gè)token,以文本格式查看結(jié)果 tokens = torch.full((bsz,total_len), fill_value=token_pad.item(), dtype=torch.long, device=params.device) # 將提示tokens填入token矩陣 tokens[:,:len(prompt_tokens)] = torch.tensor(prompt_tokens, dtype=torch.long, device=params.device) # 創(chuàng)建一個(gè)prompt_mask_token,用于稍后識(shí)別token是提示token還是填充token # 如果是提示token則為True,如果是填充token則為False input_text_mask = tokens != token_pad.item() # 現(xiàn)在我們可以從第一個(gè)位置開始,一次使用一個(gè)token從prompt_tokens列表開始推理 prev_pos = 0 for cur_pos in range(1, total_len): with torch.no_grad(): logits, _ = model(x=tokens[:,prev_pos:cur_pos], start_pos=prev_pos) if temperature > 0: probs = torch.softmax(logits[:, -1]/temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits[:, -1], dim=-1) next_token = next_token.reshape(-1) # 只有在是填充token時(shí)才替換token next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) tokens[:, cur_pos] = next_token prev_pos = cur_pos if tokens[:,cur_pos]==token_pad.item() and next_token == token_eos.item(): break output_tokens, output_texts = [], [] for i, toks in enumerate(tokens.tolist()): if token_eos.item() in toks: eos_idx = toks.index(token_eos.item()) toks = toks[:eos_idx] output_tokens.append(toks) output_texts.append(decode(toks)) return output_tokens, output_texts # 對(duì)概率分布執(zhí)行top-p (nucleus) 采樣 # probs (torch.Tensor): 由logits導(dǎo)出的概率分布張量 # p: top-p采樣的概率閾值 # 根據(jù)相關(guān)研究,Top-p采樣選擇累積概率質(zhì)量超過閾值p的最小標(biāo)記集 # 基于選定的標(biāo)記重新歸一化分布 def sample_top_p(probs, p): probs_sort, prob_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(prob_idx, -1, next_token) # 返回從詞匯表中采樣的標(biāo)記索引 return next_token
對(duì)新的提示執(zhí)行推理,并檢查生成的輸出:
## 對(duì)用戶輸入的提示執(zhí)行推理
prompts = "Consider you what services he has done"
output_tokens, output_texts = generate(model, prompts, ModelArgs)
output_texts = output_texts[0].replace("<|begin_of_text|>", "")
print(output_texts)
## 輸出 ##
"""
Consider you what services he has done o eretrane
adetranytnn i eey i ade hs rcuh i eey,ad hsatsTns rpae,T
eon o i hseflns o i eee ee hs ote i ocal ersl,Bnnlnface
o i hmr a il nwye ademto nt i a ere
h i ees.
Frm oe o etrane o oregae,alh,t orede i oeral
"""從結(jié)果可以看出,我們的Llama 3模型能夠?qū)π碌奶崾緢?zhí)行推理并生成文本。雖然考慮到我們使用的訓(xùn)練數(shù)據(jù)量和訓(xùn)練輪數(shù),輸出質(zhì)量并不是很高,但這證明了模型的基本功能是正常的。通過使用更大規(guī)模的訓(xùn)練數(shù)據(jù)和更多的訓(xùn)練輪數(shù),我們將能夠獲得更高質(zhì)量的輸出。
總結(jié)
我們已經(jīng)成功地從零開始構(gòu)建了自己的Llama 3模型。我們不僅實(shí)現(xiàn)了模型的架構(gòu),還成功地進(jìn)行了訓(xùn)練,并能夠執(zhí)行推理以生成新的文本。值得注意的是,我們?cè)谙鄬?duì)有限的計(jì)算資源(Google Colab Notebook提供的免費(fèi)GPU和RAM)下,在較短的時(shí)間內(nèi)完成了這個(gè)過程。
本文中的代碼和方法主要用于教育和研究目的。在實(shí)際應(yīng)用中,可能需要進(jìn)行更多的優(yōu)化和調(diào)整,以達(dá)到生產(chǎn)級(jí)別的性能和效果。




