{"id":35919,"date":"2025-01-03T04:00:24","date_gmt":"2025-01-02T20:00:24","guid":{"rendered":"https:\/\/17aitech.com\/?p=35919"},"modified":"2025-01-03T04:00:24","modified_gmt":"2025-01-02T20:00:24","slug":"sebastian-raschka%e6%9c%80%e6%96%b0%e5%8d%9a%e5%ae%a2%ef%bc%9a%e4%bb%8e%e5%a4%b4%e5%bc%80%e5%a7%8b%ef%bc%8c%e7%94%a8llama-2%e6%9e%84%e5%bb%ballama-3-2","status":"publish","type":"post","link":"https:\/\/17aitech.com\/?p=35919","title":{"rendered":"Sebastian Raschka\u6700\u65b0\u535a\u5ba2\uff1a\u4ece\u5934\u5f00\u59cb\uff0c\u7528Llama 2\u6784\u5efaLlama 3.2"},"content":{"rendered":"<p>\u6587\u7ae0\u6765\u6e90\u4e8e\u4e92\u8054\u7f51:<a href=\"https:\/\/www.jiqizhixin.com\/articles\/2024-10-06-3\" target=\"_blank\">Sebastian Raschka\u6700\u65b0\u535a\u5ba2\uff1a\u4ece\u5934\u5f00\u59cb\uff0c\u7528Llama 2\u6784\u5efaLlama 3.2<\/a><\/p>\n<p>\u5341\u5929\u524d\u7684 Meta Connect 2024 \u5927\u4f1a\u4e0a\uff0c\u5f00\u6e90\u9886\u57df\u8fce\u6765\u4e86\u53ef\u5728\u8fb9\u7f18\u548c\u79fb\u52a8\u8bbe\u5907\u4e0a\u7684\u8fd0\u884c\u7684\u8f7b\u91cf\u7ea7\u6a21\u578b Llama 3.2 1B \u548c 3B\u3002\u4e24\u4e2a\u7248\u672c\u90fd\u662f\u7eaf\u6587\u672c\u6a21\u578b\uff0c\u4f46\u4e5f\u5177\u5907\u591a\u8bed\u8a00\u6587\u672c\u751f\u6210\u548c\u5de5\u5177\u8c03\u7528\u80fd\u529b\u3002Meta \u8868\u793a\uff0c\u8fd9\u4e9b\u6a21\u578b\u53ef\u8ba9\u5f00\u53d1\u8005\u6784\u5efa\u4e2a\u6027\u5316\u7684\u3001\u5728\u8bbe\u5907\u672c\u5730\u4e0a\u8fd0\u884c\u7684\u901a\u7528\u5e94\u7528 \u2014\u2014 \u8fd9\u7c7b\u5e94\u7528\u5c06\u5177\u5907\u5f88\u5f3a\u7684\u9690\u79c1\u6027\uff0c\u56e0\u4e3a\u6570\u636e\u65e0\u9700\u79bb\u5f00\u8bbe\u5907\u3002<\/p>\n<p>\u8fd1\u65e5\uff0c\u673a\u5668\u5b66\u4e60\u7814\u7a76\u5458 Sebastian Raschka \u5149\u901f\u53d1\u5e03\u957f\u7bc7\u6559\u7a0b\u300aConverting Llama 2 to Llama 3.2 From Scratch\u300b\u3002<\/p>\n<p><a href=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-3de05da944f0dd31562d7b81893c4d7f.png\" data-fancybox=\"images\" data-fancybox=\"gallery\"><img decoding=\"async\" src=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-3de05da944f0dd31562d7b81893c4d7f.png\"><\/a><\/p>\n<ul>\n<li>\n<p>\u535a\u6587\u94fe\u63a5\uff1ahttps:\/\/github.com\/rasbt\/LLMs-from-scratch\/blob\/main\/ch05\/07_gpt_to_llama\/converting-llama2-to-llama3.ipynb<\/p>\n<\/li>\n<\/ul>\n<p>\u672c\u6587\u662f\u300a Converting a From-Scratch GPT Architecture to Llama 2\u300b\u7684\u540e\u7eed\uff0c\u66f4\u65b0\u7684\u5185\u5bb9\u662f\u5982\u4f55\u5c06 Meta \u7684 Llama 2 \u67b6\u6784\u6a21\u578b\u9010\u6b65\u8f6c\u6362\u4e3a Llama 3\u3001Llama 3.1 \u548c Llama 3.2\u3002\u4e3a\u4e86\u907f\u514d\u4e0d\u5fc5\u8981\u7684\u5197\u957f\uff0c\u672c\u6587\u7279\u610f\u5c06\u89e3\u91ca\u90e8\u5206\u7f29\u81f3\u6700\u77ed\uff0c\u5e76\u5c06\u91cd\u70b9\u653e\u5728\u4e3b\u4ee3\u7801\u4e0a\u3002<\/p>\n<p><a href=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-fd1362f466c31343d4ff33406fde9178.png\" data-fancybox=\"images\" data-fancybox=\"gallery\"><img decoding=\"async\" src=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-fd1362f466c31343d4ff33406fde9178.png\"><\/a><\/p>\n<p>\u673a\u5668\u4e4b\u5fc3\u5bf9\u6587\u7ae0\u5185\u5bb9\u8fdb\u884c\u4e86\u4e0d\u6539\u53d8\u539f\u610f\u7684\u7f16\u8bd1\uff1a<\/p>\n<p><strong>1 \u9010\u6b65\u8f6c\u6362 Llama \u6a21\u578b\u5b9e\u73b0<\/strong><\/p>\n<p>\u5982\u679c\u4f60\u662f\u521d\u6b21\u5b9e\u65bd LLM \u67b6\u6784\uff0c\u5efa\u8bae\u4ece\u300aBuild a Large Language Model From Scratch\u300b\uff08https:\/\/github.com\/rasbt\/LLMs-from-scratch\/blob\/0972ded5309c25dc5eecc98b62897d677c6c36c4\/ch04\/01_main-chapter-code\/ch04.ipynb\uff09\u7684\u7b2c 4 \u7ae0\u5f00\u59cb\uff0c\u90a3\u90e8\u5206\u5185\u5bb9\u5c06\u9010\u6b65\u6307\u5bfc\u4f60\u5b9e\u65bd\u539f\u59cb GPT \u67b6\u6784\u3002<\/p>\n<p>\u7136\u540e\u53ef\u53c2\u8003\u300aConverting a From-Scratch GPT Architecture to Llama 2\u300b\uff08https:\/\/github.com\/rasbt\/LLMs-from-scratch\/blob\/0972ded5309c25dc5eecc98b62897d677c6c36c4\/ch05\/07_gpt_to_llama\/converting-gpt-to-llama2.ipynb\uff09\uff0c\u5c06\u5b9e\u73b0 Llama \u7279\u6709\u7684\u7ec4\u4ef6\uff0c\u5982 RMSNorm \u5c42\u3001SiLU \u548c SwiGLU \u6fc0\u6d3b\u3001RoPE\uff08\u65cb\u8f6c\u4f4d\u7f6e\u5d4c\u5165\uff09\u548c SentencePiece tokenizer\u3002<\/p>\n<p>\u672c\u7b14\u8bb0\u672c\u91c7\u7528 Llama 2 \u67b6\u6784\uff0c\u5e76\u901a\u8fc7\u4ee5\u4e0b\u65b9\u5f0f\u5c06\u5176\u8f6c\u6362\u4e3a Llama 3 \u67b6\u6784\uff1a<\/p>\n<ul>\n<li>\n<p>\u4fee\u6539\u65cb\u8f6c\u5d4c\u5165<\/p>\n<\/li>\n<li>\n<p>\u5b9e\u73b0\u5206\u7ec4\u67e5\u8be2\u6ce8\u610f\u529b<\/p>\n<\/li>\n<li>\n<p>\u4f7f\u7528\u5b9a\u5236\u7248\u7684 GPT-4 tokenizer<\/p>\n<\/li>\n<\/ul>\n<p>\u968f\u540e\uff0c\u6211\u4eec\u5c06 Meta \u5171\u4eab\u7684\u539f\u59cb Llama 3 \u6743\u91cd\u52a0\u8f7d\u5230\u67b6\u6784\u4e2d\uff1a<\/p>\n<p><strong>1.1 \u590d\u7528 Llama 2 \u7684\u7ec4\u4ef6<\/strong><\/p>\n<p>Llama 2 \u5b9e\u9645\u4e0a\u4e0e Llama 3 \u975e\u5e38\u76f8\u4f3c\uff0c\u5982\u4e0a\u6587\u6240\u8ff0\u548c\u672c\u6587\u5f00\u5934\u7684\u56fe\u7247\u6240\u793a\u3002<\/p>\n<p>\u8fd9\u610f\u5473\u7740\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u4ee3\u7801\u4ece Llama 2 \u7b14\u8bb0\u672c\u4e2d\u5bfc\u5165\u591a\u4e2a\u6784\u5efa\u6a21\u5757\uff1a<\/p>\n<section>\n<pre data-lang=\"python\"><code>import os<\/code>\r\n<code>import sys<\/code>\r\n<code>import io<\/code>\r\n<code>import nbformat<\/code>\r\n<code>import types<\/code>\r\n<code>def import_from_notebook():<\/code>\r\n<code>def import_definitions_from_notebook(fullname, names):<\/code>\r\n<code>current_dir = os.getcwd()<\/code>\r\n<code>path = os.path.join(current_dir, fullname + \".ipynb\")<\/code>\r\n<code>path = os.path.normpath(path)<\/code>\r\n<code># Load the notebook<\/code>\r\n<code>if not os.path.exists(path):<\/code>\r\n<code>raise FileNotFoundError(f\"Notebook file not found at: {path}\")<\/code>\r\n<code>with io.open(path, \"r\", encoding=\"utf-8\") as f:<\/code>\r\n<code>nb = nbformat.read(f, as_version=4)<\/code>\r\n<code># Create a module to store the imported functions and classes<\/code>\r\n<code>mod = types.ModuleType(fullname)<\/code>\r\n<code>sys.modules[fullname] = mod<\/code>\r\n<code>#<\/code><code>\u00a0Go through the notebook cells and only execute function or class definitions<\/code>\r\n<code>for cell in nb.cells:<\/code>\r\n<code>if cell.cell_type == \"code\":<\/code>\r\n<code>cell_code = cell.source<\/code>\r\n<code>for name in names:<\/code>\r\n<code># Check for function or class definitions<\/code>\r\n<code>if f\"def {name}\" in cell_code or f\"class {name}\" in cell_code:<\/code>\r\n<code>exec(cell_code, mod.__dict__)<\/code>\r\n<code>return mod<\/code>\r\n<code>fullname = \"converting-gpt-to-llama2\"<\/code><code>names = [\"precompute_rope_params\", \"compute_rope\", \"SiLU\", \"FeedForward\", \"RMSNorm\", \"MultiHeadAttention\"]<\/code>\r\n<code>return\u00a0import_definitions_from_notebook(fullname,\u00a0names)<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"ini\"><code>imported_module = import_from_notebook()<\/code>\r\n<code># We need to redefine precompute_rope_params<\/code>\r\n<code># precompute_rope_params = getattr(imported_module, \"precompute_rope_params\", None)<\/code>\r\n<code>compute_rope = getattr(imported_module, \"compute_rope\", None)<\/code>\r\n<code>SiLU = getattr(imported_module, \"SiLU\", None)<\/code>\r\n<code>FeedForward = getattr(imported_module, \"FeedForward\", None)<\/code>\r\n<code>RMSNorm = getattr(imported_module, \"RMSNorm\", None)<\/code>\r\n<code># MultiHeadAttention only for comparison purposes<\/code>\r\n<code>MultiHeadAttention = getattr(imported_module, \"MultiHeadAttention\", None)<\/code><\/pre>\n<\/section>\n<p><strong>1.2 \u4fee\u6539\u540e\u7684 RoPE<\/strong><\/p>\n<p>Llama 3 \u4f7f\u7528\u7684 RoPE \u4e0e Llama 2 \u76f8\u4f3c\uff0c\u53ef\u53c2\u9605 RoPE \u8bba\u6587\uff08https:\/\/arxiv.org\/abs\/2104.09864\uff09\u3002<\/p>\n<p>\u4e0d\u8fc7\uff0c\u4e8c\u8005 RoPE \u8bbe\u7f6e\u6709\u4e00\u4e9b\u7ec6\u5fae\u5dee\u522b\u3002Llama 3 \u73b0\u5728\u652f\u6301\u591a\u8fbe 8192 \u4e2a token\uff0c\u662f Llama 2\uff084096\uff09\u7684\u4e24\u500d\u3002<\/p>\n<p>RoPE \u7684\u57fa\u7840\u503c\uff08\u89c1\u4e0b\u6587\u516c\u5f0f\uff09\uff0c\u4ece 10000\uff08Llama 2\uff09\u589e\u52a0\u5230 50000\uff08Llama 3\uff09\uff0c\u516c\u5f0f\u5982\u4e0b\uff08\u6539\u7f16\u81ea RoPE \u8bba\u6587\uff09\uff1a<\/p>\n<p><a href=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-d68f70fd2b3404af03183c3832c81e75.png\" data-fancybox=\"images\" data-fancybox=\"gallery\"><img decoding=\"async\" src=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-d68f70fd2b3404af03183c3832c81e75.png\"><\/a><\/p>\n<p>\u8fd9\u4e9b\u503c\u662f\u4e00\u7ec4\u9884\u5b9a\u4e49\u7684\u53c2\u6570\uff0c\u7528\u4e8e\u786e\u5b9a\u65cb\u8f6c\u77e9\u9635\u4e2d\u7684\u65cb\u8f6c\u89d2\u5ea6\uff0c\u5176\u4e2d\u7684\u7ef4\u6570\u662f\u5d4c\u5165\u7a7a\u95f4\u7684\u7ef4\u6570\u3002<\/p>\n<p>\u5c06\u57fa\u6570\u4ece 10000 \u589e\u52a0\u5230 50000\uff0c\u9891\u7387\uff08\u6216\u65cb\u8f6c\u89d2\u5ea6\uff09\u5728\u5404\u7ef4\u5ea6\u4e0a\u7684\u8870\u51cf\u901f\u5ea6\u4f1a\u66f4\u6162\uff0c\u8fd9\u610f\u5473\u7740\u7ef4\u5ea6\u8d8a\u9ad8\uff0c\u89d2\u5ea6\u8d8a\u5927\uff08\u672c\u8d28\u4e0a\uff0c\u8fd9\u662f\u5bf9\u9891\u7387\u7684\u89e3\u538b\u7f29\uff09\u3002<\/p>\n<p>\u6b64\u5916\uff0c\u6211\u4eec\u8fd8\u5728\u4e0b\u9762\u7684\u4ee3\u7801\u4e2d\u5f15\u5165\u4e86\u4e00\u4e2a freq_config \u90e8\u5206\uff0c\u7528\u4e8e\u8c03\u6574\u9891\u7387\uff1b\u4e0d\u8fc7\uff0c\u5728 Llama 3\uff08\u53ea\u6709 Llama 3.1 \u548c Llama 3.2\uff09\u4e2d\u5e76\u4e0d\u9700\u8981\u5b83\uff0c\u6240\u4ee5\u7a0d\u540e\u4f1a\u91cd\u65b0\u8bbf\u95ee\u8fd9\u4e2a freq_config\uff08\u9ed8\u8ba4\u8bbe\u7f6e\u4e3a\u300c\u65e0\u300d\u5e76\u88ab\u5ffd\u7565\uff09\u3002<\/p>\n<section>\n<pre data-lang=\"makefile\"><code>import torch<\/code>\r\n<code>def precompute_rope_params(head_dim, theta_base=10000, context_length=4096, freq_config=None):<\/code>\r\n<code>assert head_dim % 2 == 0, \"Embedding dimension must be even\"<\/code>\r\n<code># Compute the inverse frequencies<\/code>\r\n<code>inv_freq = 1.0 \/ (theta_base ** (torch.arange(0, head_dim \/\/ 2) \/ (head_dim \/\/ 2)))<\/code>\r\n<code>################################ NEW ###############################################<\/code>\r\n<code># Frequency adjustments<\/code>\r\n<code>if freq_config is not None:<\/code>\r\n<code>low_freq_wavelen = freq_config[\"original_context_length\"] \/ freq_config[\"low_freq_factor\"]<\/code>\r\n<code>high_freq_wavelen = freq_config[\"original_context_length\"] \/ freq_config[\"high_freq_factor\"]<\/code>\r\n<code>wavelen = 2 * torch.pi \/ inv_freq<\/code>\r\n<code>inv_freq_llama = torch.where(<\/code>\r\n<code>wavelen &gt; low_freq_wavelen, inv_freq \/ freq_config[\"factor\"], inv_freq<\/code>\r\n<code>)<\/code>\r\n<code>smooth_factor = (freq_config[\"original_context_length\"] \/ wavelen - freq_config[\"low_freq_factor\"]) \/ (<\/code>\r\n<code>freq_config[\"high_freq_factor\"] - freq_config[\"low_freq_factor\"]<\/code>\r\n<code>)<\/code>\r\n<code>smoothed_inv_freq = (<\/code>\r\n<code>(1 - smooth_factor) * (inv_freq \/ freq_config[\"factor\"]) + smooth_factor * inv_freq<\/code>\r\n<code>)<\/code>\r\n<code>is_medium_freq = (wavelen &lt;= low_freq_wavelen) &amp; (wavelen &gt;= high_freq_wavelen)<\/code>\r\n<code>inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)<\/code>\r\n<code>inv_freq = inv_freq_llama<\/code>\r\n<code>####################################################################################<\/code>\r\n<code># Generate position indices<\/code>\r\n<code>positions = torch.arange(context_length)<\/code>\r\n<code># Compute the angles<\/code>\r\n<code>angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim \/\/ 2)<\/code>\r\n<code># Expand angles to match the head_dim<\/code>\r\n<code>angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)<\/code>\r\n<code># Precompute sine and cosine<\/code>\r\n<code>cos = torch.cos(angles)<\/code>\r\n<code>sin = torch.sin(angles)<\/code>\r\n<code>return cos, sin<\/code><\/pre>\n<\/section>\n<p>\u603b\u4e4b\uff0c\u4e0e Llama 2 \u76f8\u6bd4\uff0cLlama 3 \u7684\u65b0\u529f\u80fd\u662f\u300c\u4e0a\u4e0b\u6587\u957f\u5ea6\u300d\u548c theta \u57fa\u5e95\u53c2\u6570\uff1a<\/p>\n<section>\n<pre data-lang=\"ini\"><code># Instantiate RoPE parameters<\/code>\r\n<code>llama_2_context_len = 4096<\/code>\r\n<code>llama_3_context_len = 8192<\/code>\r\n<code>llama_2_theta_base = 10_000<\/code>\r\n<code>llama_3_theta_base = 50_000<\/code><\/pre>\n<\/section>\n<p>\u5728 Llama 2 \u4e2d\uff0c\u7528\u6cd5\u4e0e\u4ee5\u524d\u76f8\u540c\uff1a<\/p>\n<section>\n<pre data-lang=\"makefile\"><code># Settings<\/code>\r\n<code>batch_size = 2<\/code>\r\n<code>num_heads = 4<\/code>\r\n<code>head_dim = 16<\/code>\r\n<code># Instantiate RoPE parameters<\/code>\r\n<code>cos, sin = precompute_rope_params(<\/code>\r\n<code>head_dim=head_dim,<\/code>\r\n<code>theta_base=llama_3_theta_base,<\/code>\r\n<code>context_length=llama_3_context_len<\/code>\r\n<code>)<\/code>\r\n<code># Dummy query and key tensors<\/code>\r\n<code>torch.manual_seed(123)<\/code>\r\n<code>queries = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)<\/code>\r\n<code>keys = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)<\/code>\r\n<code># Apply rotary position embeddings<\/code>\r\n<code>queries_rot = compute_rope(queries, cos, sin)<\/code>\r\n<code>keys_rot = compute_rope(keys, cos, sin)<\/code><\/pre>\n<\/section>\n<p><strong>1.3 \u5206\u7ec4\u67e5\u8be2\u6ce8\u610f\u529b<\/strong><\/p>\n<p>\u672c\u8282\u5c06\u7528\u4e00\u79cd\u540d\u4e3a\u5206\u7ec4\u67e5\u8be2\u6ce8\u610f\u529b\uff08GQA\uff09\u7684\u66ff\u4ee3\u673a\u5236\u6765\u53d6\u4ee3\u591a\u5934\u6ce8\u610f\u529b\uff08MHA\uff09\u3002\u7b80\u800c\u8a00\u4e4b\uff0c\u53ef\u4ee5\u5c06 GQA \u89c6\u4e3a\u8ba1\u7b97\u548c\u53c2\u6570\u6548\u7387\u66f4\u9ad8\u7684 MHA \u7248\u672c\u3002<\/p>\n<p>\u5728 GQA \u4e2d\uff0c\u901a\u8fc7\u5728\u591a\u4e2a\u6ce8\u610f\u529b\u5934\u4e4b\u95f4\u5171\u4eab\u6765\u51cf\u5c11\u952e\u548c\u503c\u6295\u5f71\u7684\u6570\u91cf\uff0c\u6bcf\u4e2a\u6ce8\u610f\u529b\u5934\u4ecd\u6709\u5176\u72ec\u7279\u7684\u67e5\u8be2\uff0c\u4f46\u8fd9\u4e9b\u67e5\u8be2\u5173\u6ce8\u540c\u4e00\u7ec4\u952e\u548c\u503c\u3002<\/p>\n<p>\u4e0b\u9762\u662f\u5177\u6709 2 \u4e2a key-value \u7ec4\u7684 GQA \u793a\u4f8b\uff1a<\/p>\n<p><a href=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-428ba6ef8e9f8b880d9ad44f552196ca.png\" data-fancybox=\"images\" data-fancybox=\"gallery\"><img decoding=\"async\" src=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-428ba6ef8e9f8b880d9ad44f552196ca.png\"><\/a><\/p>\n<p>GQA \u7684\u4e3b\u8981\u601d\u60f3\u662f\u51cf\u5c11\u4e0e\u952e\u503c\u5bf9\u76f8\u5173\u7684\u552f\u4e00\u67e5\u8be2\u7ec4\u7684\u6570\u91cf\uff0c\u4ece\u800c\u5728\u4e0d\u663e\u8457\u964d\u4f4e\u5efa\u6a21\u6027\u80fd\u7684\u60c5\u51b5\u4e0b\uff0c\u51cf\u5c11 MHA \u4e2d\u67d0\u4e9b\u77e9\u9635\u4e58\u6cd5\u7684\u5927\u5c0f\u548c\u53c2\u6570\u7684\u6570\u91cf\u3002<\/p>\n<p>\u7b80\u800c\u8a00\u4e4b\uff0cGQA \u7684\u4e3b\u8981\u53d8\u5316\u662f\u6bcf\u4e2a\u67e5\u8be2\u7ec4\u90fd\u9700\u8981\u91cd\u590d\uff0c\u4ee5\u5339\u914d\u4e0e\u4e4b\u76f8\u5173\u7684\u5934\u6570\u91cf\uff0c\u5177\u4f53\u5b9e\u73b0\u5982\u4e0b\uff1a<\/p>\n<section>\n<pre data-lang=\"objectivec\"><code>import torch.nn as nn<\/code>\r\n<code>\r\n<\/code><code>class GroupedQueryAttention(nn.Module):<\/code>\r\n<code>    def __init__(<\/code>\r\n<code>            self, d_in, d_out, context_length, num_heads,<\/code>\r\n<code>            num_kv_groups,       # NEW<\/code>\r\n<code>            rope_base=10_000,    # NEW<\/code>\r\n<code>            rope_config=None,    # NEW<\/code>\r\n<code>            dtype=None<\/code>\r\n<code>        ):<\/code>\r\n<code>        super().__init__()<\/code>\r\n<code>        assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"<\/code>\r\n<code>        assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"<\/code>\r\n<code>\r\n<\/code><code>        self.d_out = d_out<\/code>\r\n<code>        self.num_heads = num_heads<\/code>\r\n<code>        self.head_dim = d_out \/\/ num_heads<\/code>\r\n<code>\r\n<\/code><code>        ############################# NEW  #############################<\/code><code> <\/code>\r\n<code>\u00a0 \u00a0 \u00a0 \u00a0 # self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)<\/code>\r\n<code>        # self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)<\/code>\r\n<code>        self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)<\/code>\r\n<code>        self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)<\/code>\r\n<code>        self.num_kv_groups = num_kv_groups<\/code>\r\n<code>        self.group_size = num_heads \/\/ num_kv_groups<\/code>\r\n<code>        ################################################################<\/code><code>\r\n<\/code><code>        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)<\/code>\r\n<code>        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)<\/code>\r\n<code>\r\n<\/code><code>        self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))<\/code>\r\n<code>        cos, sin = precompute_rope_params(<\/code>\r\n<code>            head_dim=self.head_dim,<\/code>\r\n<code>            theta_base=rope_base,      # NEW<\/code>\r\n<code>            freq_config=rope_config,   # NEW<\/code>\r\n<code>            context_length=8192<\/code>\r\n<code>        )<\/code>\r\n<code>        self.register_buffer(\"cos\", cos)<\/code>\r\n<code>        self.register_buffer(\"sin\", sin)<\/code>\r\n<code>\r\n<\/code><code>    def forward(self, x):<\/code>\r\n<code>        b, num_tokens, d_in = x.shape<\/code>\r\n<code>\r\n<\/code><code>        queries = self.W_query(x)  # Shape: (b, num_tokens, d_out)<\/code>\r\n<code>        keys = self.W_key(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)<\/code>\r\n<code>        values = self.W_value(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)<\/code><code>\r\n<\/code><code>        # Reshape queries, keys, and values<\/code>\r\n<code>        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)<\/code>\r\n<code>\r\n<\/code><code>        ##################### NEW  #####################<\/code>\r\n<code>        # keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)<\/code>\r\n<code>        # values = values.view(b, num_tokens, self.num_heads, self.head_dim)<\/code>\r\n<code>        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)<\/code>\r\n<code>        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)<\/code><code>     <\/code>\r\n<code>\u00a0 \u00a0 \u00a0 \u00a0################################################<\/code><code>\r\n<\/code><code>        # Transpose keys, values, and queries<\/code>\r\n<code>        keys = keys.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)<\/code>\r\n<code>        values = values.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)<\/code>\r\n<code>        queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim)<\/code>\r\n<code>\r\n<\/code><code>        # Apply RoPE<\/code>\r\n<code>        keys = compute_rope(keys, self.cos, self.sin)<\/code>\r\n<code>        queries = compute_rope(queries, self.cos, self.sin)<\/code>\r\n<code>\r\n<\/code><code>        ##################### NEW  #####################<\/code><code> <\/code>\r\n<code>\u00a0 \u00a0 \u00a0  \u00a0# Expand keys and values to match the number of heads<\/code>\r\n<code>        # Shape: (b, num_heads, num_tokens, head_dim)<\/code><code>\r\n<\/code><code>        keys = keys.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)<\/code>\r\n<code>        values = values.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)<\/code>\r\n<code>        # For example, before repeat_interleave along dim=1 (query groups):<\/code>\r\n<code>        #   [K1, K2]<\/code>\r\n<code>        # After repeat_interleave (each query group is repeated group_size times):<\/code>\r\n<code>        #   [K1, K1, K2, K2]<\/code>\r\n<code>        # If we used regular repeat instead of repeat_interleave, we'd get:<\/code>\r\n<code>        #   [K1, K2, K1, K2]<\/code>\r\n<code>        ################################################<\/code>\r\n<code>\r\n<\/code><code>        # Compute scaled dot-product attention (aka self-attention) with a causal mask<\/code>\r\n<code>        # Shape: (b, num_heads, num_tokens, num_tokens)<\/code>\r\n<code>        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head<\/code>\r\n<code>\r\n<\/code><code>        # Original mask truncated to the number of tokens and converted to boolean<\/code>\r\n<code>        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]<\/code>\r\n<code>\r\n<\/code><code>        # Use the mask to fill attention scores<\/code>\r\n<code>        attn_scores.masked_fill_(mask_bool, -torch.inf)<\/code>\r\n<code>\r\n<\/code><code>        attn_weights = torch.softmax(attn_scores \/ keys.shape[-1]**0.5, dim=-1)<\/code>\r\n<code>        assert keys.shape[-1] == self.head_dim<\/code>\r\n<code>\r\n<\/code><code>        # Shape: (b, num_tokens, num_heads, head_dim)<\/code>\r\n<code>        context_vec = (attn_weights @ values).transpose(1, 2)<\/code>\r\n<code>\r\n<\/code><code>        # Combine heads, where self.d_out = self.num_heads * self.head_dim<\/code>\r\n<code>        context_vec = context_vec.reshape(b, num_tokens, self.d_out)<\/code>\r\n<code>        context_vec = self.out_proj(context_vec)  # optional projection<\/code>\r\n<code>\r\n<\/code><code>\u00a0 \u00a0 \u00a0 \u00a0 return context_vec<\/code><\/pre>\n<\/section>\n<p>\u53c2\u6570\u8282\u7701\u7684\u60c5\u51b5\uff0c\u8bf7\u53c2\u8003\u4ee5\u4e0b\u6765\u81ea GPT \u548c Llama 2 \u4ee3\u7801\u7684\u591a\u5934\u6ce8\u610f\u529b\u793a\u4f8b\uff1a<\/p>\n<section>\n<pre data-lang=\"makefile\"><code># Settings<\/code>\r\n<code>batch_size = 1<\/code>\r\n<code>context_len = 3000<\/code>\r\n<code>max_context_len = 8192<\/code>\r\n<code>embed_dim = 4096<\/code>\r\n<code>num_heads = 32<\/code>\r\n<code>example_batch = torch.randn((batch_size, context_len, embed_dim))<\/code>\r\n<code>mha=MultiHeadAttention(<\/code>\r\n<code>d_in=embed_dim,<\/code>\r\n<code>d_out=embed_dim,<\/code>\r\n<code>context_length=max_context_len,<\/code>\r\n<code>num_heads=num_heads<\/code>\r\n<code>)<\/code>\r\n<code>mha(example_batch)<\/code>\r\n<code>print(\"W_key:\", mha.W_key.weight.shape)<\/code>\r\n<code>print(\"W_value:\", mha.W_value.weight.shape)<\/code>\r\n<code>print(\"W_query:\", mha.W_query.weight.shape)<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"css\"><code>W_key: torch.Size([4096, 4096])<\/code>\r\n<code>W_value: torch.Size([4096, 4096])<\/code>\r\n<code>W_query: torch.Size([4096, 4096])<\/code><\/pre>\n<\/section>\n<p>\u73b0\u5728\uff0c\u5982\u679c\u6539\u7528\u5206\u7ec4\u67e5\u8be2\u6ce8\u610f\u529b\uff0c\u5e76\u4f7f\u7528 8 \u4e2a kv \u7ec4\uff08Llama 3 8B \u4f7f\u7528\u4e86 8 \u4e2a kv \u7ec4\uff09\uff0c\u53ef\u4ee5\u770b\u5230 key \u548c value \u77e9\u9635\u7684\u884c\u6570\u51cf\u5c11\u4e86 4 \u500d\uff08\u56e0\u4e3a 32 \u4e2a\u6ce8\u610f\u529b\u5934\u9664\u4ee5 8 \u4e2a kv \u7ec4\u5c31\u662f 4\uff09\uff1a<\/p>\n<section>\n<pre data-lang=\"makefile\"><code>gqa = GroupedQueryAttention(<\/code>\r\n<code>d_in=embed_dim,<\/code>\r\n<code>d_out=embed_dim,<\/code>\r\n<code>context_length=max_context_len,<\/code>\r\n<code>num_heads=num_heads,<\/code>\r\n<code>num_kv_groups=8,<\/code>\r\n<code>rope_base=llama_3_theta_base<\/code>\r\n<code>)<\/code>\r\n<code>gqa(example_batch)<\/code>\r\n<code>print(\"W_key:\", gqa.W_key.weight.shape)<\/code>\r\n<code>print(\"W_value:\", gqa.W_value.weight.shape)<\/code>\r\n<code>print(\"W_query:\",\u00a0gqa.W_query.weight.shape)<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"css\"><code>W_key: torch.Size([1024, 4096])<\/code>\r\n<code>W_value: torch.Size([1024, 4096])<\/code>\r\n<code>W_query: torch.Size([4096, 4096])<\/code><\/pre>\n<\/section>\n<p>\u987a\u4fbf\u63d0\u4e00\u4e0b\uff0c\u4e3a\u4e86\u4f7f\u5206\u7ec4\u67e5\u8be2\u6ce8\u610f\u529b\u7b49\u540c\u4e8e\u6807\u51c6\u7684\u591a\u5934\u6ce8\u610f\u529b\uff0c\u53ef\u4ee5\u5c06\u67e5\u8be2\u7ec4\u7684\u6570\u91cf\uff08num_kv_groups\uff09\u8bbe\u7f6e\u4e3a\u4e0e\u5934\u7684\u6570\u91cf\uff08num_heads\uff09\u76f8\u7b49\u3002<\/p>\n<p>\u6700\u540e\uff0c\u6bd4\u8f83\u4e00\u4e0b\u4e0b\u9762\u7684\u53c2\u6570\u6570\u91cf\uff1a<\/p>\n<section>\n<pre data-lang=\"python\"><code>print(\"Total number of parameters:\")<\/code>\r\n<code>mha_total_params = sum(p.numel() for p in mha.parameters())<\/code>\r\n<code>print(f\"MHA: {mha_total_params:,}\")<\/code>\r\n<code>gqa_total_params = sum(p.numel() for p in gqa.parameters())<\/code>\r\n<code>print(f\"GQA:\u00a0{gqa_total_params:,}\")<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"properties\"><code>Total number of parameters:<\/code>\r\n<code>MHA: 67,108,864<\/code>\r\n<code>GQA: 41,943,040<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"properties\"><code># Free up memory:<\/code>\r\n<code>del mha<\/code>\r\n<code>del gqa<\/code><\/pre>\n<\/section>\n<p><strong>1.4 \u66f4\u65b0 TransformerBlock \u6a21\u5757<\/strong><\/p>\n<p>\u63a5\u4e0b\u6765\uff0c\u66f4\u65b0 Transformer \u5757\u3002\u5728\u8fd9\u91cc\uff0c\u53ea\u9700\u5c06 MultiHeadAttention \u4e0e GroupedQueryAttention \u4e92\u6362\uff0c\u5e76\u6dfb\u52a0\u65b0\u7684 RoPE \u8bbe\u7f6e\uff1a<\/p>\n<section>\n<pre data-lang=\"ruby\"><code>class TransformerBlock(nn.Module):<\/code>\r\n<code>    def __init__(self, cfg):<\/code>\r\n<code>        super().__init__()<\/code>\r\n<code>        self.att =  GroupedQueryAttention(  # MultiHeadAttention(<\/code>\r\n<code>            d_in=cfg[\"emb_dim\"],<\/code>\r\n<code>            d_out=cfg[\"emb_dim\"],<\/code>\r\n<code>            context_length=cfg[\"context_length\"],<\/code>\r\n<code>            num_heads=cfg[\"n_heads\"],<\/code>\r\n<code>            num_kv_groups=cfg[\"n_kv_groups\"],  # NEW<\/code>\r\n<code>            rope_base=cfg[\"rope_base\"],        # NEW<\/code>\r\n<code>            rope_config=cfg[\"rope_freq\"],      # NEW<\/code>\r\n<code>            dtype=cfg[\"dtype\"]<\/code>\r\n<code>        )<\/code>\r\n<code>        self.ff = FeedForward(cfg)<\/code>\r\n<code>        self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)<\/code>\r\n<code> <\/code><code>\u00a0 \u00a0 \u00a0 \u00a0self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)<\/code>\r\n<code>\r\n<\/code><code>    def forward(self, x):<\/code>\r\n<code>        # Shortcut connection for attention block<\/code>\r\n<code>        shortcut = x<\/code>\r\n<code>        x = self.norm1(x)<\/code>\r\n<code>        x = self.att(x.to(torch.bfloat16))<\/code><code># Shape [batch_size, num_tokens, emb_size]<\/code>\r\n<code>        x = x + shortcut  # Add the original input back<\/code>\r\n<code>\r\n<\/code><code>        # Shortcut connection for feed-forward block<\/code>\r\n<code>        shortcut = x<\/code>\r\n<code>        x = self.norm2(x)<\/code>\r\n<code>        x = self.ff(x.to(torch.bfloat16))<\/code>\r\n<code>        x = x + shortcut  # Add the original input back<\/code>\r\n<code>\r\n<\/code><code>\u00a0 \u00a0 \u00a0 \u00a0 return x<\/code><\/pre>\n<\/section>\n<p><strong>1.5 \u5b9a\u4e49\u6a21\u578b\u7c7b<\/strong><\/p>\n<p>\u5e78\u8fd0\u7684\u662f\uff0c\u5728\u8bbe\u7f6e\u6a21\u578b\u7c7b\u65f6\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u505a\u592a\u591a\uff0c\u53ea\u9700\u5c06\u540d\u79f0\u66f4\u65b0\u4e3a Llama3Model<\/p>\n<section>\n<pre data-lang=\"ruby\"><code># class Llama2Model(nn.Module):<\/code><code>class Llama3Model(nn.Module):<\/code>\r\n<code>    def __init__(self, cfg):<\/code>\r\n<code>        super().__init__()<\/code>\r\n<code>        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])<\/code>\r\n<code>\r\n<\/code><code>        self.trf_blocks = nn.Sequential(<\/code>\r\n<code>            *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])<\/code>\r\n<code>\r\n<\/code><code>        self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)<\/code>\r\n<code>        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])<\/code>\r\n<code>\r\n<\/code><code>    def forward(self, in_idx):<\/code>\r\n<code>        batch_size, seq_len = in_idx.shape<\/code>\r\n<code>        tok_embeds = self.tok_emb(in_idx)<\/code>\r\n<code>        x = tok_embeds<\/code><code> <\/code>\r\n<code>\u00a0 \u00a0 \u00a0 \u00a0x = self.trf_blocks(x)<\/code>\r\n<code>        x = self.final_norm(x)<\/code>\r\n<code>        logits = self.out_head(x.to(torch.bfloat16))<\/code>\r\n\r\n<code>\u00a0 \u00a0 \u00a0 \u00a0 return logits<\/code><\/pre>\n<\/section>\n<p><strong>2 \u521d\u59cb\u5316\u6a21\u578b<\/strong><\/p>\n<p>\u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5b9a\u4e49\u4e00\u4e2a Llama 3 \u914d\u7f6e\u6587\u4ef6\uff08\u4e3a\u4fbf\u4e8e\u6bd4\u8f83\uff0c\u663e\u793a\u7684\u662f Llama 2 \u914d\u7f6e\u6587\u4ef6\uff09\uff1a<\/p>\n<section>\n<pre data-lang=\"makefile\"><code>LLAMA2_CONFIG_7B = {<\/code>\r\n<code>\"vocab_size\": 32_000,    # Vocabulary size<\/code>\r\n<code>\"context_length\": 4096,  # Context length<\/code>\r\n<code>\"emb_dim\": 4096,         # Embedding dimension<\/code>\r\n<code>\"n_heads\": 32,           # Number of attention heads<\/code>\r\n<code>\"n_layers\": 32,          # Number of layers<\/code>\r\n<code>\"hidden_dim\": 11_008,    # Size of the intermediate dimension in FeedForward<\/code>\r\n<code>\"dtype\": torch.bfloat16  # Lower-precision dtype to save memory<\/code>\r\n<code>}<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"makefile\"><code>LLAMA3_CONFIG_8B = {<\/code>\r\n<code>\"vocab_size\": 128_256,   # NEW: Larger vocabulary size<\/code>\r\n<code>\"context_length\": 8192,  # NEW: Larger context length<\/code>\r\n<code>\"emb_dim\": 4096,         # Embedding dimension<\/code>\r\n<code>\"n_heads\": 32,           # Number of attention heads<\/code>\r\n<code>\"n_layers\": 32,          # Number of layers<\/code>\r\n<code>\"hidden_dim\": 14_336,    # NEW: Larger size of the intermediate dimension in FeedForward<\/code>\r\n<code>\"n_kv_groups\": 8,        # NEW: Key-Value groups for grouped-query attention<\/code>\r\n<code>\"rope_base\": 50_000,     # NEW: The base in RoPE's \"theta\" was increased to 50_000<\/code>\r\n<code>\"rope_freq\": None,       # NEW: Additional configuration for adjusting the RoPE frequencies<\/code>\r\n<code>\"dtype\": torch.bfloat16  # Lower-precision dtype to save memory<\/code>\r\n<code>}<\/code><\/pre>\n<\/section>\n<p>\u4f7f\u7528\u8fd9\u4e9b\u8bbe\u7f6e\uff0c\u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u521d\u59cb\u5316 Llama 3 8B \u6a21\u578b\u3002<\/p>\n<p>\u8bf7\u6ce8\u610f\uff0c\u8fd9\u9700\u8981\u7ea6 34 GB \u5185\u5b58\uff08\u4f5c\u4e3a\u5bf9\u6bd4\uff0cLlama 2 7B \u9700\u8981\u7ea6 26 GB \u5185\u5b58\uff09<\/p>\n<section>\n<pre data-lang=\"ini\"><code>model\u00a0=\u00a0Llama3Model(LLAMA3_CONFIG_8B)<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"python\"><code>total_params = sum(p.numel() for p in model.parameters())<\/code>\r\n<code>print(f\"Total\u00a0number\u00a0of\u00a0parameters:\u00a0{total_params:,}\")<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"javascript\"><code>Total number of parameters: 8,030,261,248<\/code><\/pre>\n<\/section>\n<p>\u5982\u4e0a\u56fe\u6240\u793a\uff0c\u6a21\u578b\u5305\u542b 80 \u4ebf\u4e2a\u53c2\u6570\u3002\u6b64\u5916\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u4e0b\u9762\u7684\u4ee3\u7801\u8ba1\u7b97\u8be5\u6a21\u578b\u7684\u5185\u5b58\u9700\u6c42\uff1a<\/p>\n<section>\n<pre data-lang=\"properties\"><code>def model_memory_size(model, input_dtype=torch.float32):<\/code>\r\n<code>    total_params = 0<\/code>\r\n<code>    total_grads = 0<\/code>\r\n<code>    for param in model.parameters():<\/code>\r\n<code>        # Calculate total number of elements per parameter<\/code>\r\n<code>        param_size = param.numel()<\/code>\r\n<code>        total_params += param_size<\/code>\r\n<code>        # Check if gradients are stored for this parameter<\/code>\r\n<code> \u00a0 \u00a0 \u00a0 if param.requires_grad:<\/code>\r\n<code>            total_grads += param_size<\/code>\r\n<code>\r\n<\/code><code>    # Calculate buffer size (non-parameters that require memory)<\/code>\r\n<code>    total_buffers = sum(buf.numel() for buf in model.buffers())<\/code>\r\n<code>\r\n<\/code><code>    # Size in bytes = (Number of elements) * (Size of each element in bytes)<\/code>\r\n<code>    # We assume parameters and gradients are stored in the same type as input dtype<\/code>\r\n<code>    element_size = torch.tensor(0, dtype=input_dtype).element_size()<\/code>\r\n<code>    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size<\/code>\r\n<code>\r\n<\/code><code>    # Convert bytes to gigabytes<\/code>\r\n<code>    total_memory_gb = total_memory_bytes \/ (1024**3)<\/code>\r\n<code>\r\n<\/code><code>    return total_memory_gb<\/code>\r\n<code>\r\n<\/code><code>print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")<\/code>\r\n<code>print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")<\/code>\r\n<code>float32 (PyTorch default): 68.08 GB<\/code>\r\n<code>bfloat16: 34.04 GB<\/code><\/pre>\n<\/section>\n<p>\u6700\u540e\uff0c\u5982\u679c\u9002\u7528\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u5c06\u6a21\u578b\u8f6c\u79fb\u5230 NVIDIA \u6216 Apple Silicon GPU \u4e0a\uff1a<\/p>\n<section>\n<pre data-lang=\"bash\"><code>if torch.cuda.is_available():<\/code>\r\n<code>    device = torch.device(\"cuda\")<\/code>\r\n<code>elif torch.backends.mps.is_available():<\/code>\r\n<code>    device = torch.device(\"mps\")<\/code>\r\n<code>else:<\/code>\r\n<code>    device = torch.device(\"cpu\")<\/code>\r\n<code>\r\n<\/code><code>model.to(device);<\/code><\/pre>\n<\/section>\n<p><strong>3 \u52a0\u8f7d tokenizer<\/strong><\/p>\n<p>\u5728\u672c\u8282\u4e2d\uff0c\u6211\u4eec\u5c06\u4e3a\u6a21\u578b\u52a0\u8f7d tokenizer\u3002<\/p>\n<p>Llama 2 \u4f7f\u7528\u4e86\u8c37\u6b4c\u7684 SentencePiece \u00a0tokenizer \uff0c\u800c\u4e0d\u662f OpenAI \u57fa\u4e8e Tiktoken \u5e93\u7684 BPE \u00a0tokenizer \u3002\u7136\u800c\uff0cLlama 3 \u6062\u590d\u4f7f\u7528 Tiktoken \u7684 BPE tokenizer\uff1b\u5177\u4f53\u6765\u8bf4\uff0c\u5b83\u4f7f\u7528\u7684\u662f\u5177\u6709\u6269\u5c55\u8bcd\u6c47\u7684 GPT-4 \u00a0tokenizer\u3002\u6211\u4eec\u53ef\u4ee5\u5728 Meta AI \u7684\u5b98\u65b9 Llama 3 \u5b58\u50a8\u5e93\u4e2d\u627e\u5230\u6700\u521d\u7684 Tiktoken \u9002\u914d\u7a0b\u5e8f\u3002<\/p>\n<p>\u4e0b\u9762\u91cd\u5199\u4e86 tokenizer \u7684\u4ee3\u7801\uff0c\u4f7f\u5176\u66f4\u6613\u8bfb\uff0c\u66f4\u9002\u5408\u672c\u7b14\u8bb0\u672c\u4f7f\u7528\uff08\u4f46\u8868\u73b0\u5e94\u8be5\u662f\u76f8\u4f3c\u7684\uff09\uff1a<\/p>\n<section>\n<pre data-lang=\"python\"><code>import os<\/code>\r\n<code>from pathlib import Path<\/code>\r\n<code>\r\n<\/code><code>import tiktoken<\/code>\r\n<code>from tiktoken.load import load_tiktoken_bpe<\/code><code>\r\n<\/code><code>\r\n<\/code><code>class Tokenizer:<\/code>\r\n<code>    def __init__(self, model_path):<\/code>\r\n<code>        assert os.path.isfile(model_path), f\"Model file {model_path} not found\"<\/code>\r\n<code>        mergeable_ranks = load_tiktoken_bpe(model_path)<\/code>\r\n<code>        num_base_tokens = len(mergeable_ranks)<\/code>\r\n<code>\r\n<\/code><code>        self.special_tokens = {<\/code>\r\n<code>            \"&lt;|begin_of_text|&gt;\": 128000,<\/code>\r\n<code>            \"&lt;|end_of_text|&gt;\": 128001,<\/code>\r\n<code>            \"&lt;|start_header_id|&gt;\": 128006,<\/code>\r\n<code>            \"&lt;|end_header_id|&gt;\": 128007,<\/code>\r\n<code>            \"&lt;|eot_id|&gt;\": 128009,<\/code>\r\n<code>        }<\/code>\r\n<code>        self.special_tokens.update({<\/code>\r\n<code>            f\"&lt;|reserved_{i}|&gt;\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()<\/code>\r\n<code>        })<\/code>\r\n<code>\r\n<\/code><code>        self.model = tiktoken.Encoding(<\/code>\r\n<code>            name=Path(model_path).name,<\/code>\r\n<code>            pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^rnp{L}p{N}]?p{L}+|p{N}{1,3}| ?[^sp{L}p{N}]+[rn]*|s*[rn]+|s+(?!S)|s+\",<\/code>\r\n<code>            mergeable_ranks=mergeable_ranks,<\/code>\r\n<code>            special_tokens=self.special_tokens<\/code>\r\n<code>        )<\/code><code>\r\n<\/code><code>\r\n<\/code><code>    def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):<\/code>\r\n<code>        if bos:<\/code>\r\n<code>            tokens = [self.special_tokens[\"&lt;|begin_of_text|&gt;\"]]<\/code>\r\n<code>        else:<\/code>\r\n<code>            tokens = []<\/code>\r\n<code>\r\n<\/code><code>        tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)<\/code>\r\n<code>\r\n<\/code><code>        if eos:<\/code>\r\n<code>            tokens.append(self.special_tokens[\"&lt;|end_of_text|&gt;\"])<\/code>\r\n<code>        return tokens<\/code>\r\n<code>\r\n<\/code><code>    def decode(self, tokens):<\/code><code>\u00a0<\/code>\r\n<code>\u00a0 \u00a0 \u00a0 \u00a0return self.model.decode(tokens)<\/code><\/pre>\n<\/section>\n<p>Meta AI \u5728 Hugging Face Hub \u4e0a\u5171\u4eab\u4e86 Llama 3 \u6a21\u578b\u7684\u539f\u59cb\u6743\u91cd\u548c tokenizer \u8bcd\u5e93\u3002<\/p>\n<p>\u6211\u4eec\u5c06\u9996\u5148\u4ece Hub \u4e0b\u8f7d tokenizer \u8bcd\u5e93\uff0c\u5e76\u5c06\u5176\u52a0\u8f7d\u5230\u4e0a\u8ff0\u4ee3\u7801\u4e2d\u3002\u8bf7\u6ce8\u610f\uff0cMeta AI \u8981\u6c42\u4f60\u5728\u4e0b\u8f7d\u6587\u4ef6\u524d\u63a5\u53d7 Llama 3 \u8bb8\u53ef\u6761\u6b3e\uff1b\u4e3a\u6b64\u5fc5\u987b\u521b\u5efa\u4e00\u4e2a Hugging Face Hub \u8d26\u6237\uff0c\u5e76\u8bbf\u95ee meta-llama\/Meta-Llama-3-8B \u5b58\u50a8\u5e93\u4ee5\u63a5\u53d7\u6761\u6b3e\u3002<\/p>\n<p>\u63a5\u4e0b\u6765\uff0c\u9700\u8981\u521b\u5efa\u4e00\u4e2a\u8bbf\u95ee token\uff1b\u8981\u751f\u6210\u4e00\u4e2a\u5177\u6709\u300c\u8bfb\u53d6\u300d\u6743\u9650\u7684\u8bbf\u95ee token\uff0c\u8bf7\u70b9\u51fb\u53f3\u4e0a\u89d2\u7684\u4e2a\u4eba\u8d44\u6599\u56fe\u7247\uff0c\u7136\u540e\u70b9\u51fb\u300c\u8bbe\u7f6e\u300d\u3002<\/p>\n<p><a href=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-7f7841e4c2b70d258c50c1d9f7f81ad2.png\" data-fancybox=\"images\" data-fancybox=\"gallery\"><img decoding=\"async\" src=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-7f7841e4c2b70d258c50c1d9f7f81ad2.png\"><\/a><\/p>\n<p>\u7136\u540e\uff0c\u521b\u5efa\u5e76\u590d\u5236\u8bbf\u95ee token\uff0c\u4ee5\u4fbf\u590d\u5236\u5e76\u7c98\u8d34\u5230\u4e0b\u4e00\u4e2a\u4ee3\u7801\u5355\u5143\u4e2d\uff1a<\/p>\n<p><a href=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-8ecc6831d5e0eebaca979315ef13cb7d.png\" data-fancybox=\"images\" data-fancybox=\"gallery\"><img decoding=\"async\" src=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-8ecc6831d5e0eebaca979315ef13cb7d.png\"><\/a><\/p>\n<section>\n<pre data-lang=\"swift\"><code>from huggingface_hub import login<\/code>\r\n<code>import json<\/code>\r\n<code>with open(\"config.json\", \"r\") as config_file:<\/code>\r\n<code>config = json.load(config_file)<\/code>\r\n<code>access_token = config[\"HF_ACCESS_TOKEN\"]<\/code>\r\n<code>login(token=access_token)<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"cs\"><code>The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.<\/code>\r\n<code>Token is valid (permission: read).<\/code>\r\n<code>Your token has been saved to \/root\/.cache\/huggingface\/token<\/code>\r\n<code>Login successful<\/code><\/pre>\n<\/section>\n<p>\u901a\u8fc7\u8bbf\u95ee token \u767b\u5f55\uff08\u8fd9\u662f\u9a8c\u8bc1\u6211\u4eec\u662f\u5426\u63a5\u53d7 Llama 3 \u8bb8\u53ef\u6761\u6b3e\u6240\u5fc5\u9700\u7684\uff09\u540e\uff0c\u5c31\u53ef\u4ee5\u4e0b\u8f7d tokenizer \u8bcd\u5e93\u4e86\uff1a<\/p>\n<section>\n<pre data-lang=\"makefile\"><code>from huggingface_hub import hf_hub_download<\/code>\r\n<code>tokenizer_file_path = hf_hub_download(<\/code>\r\n<code>repo_id=\"meta-llama\/Meta-Llama-3-8B\",<\/code>\r\n<code>filename=\"original\/tokenizer.model\",<\/code>\r\n<code>local_dir=\"llama3-files\"<\/code>\r\n<code>)<\/code><\/pre>\n<\/section>\n<p>\u8bf7\u6ce8\u610f\uff0c\u5728\u4f7f\u7528 Llama 3 \u6587\u4ef6\u65f6\uff0c\u6211\u4eec\u53ef\u80fd\u9700\u8981 blobfile \u8f6f\u4ef6\u5305\uff0c\u5b83\u7528\u4e8e\u5904\u7406\u5b58\u50a8\u5728\u4e91\u5b58\u50a8\u89e3\u51b3\u65b9\u6848\uff08\u5982 Google Cloud Storage (GCS)\u3001Azure Blob Storage \u6216 Amazon S3\uff09\u4e2d\u7684\u6570\u636e\u96c6\u6216\u6a21\u578b\u3002<\/p>\n<p>\u53ef\u4ee5\u901a\u8fc7\u53d6\u6d88\u6ce8\u91ca\u5e76\u6267\u884c\u4e0b\u9762\u7684 pip \u547d\u4ee4\u6765\u5b89\u88c5\u6b64\u4f9d\u8d56\u5305\uff1a<\/p>\n<section>\n<pre data-lang=\"apache\"><code>#\u00a0pip\u00a0install\u00a0blobfile<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"ini\"><code>tokenizer = Tokenizer(tokenizer_file_path)<\/code><\/pre>\n<\/section>\n<p>\u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u751f\u6210\u51fd\u6570\u8ba9 Llama 3 \u6a21\u578b\u751f\u6210\u65b0\u6587\u672c\uff1a<\/p>\n<section>\n<pre data-lang=\"makefile\"><code>from previous_chapters import generate, text_to_token_ids, token_ids_to_text<\/code>\r\n<code>torch.manual_seed(123)<\/code>\r\n<code>token_ids = generate(<\/code>\r\n<code>model=model,<\/code>\r\n<code>idx=text_to_token_ids(\"Every effort\", <\/code><code>tokenizer).to(device),<\/code>\r\n<code>max_new_tokens=30,<\/code>\r\n<code>context_size=LLAMA3_CONFIG_8B[\"context_length\"],<\/code>\r\n<code>top_k=1,<\/code>\r\n<code>temperature=0.<\/code>\r\n<code>)<\/code>\r\n<code>print(\"Output\u00a0text:n\",\u00a0token_ids_to_text(token_ids,\u00a0tokenizer))<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"properties\"><code>Output text:<\/code><code>\u00a0<\/code>\r\n<code>Every effort_dead aeros Ingredients ba\u015f\u0131nda.extension clangmissions.esp \uc0ac\uc9c4 Ek Pars til DoctorsDao\u0435\u043d\u044costivan normal Ekized \ufffd Ekized \ufffd Ek rdr t\u0131k%,orgen&gt;',<\/code><\/pre>\n<\/section>\n<p>\u5f53\u7136\uff0c\u6b63\u5982\u6211\u4eec\u5728\u4e0a\u9762\u770b\u5230\u7684\uff0c\u8fd9\u6bb5\u6587\u5b57\u662f\u6beb\u65e0\u610f\u4e49\u7684\uff0c\u56e0\u4e3a\u6211\u4eec\u8fd8\u6ca1\u6709\u8bad\u7ec3\u8fc7 Llama 3 \u6a21\u578b\u3002\u5728\u4e0b\u4e00\u8282\u4e2d\uff0c\u6211\u4eec\u5c06\u4ece Meta AI \u4e2d\u52a0\u8f7d\u9884\u8bad\u7ec3\u7684\u6743\u91cd\uff0c\u800c\u4e0d\u662f\u81ea\u5df1\u8bad\u7ec3\u6a21\u578b\uff0c\u56e0\u4e3a\u8fd9\u5c06\u82b1\u8d39\u6570\u4e07\u81f3\u6570\u5341\u4e07\u7f8e\u5143\u3002<\/p>\n<p><strong>4 \u52a0\u8f7d\u9884\u8bad\u7ec3\u6743\u91cd<\/strong><\/p>\n<p>\u6211\u4eec\u5c06\u52a0\u8f7d\u4e0b\u9762\u7684\u300cmeta-llama\/Meta-Llama-3-8B \u300dbase \u6a21\u578b\uff0c\u5b83\u662f\u5fae\u8c03\u524d\u7684\u7b80\u5355\u6587\u672c\u8865\u5168\u6a21\u578b\u3002<\/p>\n<p>\u6216\u8005\uff0c\u4f60\u4e5f\u53ef\u4ee5\u52a0\u8f7d\u7ecf\u8fc7\u6307\u4ee4\u5fae\u8c03\u548c\u5bf9\u9f50\u7684\u300cmeta-llama\/Meta-Llama-3-8B-Instruct\u300d\u6a21\u578b\uff0c\u65b9\u6cd5\u662f\u76f8\u5e94\u4fee\u6539\u4e0b\u4e00\u4e2a\u4ee3\u7801\u5355\u5143\u4e2d\u7684\u5b57\u7b26\u4e32\u3002\u52a0\u8d77\u6765\uff0c\u6743\u91cd\u6587\u4ef6\u5927\u7ea6\u6709 16 GB \u5927\u3002<\/p>\n<section>\n<pre data-lang=\"ruby\"><code>from safetensors.torch import load_file<\/code>\r\n<code>\r\n<\/code><code>combined_weights = {}<\/code>\r\n<code>\r\n<\/code><code>for i in range(1, 5):<\/code>\r\n<code>    weights_file = hf_hub_download(<\/code>\r\n<code>        repo_id=\"meta-llama\/Meta-Llama-3-8B\",<\/code>\r\n<code>        filename=f\"model-0000{i}-of-00004.safetensors\",<\/code>\r\n<code>        local_dir=\"llama3-files\"<\/code>\r\n<code>    )<\/code>\r\n<code>    current_weights = load_file(weights_file)<\/code>\r\n<code>    combined_weights.update(current_weights)<\/code>\r\n<code>model-00001-of-00004.safetensors:   0%|          | 0.00\/4.98G [00:00, ?B\/s]<\/code>\r\n<code>model-00002-of-00004.safetensors:   0%|          | 0.00\/5.00G [00:00, ?B\/s]<\/code>\r\n<code>model-00003-of-00004.safetensors:   0%|          | 0.00\/4.92G [00:00, ?B\/s]<\/code>\r\n<code>model-00004-of-00004.safetensors: \u00a0 0%| \u00a0 \u00a0 \u00a0 \u00a0 \u00a0| 0.00\/1.17G [00:00, ?B\/s]<\/code><\/pre>\n<\/section>\n<p>\u6743\u91cd\u5305\u542b\u4ee5\u4e0b\u5f20\u91cf\uff08\u4e3a\u7b80\u5355\u8d77\u89c1\uff0c\u53ea\u663e\u793a\u524d 15 \u4e2a\u5f20\u91cf\uff09\uff1a<\/p>\n<section>\n<pre data-lang=\"css\"><code>list(combined_weights.keys())[:15]<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"json\"><code>['model.embed_tokens.weight',<\/code><code> <\/code>\r\n<code>'model.layers.0.input_layernorm.weight',<\/code>\r\n<code>'model.layers.0.mlp.down_proj.weight',<\/code><code> <\/code>\r\n<code>'model.layers.0.mlp.gate_proj.weight',<\/code><code> <\/code>\r\n<code>'model.layers.0.mlp.up_proj.weight',<\/code><code> <\/code>\r\n<code>'model.layers.0.post_attention_layernorm.weight',<\/code><code> <\/code>\r\n<code>'model.layers.0.self_attn.k_proj.weight',<\/code><code> <\/code>\r\n<code>'model.layers.0.self_attn.o_proj.weight',<\/code><code> <\/code>\r\n<code>'model.layers.0.self_attn.q_proj.weight',<\/code><code> <\/code>\r\n<code>'model.layers.0.self_attn.v_proj.weight',<\/code><code> <\/code>\r\n<code>'model.layers.1.input_layernorm.weight',<\/code><code> <\/code>\r\n<code>'model.layers.1.mlp.down_proj.weight',<\/code><code> <\/code>\r\n<code>'model.layers.1.mlp.gate_proj.weight',<\/code><code> <\/code>\r\n<code>'model.layers.1.mlp.up_proj.weight',<\/code><code>\u00a0<\/code>\r\n<code>'model.layers.1.post_attention_layernorm.weight']<\/code><\/pre>\n<\/section>\n<p>\u4e0b\u9762\u7684\u51fd\u6570\u4eff\u7167\u300aBuild a Large Language Model From Scratch\u300b\u7b2c 5 \u7ae0\uff08https:\/\/github.com\/rasbt\/LLMs-from-scratch\/blob\/0972ded5309c25dc5eecc98b62897d677c6c36c4\/ch05\/01_main-chapter-code\/ch05.ipynb\uff09\u4e2d\u7684 load_weights_into_gpt \u51fd\u6570\uff0c\u5c06\u9884\u8bad\u7ec3\u597d\u7684\u6743\u91cd\u52a0\u8f7d\u5230 Llama 3 \u6a21\u578b\u4e2d\uff1a<\/p>\n<section>\n<pre data-lang=\"python\"><code>def assign(left, right, tensor_name=\"unknown\"):<\/code>\r\n<code>    if left.shape != right.shape:<\/code>\r\n<code>        raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")<\/code>\r\n<code>\r\n<\/code><code>    if isinstance(right, torch.Tensor):<\/code>\r\n<code>        return torch.nn.Parameter(right.clone().detach())<\/code>\r\n<code>    else:<\/code>\r\n<code>        return torch.nn.Parameter(torch.tensor(right))<\/code><code>\r\n<\/code><code>\r\n<\/code><code>def load_weights_into_llama(model, param_config, params):<\/code>\r\n<code>    model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")<\/code>\r\n<code>\r\n<\/code><code>    for l in range(param_config[\"n_layers\"]):<\/code>\r\n<code>\r\n<\/code><code>        # Load attention weights<\/code>\r\n<code>        model.trf_blocks[l].att.W_query.weight = assign(<\/code>\r\n<code>            model.trf_blocks[l].att.W_query.weight,<\/code>\r\n<code>            params[f\"model.layers.{l}.self_attn.q_proj.weight\"],<\/code>\r\n<code>            f\"model.layers.{l}.self_attn.q_proj.weight\"<\/code>\r\n<code>        )<\/code>\r\n<code>        model.trf_blocks[l].att.W_key.weight = assign(<\/code>\r\n<code>            model.trf_blocks[l].att.W_key.weight,<\/code>\r\n<code>            params[f\"model.layers.{l}.self_attn.k_proj.weight\"],<\/code>\r\n<code>            f\"model.layers.{l}.self_attn.k_proj.weight\"<\/code>\r\n<code>        )<\/code>\r\n<code>        model.trf_blocks[l].att.W_value.weight = assign(<\/code>\r\n<code>            model.trf_blocks[l].att.W_value.weight,<\/code>\r\n<code>            params[f\"model.layers.{l}.self_attn.v_proj.weight\"],<\/code>\r\n<code>            f\"model.layers.{l}.self_attn.v_proj.weight\"<\/code>\r\n<code>        )<\/code>\r\n<code>        model.trf_blocks[l].att.out_proj.weight = assign(<\/code>\r\n<code>            model.trf_blocks[l].att.out_proj.weight,<\/code>\r\n<code>            params[f\"model.layers.{l}.self_attn.o_proj.weight\"],<\/code>\r\n<code>            f\"model.layers.{l}.self_attn.o_proj.weight\"<\/code>\r\n<code>        )<\/code>\r\n<code>        model.trf_blocks[l].norm1.weight = assign(<\/code>\r\n<code>            model.trf_blocks[l].norm1.weight,<\/code>\r\n<code>            params[f\"model.layers.{l}.input_layernorm.weight\"],<\/code>\r\n<code>            f\"model.layers.{l}.input_layernorm.weight\"<\/code>\r\n<code>        )<\/code>\r\n<code>\r\n<\/code><code>        # Load FeedForward weights<\/code>\r\n<code>        model.trf_blocks[l].ff.fc1.weight = assign(<\/code>\r\n<code>            model.trf_blocks[l].ff.fc1.weight,<\/code>\r\n<code>            params[f\"model.layers.{l}.mlp.gate_proj.weight\"],<\/code>\r\n<code>            f\"model.layers.{l}.mlp.gate_proj.weight\"<\/code>\r\n<code>        )<\/code>\r\n<code>        model.trf_blocks[l].ff.fc2.weight = assign(<\/code>\r\n<code>            model.trf_blocks[l].ff.fc2.weight,<\/code>\r\n<code>            params[f\"model.layers.{l}.mlp.up_proj.weight\"],<\/code>\r\n<code>            f\"model.layers.{l}.mlp.up_proj.weight\"<\/code>\r\n<code>        )<\/code>\r\n<code>        model.trf_blocks[l].ff.fc3.weight = assign(<\/code>\r\n<code>            model.trf_blocks[l].ff.fc3.weight,<\/code>\r\n<code>            params[f\"model.layers.{l}.mlp.down_proj.weight\"],<\/code>\r\n<code>            f\"model.layers.{l}.mlp.down_proj.weight\"<\/code>\r\n<code>        )<\/code>\r\n<code>        model.trf_blocks[l].norm2.weight = assign(<\/code>\r\n<code>            model.trf_blocks[l].norm2.weight,<\/code>\r\n<code>            params[f\"model.layers.{l}.post_attention_layernorm.weight\"],<\/code>\r\n<code>            f\"model.layers.{l}.post_attention_layernorm.weight\"<\/code>\r\n<code>        )<\/code>\r\n<code>\r\n<\/code><code>    # Load output layer weights<\/code>\r\n<code>    model.final_norm.weight = assign(model.final_norm.weight, params[\"model.norm.weight\"], \"model.norm.weight\")<\/code>\r\n<code>\r\n<\/code><code>    if \"lm_head.weight\" in params.keys():<\/code>\r\n<code>        model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")<\/code>\r\n<code>    else:<\/code>\r\n<code>        model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")<\/code>\r\n<code>        print(\"Model uses weight tying.\")<\/code><code>\r\n<\/code><code>\r\n<\/code><code>load_weights_into_llama(model, LLAMA3_CONFIG_8B, combined_weights)<\/code>\r\n<code>model.to(device);<\/code>\r\n<code>del combined_weights \u00a0# free up memory<\/code><\/pre>\n<\/section>\n<p>\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u8be5\u6a21\u578b\u751f\u6210\u6587\u672c\u4e86\uff1a<\/p>\n<section>\n<pre data-lang=\"properties\"><code>torch.manual_seed(123)<\/code>\r\n<code>\r\n<\/code><code>token_ids = generate(<\/code>\r\n<code>    model=model,<\/code>\r\n<code>    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),<\/code>\r\n<code>    max_new_tokens=25,<\/code>\r\n<code>    context_size=LLAMA3_CONFIG_8B[\"context_length\"],<\/code>\r\n<code>    top_k=1,<\/code>\r\n<code>    temperature=0.<\/code>\r\n<code>)<\/code>\r\n<code>\r\n<\/code><code>print(\"Output text:n\", token_ids_to_text(token_ids, tokenizer))<\/code>\r\n<code>Output text:<\/code>\r\n<code> Every effort has been made to trace copyright holders and to obtain their permission for the use of copyright material. The publisher apologizes for any<\/code><\/pre>\n<\/section>\n<p><strong>5 \u4f7f\u7528\u6307\u4ee4\u5fae\u8c03\u6a21\u578b<\/strong><\/p>\n<p>\u4e0a\u9762\u6211\u4eec\u4f7f\u7528\u7684\u662f\u7ecf\u8fc7\u9884\u8bad\u7ec3\u7684\u57fa\u7840\u6a21\u578b\uff0c\u5982\u679c\u4f60\u60f3\u4f7f\u7528\u4e00\u4e2a\u80fd\u591f\u9075\u5faa\u6307\u4ee4\u7684\u6a21\u578b\uff0c\u8bf7\u4f7f\u7528\u300cmeta-llama\/Llama-3-8b-Instruct\u300d\u6a21\u578b\uff0c\u5982\u4e0b\u6240\u793a\uff1a<\/p>\n<section>\n<pre data-lang=\"css\"><code># to free up memory<\/code>\r\n<code>import gc<\/code>\r\n<code>del model<\/code>\r\n<code>gc.collect()  # Run Python garbage collector<\/code>\r\n<code>if torch.cuda.is_available():<\/code>\r\n<code>torch.cuda.empty_cache()<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"makefile\"><code>combined_weights = {}<\/code>\r\n<code>for i in range(1, 5):<\/code>\r\n<code>weights_file = hf_hub_download(<\/code>\r\n<code>repo_id=\"meta-llama\/Meta-Llama-3-8B-Instruct\",<\/code>\r\n<code>filename=f\"model-0000{i}-of-00004.safetensors\",<\/code>\r\n<code>local_dir=\"llama3-files\"<\/code>\r\n<code>)<\/code>\r\n<code>current_weights = load_file(weights_file)<\/code>\r\n<code>combined_weights.update(current_weights)<\/code>\r\n<code>model = Llama3Model(LLAMA3_CONFIG_8B)<\/code>\r\n<code>load_weights_into_llama(model, LLAMA3_CONFIG_8B, combined_weights)<\/code>\r\n<code>model.to(device)<\/code>\r\n<code>del\u00a0combined_weights\u00a0\u00a0#\u00a0free\u00a0up\u00a0memory<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"ruby\"><code>model-00001-of-00004.safetensors:   0%|          | 0.00\/4.98G [00:00, ?B\/s] <\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"ruby\"><code>model-00002-of-00004.safetensors:   0%|          | 0.00\/5.00G [00:00, ?B\/s]<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"ruby\"><code>model-00003-of-00004.safetensors:   0%|          | 0.00\/4.92G [00:00, ?B\/s]<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"ruby\"><code>model-00004-of-00004.safetensors: \u00a0 0%| \u00a0 \u00a0 \u00a0 \u00a0 \u00a0| 0.00\/1.17G [00:00, ?B\/s]<\/code><\/pre>\n<\/section>\n<p>\u8bf7\u6ce8\u610f\uff0cLlama 3 \u6a21\u578b\u6700\u597d\u4e0e\u5fae\u8c03\u65f6\u4f7f\u7528\u7684\u6b63\u786e\u63d0\u793a\u6a21\u677f\u4e00\u8d77\u4f7f\u7528\u3002<\/p>\n<p>\u4e0b\u9762\u662f\u4e00\u4e2a\u57fa\u4e8e Meta AI \u7684 Llama 3 \u4e13\u7528 ChatFormat \u4ee3\u7801\u7684 tokenizer wrapper \u7c7b\uff0c\u7528\u4e8e\u6784\u5efa\u63d0\u793a\u6a21\u677f\uff1a<\/p>\n<section>\n<pre data-lang=\"ruby\"><code>class ChatFormat:<\/code>\r\n<code>    def __init__(self, tokenizer):<\/code>\r\n<code>        self.tokenizer = tokenizer<\/code>\r\n<code>\r\n<\/code><code>    def encode_header(self, message):<\/code>\r\n<code>        tokens = []<\/code>\r\n<code>        tokens.append(self.tokenizer.special_tokens[\"&lt;|start_header_id|&gt;\"])<\/code>\r\n<code>        tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))<\/code>\r\n<code>        tokens.append(self.tokenizer.special_tokens[\"&lt;|end_header_id|&gt;\"])<\/code>\r\n<code>        tokens.extend(self.tokenizer.encode(\"nn\", bos=False, eos=False))<\/code>\r\n<code>        return tokens<\/code>\r\n<code>\r\n<\/code><code>    def encode(self, text):<\/code>\r\n<code>        message = {<\/code>\r\n<code>            \"role\": \"user\",<\/code>\r\n<code>            \"content\": text<\/code>\r\n<code>        }<\/code>\r\n<code>\r\n<\/code><code>        tokens = self.encode_header(message)<\/code>\r\n<code>        tokens.extend(<\/code>\r\n<code>            self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)<\/code>\r\n<code>        )<\/code>\r\n<code>        tokens.append(self.tokenizer.special_tokens[\"&lt;|eot_id|&gt;\"])<\/code><code>        return tokens<\/code><code>\r\n<\/code><code>    def decode(self, token_ids):<\/code>\r\n<code>        return self.tokenizer.decode(token_ids)<\/code><code>\r\n<\/code><code>\r\n<\/code><code>chat_tokenizer = ChatFormat(tokenizer)<\/code><\/pre>\n<\/section>\n<p>\u7528\u6cd5\u5982\u4e0b\uff1a<\/p>\n<section>\n<pre data-lang=\"makefile\"><code>token_ids = chat_tokenizer.encode(\"Hello World!\")<\/code>\r\n<code>print(token_ids)<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"json\"><code>[128006, 882, 128007, 271, 9906, 4435, 0, 128009]<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"css\"><code>tokenizer.decode(token_ids)<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"xml\"><code>'&lt;|start_header_id|&gt;user&lt;|end_header_id|&gt;nnHello World!&lt;|eot_id|&gt;'<\/code><\/pre>\n<\/section>\n<p>\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b Llama 3 \u6559\u5b66\u6a21\u5f0f\u7684\u5b9e\u9645\u5e94\u7528\uff1a<\/p>\n<section>\n<pre data-lang=\"makefile\"><code>import re<\/code>\r\n<code>torch.manual_seed(123)<\/code>\r\n<code>token_ids = generate(<\/code>\r\n<code>model=model,<\/code>\r\n<code>idx=text_to_token_ids(\"What do llamas eat?\", chat_tokenizer).to(device),<\/code>\r\n<code>max_new_tokens=150,<\/code>\r\n<code>context_size=LLAMA3_CONFIG_8B[\"context_length\"],<\/code>\r\n<code>top_k=1,<\/code>\r\n<code>temperature=0.<\/code>\r\n<code>)<\/code>\r\n<code>output_text = token_ids_to_text(token_ids, tokenizer)<\/code>\r\n<code>def clean_text(text, header_end=\"assistant&lt;|end_header_id|&gt;nn\"):<\/code>\r\n<code># Find the index of the first occurrence of \"&lt;|end_header_id|&gt;\"<\/code>\r\n<code>index = text.find(header_end)<\/code>\r\n<code>if index != -1:<\/code>\r\n<code># Return the substring starting after \"&lt;|end_header_id|&gt;\"<\/code>\r\n<code>return text[index + len(header_end):].strip()  # Strip removes leading\/trailing whitespace<\/code>\r\n<code>else:<\/code>\r\n<code># If the token is not found, return the original text<\/code>\r\n<code>return text<\/code>\r\n<code>print(\"Output\u00a0text:n\",\u00a0clean_text(output_text))<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"cpp\"><code>Output text:<\/code>\r\n<code> Llamas are herbivores, which means they primarily eat plants and plant-based foods. Here are some of the things llamas like to eat:<\/code>\r\n<code>\r\n<\/code><code>1. Grass: Llamas love to graze on grass, especially in the spring and summer months.<\/code>\r\n<code>2. Hay: Hay is a staple in a llama's diet. They like to eat timothy hay, alfalfa hay, and other types of hay.<\/code>\r\n<code>3. Grains: Llamas may also be fed grains like oats, barley, and corn. However, grains should not make up more than 10% of a llama's diet.<\/code>\r\n<code>4. Fruits and vegetables: Llamas may enjoy fruits and vegetables as treats, such as apples,<\/code><\/pre>\n<\/section>\n<p><strong>\u00a0Llama 3.1 8B<\/strong><\/p>\n<p>\u5728 Llama 3 \u53d1\u5e03\u51e0\u4e2a\u6708\u540e\uff0cMeta AI \u53c8\u63a8\u51fa\u4e86 Llama 3.1 \u6a21\u578b\u5957\u4ef6\uff08\u8be6\u89c1 Llama 3.1 \u5b98\u65b9\u4ecb\u7ecd\uff09\u3002<\/p>\n<p>\u65b9\u4fbf\u7684\u662f\uff0c\u6211\u4eec\u53ef\u4ee5\u91cd\u590d\u4f7f\u7528\u4e4b\u524d\u7684 Llama 3 \u4ee3\u7801\u6765\u5b9e\u73b0 Llama 3.1 8B\uff1a<\/p>\n<p><a href=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-3daa97087b3fe77c7184e0ef0ccd974e.png\" data-fancybox=\"images\" data-fancybox=\"gallery\"><img decoding=\"async\" src=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-3daa97087b3fe77c7184e0ef0ccd974e.png\"><\/a><\/p>\n<p>\u7ed3\u6784\u5b8c\u5168\u76f8\u540c\uff0c\u552f\u4e00\u7684\u53d8\u5316\u662f\u91cd\u65b0\u8c03\u6574\u4e86 RoPE \u9891\u7387\uff0c\u5982\u4e0b\u914d\u7f6e\u6587\u4ef6\u6240\u793a\uff1a<\/p>\n<section>\n<pre data-lang=\"makefile\"><code>LLAMA3_CONFIG_8B = {<\/code>\r\n<code>\"vocab_size\": 128_256,   # Vocabulary size<\/code>\r\n<code>\"context_length\": 8192,  # Context length<\/code>\r\n<code>\"emb_dim\": 4096,         # Embedding dimension<\/code>\r\n<code>\"n_heads\": 32,           # Number of attention heads<\/code>\r\n<code>\"n_layers\": 32,          # Number of layers<\/code>\r\n<code>\"hidden_dim\": 14_336,    # Size of the intermediate dimension in FeedForward<\/code>\r\n<code>\"n_kv_groups\": 8,        # Key-Value groups for grouped-query attention<\/code>\r\n<code>\"rope_base\": 50_000,     # The base in RoPE's \"theta\"<\/code>\r\n<code>\"rope_freq\": None,       # Additional configuration for adjusting the RoPE frequencies<\/code>\r\n<code>\"dtype\": torch.bfloat16  # Lower-precision dtype to save memory<\/code>\r\n<code>}<\/code>\r\n<code>LLAMA31_CONFIG_8B = {<\/code>\r\n<code>\"vocab_size\": 128_256,    # Vocabulary size<\/code>\r\n<code>\"context_length\": 8192,   # Context length<\/code>\r\n<code>\"emb_dim\": 4096,          # Embedding dimension<\/code>\r\n<code>\"n_heads\": 32,            # Number of attention heads<\/code>\r\n<code>\"n_layers\": 32,           # Number of layers<\/code>\r\n<code>\"hidden_dim\": 14_336,     # Size of the intermediate dimension in FeedForward<\/code>\r\n<code>\"n_kv_groups\": 8,         # Key-Value groups for grouped-query attention<\/code>\r\n<code>\"rope_base\": 50_000,      # The base in RoPE's \"theta\"<\/code>\r\n<code>\"dtype\": torch.bfloat16,  # Lower-precision dtype to save memory<\/code>\r\n<code>\"rope_freq\": {            # NEW: RoPE frequency scaling<\/code>\r\n<code>\"factor\": 8.0,<\/code>\r\n<code>\"low_freq_factor\": 1.0,<\/code>\r\n<code>\"high_freq_factor\": 4.0,<\/code>\r\n<code>\"original_context_length\": 8192,<\/code>\r\n<code>}<\/code>\r\n<code>}<\/code><\/pre>\n<\/section>\n<p>\u6b63\u5982\u6211\u4eec\u4e4b\u524d\u5728\u4ee3\u7801\u4e2d\u770b\u5230\u7684\uff0cRoPE \u65b9\u6cd5\u4f7f\u7528\u6b63\u5f26\u51fd\u6570\uff08\u6b63\u5f26\u548c\u4f59\u5f26\uff09\u5c06\u4f4d\u7f6e\u4fe1\u606f\u76f4\u63a5\u5d4c\u5165\u6ce8\u610f\u529b\u673a\u5236\u4e2d\u3002<\/p>\n<p>\u5728 Llama 3.1 \u4e2d\uff0c\u901a\u8fc7\u9644\u52a0\u914d\u7f6e\uff0c\u6211\u4eec\u5bf9\u53cd\u5411\u9891\u7387\u8ba1\u7b97\u8fdb\u884c\u4e86\u989d\u5916\u8c03\u6574\u3002\u8fd9\u4e9b\u8c03\u6574\u4f1a\u5f71\u54cd\u4e0d\u540c\u9891\u7387\u6210\u5206\u5bf9\u4f4d\u7f6e\u5d4c\u5165\u7684\u8d21\u732e\u3002<\/p>\n<p>\u8ba9\u6211\u4eec\u5728\u5b9e\u8df5\u4e2d\u8bd5\u8bd5 Llama 3.1 \u6a21\u578b\uff1b\u9996\u5148\uff0c\u6211\u4eec\u6e05\u9664\u65e7\u6a21\u578b\uff0c\u4ee5\u91ca\u653e\u4e00\u4e9b GPU \u5185\u5b58\uff1a<\/p>\n<section>\n<pre data-lang=\"css\"><code># free up memory<\/code>\r\n<code>del model<\/code>\r\n<code>\r\n<\/code><code>gc.collect()  # Run Python garbage collector<\/code>\r\n<code>\r\n<\/code><code>if torch.cuda.is_available():<\/code><code>\u00a0 \u00a0 torch.cuda.empty_cache()<\/code><\/pre>\n<\/section>\n<p>\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u4e0b\u8f7d tokenizer\u3002<\/p>\n<p>\u8bf7\u6ce8\u610f\uff0c\u7531\u4e8e Llama 3.1 \u7cfb\u5217\u4e0e Llama 3 \u7cfb\u5217\u4e0d\u540c\uff0c\u3001\u5fc5\u987b\u8bbf\u95ee meta-llama\/Llama-3.1-8Brepository\uff0c\u5e76\u786e\u8ba4\u8bb8\u53ef\u6761\u6b3e\uff0c\u8fd9\u6837 Hugging Face \u8bbf\u95ee token \u624d\u80fd\u5728\u4e0b\u8f7d\u65f6\u8d77\u4f5c\u7528\u3002<\/p>\n<p>\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u5728\u4e0b\u9762\u53ea\u52a0\u8f7d base \u6a21\u578b\uff0c\u4f46\u4e5f\u6709\u4e00\u4e2a\u7ecf\u8fc7\u6307\u4ee4\u5fae\u8c03\u7684\u7248\u672c\uff0c\u4f60\u53ef\u4ee5\u5c06\u300cmeta-llama\/Llama-3.1-8B\u300d\u66ff\u6362\u4e3a\u300cmeta-llama\/Llama-3.1-8B-Instruct\u300d\u3002<\/p>\n<section>\n<pre data-lang=\"makefile\"><code>tokenizer_file_path = hf_hub_download(<\/code>\r\n<code>    repo_id=\"meta-llama\/Llama-3.1-8B\",<\/code>\r\n<code>    filename=\"original\/tokenizer.model\",<\/code>\r\n<code>    local_dir=\"llama3-files\"<\/code>\r\n<code>)<\/code>\r\n<code>\r\n<\/code><code>tokenizer = Tokenizer(tokenizer_file_path)<\/code>\r\n<code>model = Llama3Model(LLAMA31_CONFIG_8B)<\/code>\r\n<code>\r\n<\/code><code>total_params = sum(p.numel() for p in model.parameters())<\/code>\r\n<code>print(f\"Total number of parameters: {total_params:,}\")<\/code>\r\n<code>Total number of parameters: 8,030,261,248<\/code>\r\n<code>combined_weights = {}<\/code>\r\n<code>\r\n<\/code><code>for i in range(1, 5):<\/code>\r\n<code>    weights_file = hf_hub_download(<\/code>\r\n<code>        repo_id=\"meta-llama\/Llama-3.1-8B\",<\/code>\r\n<code>        filename=f\"model-0000{i}-of-00004.safetensors\",<\/code>\r\n<code>        local_dir=\"llama3-files\"<\/code>\r\n<code>    )<\/code>\r\n<code>    current_weights = load_file(weights_file)<\/code>\r\n<code>    combined_weights.update(current_weights)<\/code>\r\n<code>\r\n<\/code><code>load_weights_into_llama(model, LLAMA31_CONFIG_8B, combined_weights)<\/code>\r\n<code>model.to(device);<\/code>\r\n<code>model-00001-of-00004.safetensors:   0%|          | 0.00\/4.98G [00:00, ?B\/s]<\/code>\r\n<code>model-00002-of-00004.safetensors:   0%|          | 0.00\/5.00G [00:00, ?B\/s]<\/code>\r\n<code>model-00003-of-00004.safetensors:   0%|          | 0.00\/4.92G [00:00, ?B\/s]<\/code>\r\n<code>model-00004-of-00004.safetensors:   0%|          | 0.00\/1.17G [00:00, ?B\/s]<\/code>\r\n<code>torch.manual_seed(123)<\/code>\r\n<code>\r\n<\/code><code>token_ids = generate(<\/code>\r\n<code>    model=model,<\/code>\r\n<code>    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),<\/code>\r\n<code>    max_new_tokens=25,<\/code>\r\n<code>    context_size=LLAMA31_CONFIG_8B[\"context_length\"],<\/code>\r\n<code>    top_k=1,<\/code>\r\n<code>    temperature=0.<\/code>\r\n<code>)<\/code>\r\n<code>\r\n<\/code><code>print(\"Output text:n\", token_ids_to_text(token_ids, tokenizer))<\/code>\r\n<code>Output text:<\/code>\r\n<code>\u00a0Every effort has been made to trace copyright holders and to obtain their permission for the use of copyright material. The publisher apologizes for any<\/code><\/pre>\n<\/section>\n<p><strong>\u00a0Llama 3.2 1B<\/strong><\/p>\n<p>\u622a\u81f3\u672c\u6587\u64b0\u5199\u4e4b\u65f6\uff0cMeta AI \u7684\u6700\u65b0\u6a21\u578b\u662f\u6b64\u5904\u516c\u5e03\u7684 Llama 3.2 \u6a21\u578b\u3002<\/p>\n<p>Llama 3.2 \u6587\u672c\u6a21\u578b\u7684\u4ee3\u7801\u4e0e Llama 3.1 \u76f8\u4f3c\uff0c\u53ea\u662f\u7f29\u5c0f\u4e86\u6a21\u578b\u7684\u5927\u5c0f\uff08\u6709 1B \u548c 3B \u7248\u672c\uff09\u3002<\/p>\n<p>\u53e6\u4e00\u4e2a\u6548\u7387\u4e0a\u7684\u8c03\u6574\u662f\uff0c\u4ed6\u4eec\u53c8\u589e\u52a0\u4e86\u6743\u91cd\u7ed1\u5b9a\uff08GPT-2 \u67b6\u6784\u4e2d\u6700\u521d\u4f7f\u7528\u7684\u6982\u5ff5\uff09\uff1b\u5728\u8fd9\u91cc\uff0c\u4ed6\u4eec\u5728\u8f93\u5165\uff08token\uff09\u5d4c\u5165\u5c42\u548c\u8f93\u51fa\u5c42\u4e2d\u91cd\u590d\u4f7f\u7528\u76f8\u540c\u7684\u6743\u91cd\u53c2\u6570\u503c\u3002<\/p>\n<p>Llama 3.2 1B \u7684\u6a21\u578b\u4f53\u79ef\u5c0f\uff0c\u751a\u81f3\u53ef\u4ee5\u5728\u8bb8\u591a\u79fb\u52a8\u8bbe\u5907\u4e0a\u8fd0\u884c\uff0c\u56e0\u6b64\u975e\u5e38\u65b9\u4fbf\u3002<\/p>\n<p>Llama 3.1 8B \u548c Llama 3.2 1B \u5728\u7ed3\u6784\u4e0a\u7684\u5dee\u5f02\u5982\u4e0b\u56fe\u6240\u793a\uff1a<\/p>\n<p><a href=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-5d679b0d44acc7c460ef7f85f6b67336.png\" data-fancybox=\"images\" data-fancybox=\"gallery\"><img decoding=\"async\" src=\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-5d679b0d44acc7c460ef7f85f6b67336.png\"><\/a><\/p>\n<p>\u4ece\u4e0a\u56fe\u53ef\u4ee5\u770b\u51fa\uff0cLlama 3.1 8B \u548c Llama 3.2 1B \u67b6\u6784\u7684\u4e3b\u8981\u533a\u522b\u5728\u4e8e\u5404\u81ea\u7684\u5c3a\u5bf8\u3002<\/p>\n<p>\u4e00\u4e2a\u5c0f\u7684\u989d\u5916\u53d8\u5316\u662f\u589e\u52a0\u4e86 RoPE rescaling \u7cfb\u6570\uff0c\u8fd9\u53cd\u6620\u5728\u4e0b\u9762\u7684\u914d\u7f6e\u6587\u4ef6\u4e2d\uff1a<\/p>\n<section>\n<pre data-lang=\"makefile\"><code>LLAMA31_CONFIG_8B = {<\/code>\r\n<code>\"vocab_size\": 128_256,    # Vocabulary size<\/code>\r\n<code>\"context_length\": 8192,   # Context length<\/code>\r\n<code>\"emb_dim\": 4096,          # Embedding dimension<\/code>\r\n<code>\"n_heads\": 32,            # Number of attention heads<\/code>\r\n<code>\"n_layers\": 32,           # Number of layers<\/code>\r\n<code>\"hidden_dim\": 14_336,     # Size of the intermediate dimension in FeedForward<\/code>\r\n<code>\"n_kv_groups\": 8,         # Key-Value groups for grouped-query attention<\/code>\r\n<code>\"rope_base\": 50_000,      # The base in RoPE's \"theta\"<\/code>\r\n<code>\"dtype\": torch.bfloat16,  # Lower-precision dtype to save memory<\/code>\r\n<code>\"rope_freq\": {          # RoPE frequency scaling<\/code>\r\n<code>\"factor\": 8.0,<\/code>\r\n<code>\"low_freq_factor\": 1.0,<\/code>\r\n<code>\"high_freq_factor\": 4.0,<\/code>\r\n<code>\"original_context_length\": 8192,<\/code>\r\n<code>}<\/code>\r\n<code>}<\/code>\r\n<code>LLAMA32_CONFIG_1B = {<\/code>\r\n<code>\"vocab_size\": 128_256,    # Vocabulary size<\/code>\r\n<code>\"context_length\": 8192,   # Context length<\/code>\r\n<code>\"emb_dim\": 2048,          # NEW: Half the embedding dimension<\/code>\r\n<code>\"n_heads\": 32,            # Number of attention heads<\/code>\r\n<code>\"n_layers\": 16,           # NEW: Half the number of layers<\/code>\r\n<code>\"hidden_dim\": 8192,      # NEW: Almopst half the size of the intermediate dimension in FeedForward<\/code>\r\n<code>\"n_kv_groups\": 8,         # Key-Value groups for grouped-query attention<\/code>\r\n<code>\"rope_base\": 50_000,      # The base in RoPE's \"theta\"<\/code>\r\n<code>\"dtype\": torch.bfloat16,  # Lower-precision dtype to save memory<\/code>\r\n<code>\"rope_freq\": {            # RoPE frequency scaling<\/code>\r\n<code>\"factor\": 32.0,       # NEW: Adjustment of the rescaling factor<\/code>\r\n<code>\"low_freq_factor\": 1.0,<\/code>\r\n<code>\"high_freq_factor\": 4.0,<\/code>\r\n<code>\"original_context_length\": 8192,<\/code>\r\n<code>}<\/code>\r\n<code>}<\/code><\/pre>\n<\/section>\n<p>\u4e0b\u9762\uff0c\u6211\u4eec\u53ef\u4ee5\u91cd\u590d\u4f7f\u7528 Llama 3.1 8B \u90e8\u5206\u7684\u4ee3\u7801\u6765\u52a0\u8f7d Llama 3.2 1B \u6a21\u578b\u3002<\/p>\n<p>\u540c\u6837\uff0c\u7531\u4e8e Llama 3.2 \u7cfb\u5217\u6709\u522b\u4e8e Llama 3.1 \u7cfb\u5217\uff0c\u56e0\u6b64\u5fc5\u987b\u8bbf\u95ee meta-llama\/Llama-3.2-1B \u8f6f\u4ef6\u6e90\u5e76\u786e\u8ba4\u8bb8\u53ef\u6761\u6b3e\u3002<\/p>\n<p>\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u53ea\u5728\u4e0b\u9762\u52a0\u8f7d\u57fa\u672c\u6a21\u578b\uff0c\u4f46\u4e5f\u6709\u4e00\u4e2a\u7ecf\u8fc7\u6307\u4ee4\u5fae\u8c03\u7684\u7248\u672c\uff0c\u53ef\u4ee5\u7528 \u300cmeta-llama\/Llama-3.2-1B-Instruct\u300d\u66ff\u6362\u300cmeta-llama\/Llama-3.2-1B\u300d\u3002<\/p>\n<section>\n<pre data-lang=\"css\"><code># free up memory<\/code>\r\n<code>del model<\/code>\r\n<code>gc.collect()  # Run Python garbage collector<\/code>\r\n<code>if torch.cuda.is_available():<\/code>\r\n<code>torch.cuda.empty_cache()<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"makefile\"><code>tokenizer_file_path = hf_hub_download(<\/code>\r\n<code>repo_id=\"meta-llama\/Llama-3.2-1B\",<\/code>\r\n<code>filename=\"original\/tokenizer.model\",<\/code>\r\n<code>local_dir=\"llama32-files\"<\/code>\r\n<code>)<\/code>\r\n<code>tokenizer\u00a0=\u00a0Tokenizer(tokenizer_file_path)<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"makefile\"><code>model = Llama3Model(LLAMA32_CONFIG_1B)<\/code>\r\n<code>total_params = sum(p.numel() for p in model.parameters())<\/code>\r\n<code>print(f\"Total number of parameters: {total_params:,}\")<\/code>\r\n<code># Account for weight tying<\/code>\r\n<code>total_params_normalized = total_params - model.tok_emb.weight.numel()<\/code>\r\n<code>print(f\"nTotal\u00a0number\u00a0of\u00a0unique\u00a0parameters:\u00a0{total_params_normalized:,}\")<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"javascript\"><code>Total number of parameters: 1,498,482,688<\/code>\r\n<code>\r\n<\/code><code>Total number of unique parameters: 1,235,814,400<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"makefile\"><code>weights_file = hf_hub_download(<\/code>\r\n<code>repo_id=\"meta-llama\/Llama-3.2-1B\",<\/code>\r\n<code>filename=f\"model.safetensors\",<\/code>\r\n<code>local_dir=\"llama32-files\"<\/code>\r\n<code>)<\/code>\r\n<code>current_weights = load_file(weights_file)<\/code>\r\n<code>load_weights_into_llama(model, LLAMA32_CONFIG_1B, current_weights)<\/code>\r\n<code>model.to(device);<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"nginx\"><code>Model uses weight tying.<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"css\"><code>print(\"Weight\u00a0tying:\",\u00a0torch.equal(model.tok_emb.weight,\u00a0model.out_head.weight))<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"nginx\"><code>Weight tying: True<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"makefile\"><code>torch.manual_seed(123)<\/code>\r\n<code>token_ids = generate(<\/code>\r\n<code>model=model,<\/code>\r\n<code>idx=text_to_token_ids(\"Every effort\", <\/code><code>tokenizer).to(device),<\/code>\r\n<code>max_new_tokens=25,<\/code><code>context_size=LLAMA32_CONFIG_1B[\"context_length\"],<\/code>\r\n<code>top_k=1,<\/code>\r\n<code>temperature=0.<\/code>\r\n<code>)<\/code>\r\n<code>print(\"Output\u00a0text:n\",\u00a0token_ids_to_text(token_ids,\u00a0tokenizer))<\/code><\/pre>\n<\/section>\n<section>\n<pre data-lang=\"cs\"><code>Output text:<\/code><code>\u00a0<\/code>\r\n<code>Every effort is made to ensure that the information on this website is accurate. However, we cannot guarantee that the information is accurate, complete<\/code><\/pre>\n<\/section>\n<p><em><sup>\u539f\u6587\u94fe\u63a5\uff1a<\/sup><\/em><\/p>\n<p><em><sup>https:\/\/github.com\/rasbt\/LLMs-from-scratch\/blob\/main\/ch05\/07_gpt_to_llama\/converting-llama2-to-llama3.ipynb<\/sup><\/em><\/p>\n<p>\u6587\u7ae0\u6765\u6e90\u4e8e\u4e92\u8054\u7f51:<a href=\"https:\/\/www.jiqizhixin.com\/articles\/2024-10-06-3\" target=\"_blank\">Sebastian Raschka\u6700\u65b0\u535a\u5ba2\uff1a\u4ece\u5934\u5f00\u59cb\uff0c\u7528Llama 2\u6784\u5efaLlama 3.2<\/a><\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u6587\u7ae0\u6765\u6e90\u4e8e\u4e92\u8054\u7f51:Sebastian R [&hellip;]<\/p>\n","protected":false},"author":3,"featured_media":0,"comment_status":"open","ping_status":"","sticky":false,"template":"","format":"standard","meta":{"site-sidebar-layout":"default","site-content-layout":"","ast-site-content-layout":"","site-content-style":"default","site-sidebar-style":"default","ast-global-header-display":"","ast-banner-title-visibility":"","ast-main-header-display":"","ast-hfb-above-header-display":"","ast-hfb-below-header-display":"","ast-hfb-mobile-header-display":"","site-post-title":"","ast-breadcrumbs-content":"","ast-featured-img":"","footer-sml-layout":"","theme-transparent-header-meta":"","adv-header-id-meta":"","stick-header-meta":"","header-above-stick-meta":"","header-main-stick-meta":"","header-below-stick-meta":"","astra-migrate-meta-layouts":"default","ast-page-background-enabled":"default","ast-page-background-meta":{"desktop":{"background-color":"var(--ast-global-color-4)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"ast-content-background-meta":{"desktop":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"footnotes":""},"categories":[27],"tags":[71,55],"class_list":["post-35919","post","type-post","status-publish","format-standard","hentry","category-news","tag-rag","tag-55"],"yoast_head":"<!-- This site is optimized with the Yoast SEO plugin v26.4 - https:\/\/yoast.com\/wordpress\/plugins\/seo\/ -->\n<title>Sebastian Raschka\u6700\u65b0\u535a\u5ba2\uff1a\u4ece\u5934\u5f00\u59cb\uff0c\u7528Llama 2\u6784\u5efaLlama 3.2 - \u4e00\u8d77AI\u6280\u672f<\/title>\n<meta name=\"robots\" content=\"index, follow, max-snippet:-1, max-image-preview:large, max-video-preview:-1\" \/>\n<link rel=\"canonical\" href=\"https:\/\/17aitech.com\/?p=35919\" \/>\n<script type=\"application\/ld+json\" class=\"yoast-schema-graph\">{\"@context\":\"https:\/\/schema.org\",\"@graph\":[{\"@type\":\"WebPage\",\"@id\":\"https:\/\/17aitech.com\/?p=35919\",\"url\":\"https:\/\/17aitech.com\/?p=35919\",\"name\":\"Sebastian Raschka\u6700\u65b0\u535a\u5ba2\uff1a\u4ece\u5934\u5f00\u59cb\uff0c\u7528Llama 2\u6784\u5efaLlama 3.2 - \u4e00\u8d77AI\u6280\u672f\",\"isPartOf\":{\"@id\":\"https:\/\/17aitech.com\/#website\"},\"primaryImageOfPage\":{\"@id\":\"https:\/\/17aitech.com\/?p=35919#primaryimage\"},\"image\":{\"@id\":\"https:\/\/17aitech.com\/?p=35919#primaryimage\"},\"thumbnailUrl\":\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-3de05da944f0dd31562d7b81893c4d7f.png\",\"datePublished\":\"2025-01-02T20:00:24+00:00\",\"author\":{\"@id\":\"https:\/\/17aitech.com\/#\/schema\/person\/60225458499e817ae0af73e67e440b9d\"},\"breadcrumb\":{\"@id\":\"https:\/\/17aitech.com\/?p=35919#breadcrumb\"},\"inLanguage\":\"zh-Hans\",\"potentialAction\":[{\"@type\":\"ReadAction\",\"target\":[\"https:\/\/17aitech.com\/?p=35919\"]}]},{\"@type\":\"ImageObject\",\"inLanguage\":\"zh-Hans\",\"@id\":\"https:\/\/17aitech.com\/?p=35919#primaryimage\",\"url\":\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-3de05da944f0dd31562d7b81893c4d7f.png\",\"contentUrl\":\"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-3de05da944f0dd31562d7b81893c4d7f.png\"},{\"@type\":\"BreadcrumbList\",\"@id\":\"https:\/\/17aitech.com\/?p=35919#breadcrumb\",\"itemListElement\":[{\"@type\":\"ListItem\",\"position\":1,\"name\":\"\u9996\u9875\",\"item\":\"https:\/\/17aitech.com\/\"},{\"@type\":\"ListItem\",\"position\":2,\"name\":\"Sebastian Raschka\u6700\u65b0\u535a\u5ba2\uff1a\u4ece\u5934\u5f00\u59cb\uff0c\u7528Llama 2\u6784\u5efaLlama 3.2\"}]},{\"@type\":\"WebSite\",\"@id\":\"https:\/\/17aitech.com\/#website\",\"url\":\"https:\/\/17aitech.com\/\",\"name\":\"\u4e00\u8d77AI\u6280\u672f\",\"description\":\"\u8ba9AI\u77e5\u8bc6\u89e6\u624b\u53ef\u53ca\",\"alternateName\":\"\u4e00\u8d77AI\u6280\u672f\",\"potentialAction\":[{\"@type\":\"SearchAction\",\"target\":{\"@type\":\"EntryPoint\",\"urlTemplate\":\"https:\/\/17aitech.com\/?s={search_term_string}\"},\"query-input\":{\"@type\":\"PropertyValueSpecification\",\"valueRequired\":true,\"valueName\":\"search_term_string\"}}],\"inLanguage\":\"zh-Hans\"},{\"@type\":\"Person\",\"@id\":\"https:\/\/17aitech.com\/#\/schema\/person\/60225458499e817ae0af73e67e440b9d\",\"name\":\"AI\u5c0f\u52a9\u624b\",\"image\":{\"@type\":\"ImageObject\",\"inLanguage\":\"zh-Hans\",\"@id\":\"https:\/\/17aitech.com\/#\/schema\/person\/image\/\",\"url\":\"\/\/17aitech.com\/wp-content\/uploads\/2024\/04\/robot_3.png\",\"contentUrl\":\"\/\/17aitech.com\/wp-content\/uploads\/2024\/04\/robot_3.png\",\"caption\":\"AI\u5c0f\u52a9\u624b\"},\"description\":\"\u8fd9\u4e2a\u4eba\u5f88\u61d2\uff0c\u4ec0\u4e48\u90fd\u6ca1\u6709\u7559\u4e0b\uff5e\",\"url\":\"https:\/\/17aitech.com\/?page_id=33738&user=3\"}]}<\/script>\n<!-- \/ Yoast SEO plugin. -->","yoast_head_json":{"title":"Sebastian Raschka\u6700\u65b0\u535a\u5ba2\uff1a\u4ece\u5934\u5f00\u59cb\uff0c\u7528Llama 2\u6784\u5efaLlama 3.2 - \u4e00\u8d77AI\u6280\u672f","robots":{"index":"index","follow":"follow","max-snippet":"max-snippet:-1","max-image-preview":"max-image-preview:large","max-video-preview":"max-video-preview:-1"},"canonical":"https:\/\/17aitech.com\/?p=35919","schema":{"@context":"https:\/\/schema.org","@graph":[{"@type":"WebPage","@id":"https:\/\/17aitech.com\/?p=35919","url":"https:\/\/17aitech.com\/?p=35919","name":"Sebastian Raschka\u6700\u65b0\u535a\u5ba2\uff1a\u4ece\u5934\u5f00\u59cb\uff0c\u7528Llama 2\u6784\u5efaLlama 3.2 - \u4e00\u8d77AI\u6280\u672f","isPartOf":{"@id":"https:\/\/17aitech.com\/#website"},"primaryImageOfPage":{"@id":"https:\/\/17aitech.com\/?p=35919#primaryimage"},"image":{"@id":"https:\/\/17aitech.com\/?p=35919#primaryimage"},"thumbnailUrl":"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-3de05da944f0dd31562d7b81893c4d7f.png","datePublished":"2025-01-02T20:00:24+00:00","author":{"@id":"https:\/\/17aitech.com\/#\/schema\/person\/60225458499e817ae0af73e67e440b9d"},"breadcrumb":{"@id":"https:\/\/17aitech.com\/?p=35919#breadcrumb"},"inLanguage":"zh-Hans","potentialAction":[{"@type":"ReadAction","target":["https:\/\/17aitech.com\/?p=35919"]}]},{"@type":"ImageObject","inLanguage":"zh-Hans","@id":"https:\/\/17aitech.com\/?p=35919#primaryimage","url":"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-3de05da944f0dd31562d7b81893c4d7f.png","contentUrl":"https:\/\/17aitech.com\/wp-content\/uploads\/2024\/10\/frc-3de05da944f0dd31562d7b81893c4d7f.png"},{"@type":"BreadcrumbList","@id":"https:\/\/17aitech.com\/?p=35919#breadcrumb","itemListElement":[{"@type":"ListItem","position":1,"name":"\u9996\u9875","item":"https:\/\/17aitech.com\/"},{"@type":"ListItem","position":2,"name":"Sebastian Raschka\u6700\u65b0\u535a\u5ba2\uff1a\u4ece\u5934\u5f00\u59cb\uff0c\u7528Llama 2\u6784\u5efaLlama 3.2"}]},{"@type":"WebSite","@id":"https:\/\/17aitech.com\/#website","url":"https:\/\/17aitech.com\/","name":"\u4e00\u8d77AI\u6280\u672f","description":"\u8ba9AI\u77e5\u8bc6\u89e6\u624b\u53ef\u53ca","alternateName":"\u4e00\u8d77AI\u6280\u672f","potentialAction":[{"@type":"SearchAction","target":{"@type":"EntryPoint","urlTemplate":"https:\/\/17aitech.com\/?s={search_term_string}"},"query-input":{"@type":"PropertyValueSpecification","valueRequired":true,"valueName":"search_term_string"}}],"inLanguage":"zh-Hans"},{"@type":"Person","@id":"https:\/\/17aitech.com\/#\/schema\/person\/60225458499e817ae0af73e67e440b9d","name":"AI\u5c0f\u52a9\u624b","image":{"@type":"ImageObject","inLanguage":"zh-Hans","@id":"https:\/\/17aitech.com\/#\/schema\/person\/image\/","url":"\/\/17aitech.com\/wp-content\/uploads\/2024\/04\/robot_3.png","contentUrl":"\/\/17aitech.com\/wp-content\/uploads\/2024\/04\/robot_3.png","caption":"AI\u5c0f\u52a9\u624b"},"description":"\u8fd9\u4e2a\u4eba\u5f88\u61d2\uff0c\u4ec0\u4e48\u90fd\u6ca1\u6709\u7559\u4e0b\uff5e","url":"https:\/\/17aitech.com\/?page_id=33738&user=3"}]}},"_links":{"self":[{"href":"https:\/\/17aitech.com\/index.php?rest_route=\/wp\/v2\/posts\/35919","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/17aitech.com\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/17aitech.com\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/17aitech.com\/index.php?rest_route=\/wp\/v2\/users\/3"}],"replies":[{"embeddable":true,"href":"https:\/\/17aitech.com\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=35919"}],"version-history":[{"count":0,"href":"https:\/\/17aitech.com\/index.php?rest_route=\/wp\/v2\/posts\/35919\/revisions"}],"wp:attachment":[{"href":"https:\/\/17aitech.com\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=35919"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/17aitech.com\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=35919"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/17aitech.com\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=35919"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}