{"id":25918,"date":"2025-05-20T09:18:00","date_gmt":"2025-05-20T01:18:00","guid":{"rendered":"https:\/\/aif.amtbbs.org\/?p=25918"},"modified":"2025-05-20T09:18:00","modified_gmt":"2025-05-20T01:18:00","slug":"transformer-%e6%a8%a1%e5%9e%8b%e7%bb%93%e6%9e%84%e8%af%a6%e8%a7%a3%e5%8f%8a%e4%bb%a3%e7%a0%81%e5%ae%9e%e7%8e%b0","status":"publish","type":"post","link":"https:\/\/aif.amtbbs.org\/index.php\/2025\/05\/20\/25918\/","title":{"rendered":"Transformer \u6a21\u578b\u7ed3\u6784\u8be6\u89e3\u53ca\u4ee3\u7801\u5b9e\u73b0!"},"content":{"rendered":"<div>\n<figure id=\"attachment_25919\" aria-describedby=\"caption-attachment-25919\" style=\"width: 300px\" class=\"wp-caption alignnone\"><img data-dominant-color=\"525b5e\" data-has-transparency=\"false\" style=\"--dominant-color: #525b5e;\" loading=\"lazy\" decoding=\"async\" class=\"not-transparent alignnone size-full wp-image-25920\" src=\"https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/c4d52b46052fa081dc3556de71608e061b9fbd-2-300x167-1.jpg\" width=\"300\" height=\"167\" alt=\"\" srcset=\"https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/c4d52b46052fa081dc3556de71608e061b9fbd-2-300x167-1.jpg 300w, https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/c4d52b46052fa081dc3556de71608e061b9fbd-2-300x167-1-150x84.jpg 150w\" sizes=\"auto, (max-width: 300px) 100vw, 300px\" \/><figcaption id=\"caption-attachment-25919\" class=\"wp-caption-text\">3D Rendering futuristic robot technology development, artificial intelligence AI, and machine learning concept. Global robotic bionic science research for future of human life.<\/figcaption><\/figure>\n<\/div>\n<div><\/div>\n<div class=\"article-desc\">Transformer \u9ed8\u8ba4\u90fd\u662f\u5927\u6a21\u578b\uff0c\u9664\u4e86\u4e00\u4e9b\u7279\u4f8b\uff08\u5982 DistilBERT\uff09\u5916\uff0c\u5b9e\u73b0\u66f4\u597d\u6027\u80fd\u7684\u4e00\u822c\u7b56\u7565\u662f\u589e\u52a0\u6a21\u578b\u7684\u5927\u5c0f\u4ee5\u53ca\u9884\u8bad\u7ec3\u7684\u6570\u636e\u91cf\u3002<\/div>\n<div id=\"postspictures\" class=\"article-content\">\n<div id=\"container\" class=\"container am-engine\" data-v-1d7a5742=\"\" data-element=\"root\">\n<h3>\u4e00\u3001Transformer\u7b80\u8981\u53d1\u5c55\u53f2<\/h3>\n<p>\u4ee5\u4e0b\u662fTransformer\u6a21\u578b\u53d1\u5c55\u5386\u53f2\u4e2d\u7684\u5173\u952e\u8282\u70b9\uff1a<\/p>\n<p>&nbsp;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s2.51cto.com\/oss\/202505\/19\/66c30a8834c517f33683165a9247ccf5e7150b.webp\" data-type=\"block\" \/><\/p>\n<p>&nbsp;<\/p>\n<p>Transformer\u67b6\u6784\u4e8e2017\u5e746\u6708\u63a8\u51fa\u3002\u539f\u672c\u7814\u7a76\u7684\u91cd\u70b9\u662f\u7ffb\u8bd1\u4efb\u52a1\u3002\u968f\u540e\u63a8\u51fa\u4e86\u51e0\u4e2a\u6709\u5f71\u54cd\u529b\u7684\u6a21\u578b\uff0c\u5305\u62ec\uff1a<\/p>\n<table class=\"data-table\" data-transient-attributes=\"class\" data-width=\"655.99px\">\n<colgroup data-id=\"c7104f7d-RJMPMBSS\">\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-OIPaW1MM\" \/>\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-IUa8X6AL\" \/>\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-6PLUTrOS\" \/><\/colgroup>\n<tbody data-id=\"t6d5e859-RTFl2YRf\">\n<tr data-id=\"t31e458f-nOZJM3lo\">\n<td data-id=\"t6267798-XDHBDBMB\" data-transient-attributes=\"table-cell-selection\">\u65f6\u95f4<\/td>\n<td data-id=\"t6267798-1iNOO2VV\" data-transient-attributes=\"table-cell-selection\">\u6a21\u578b<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-lCc89Umq\" data-transient-attributes=\"table-cell-selection\">\u7b80\u8981\u8bf4\u660e<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-MKUkHIDm\">\n<td data-id=\"t6267798-Yx4bNuHG\" data-transient-attributes=\"table-cell-selection\">2017 \u5e74 6 \u6708<\/td>\n<td data-id=\"t6267798-LmD5Qq99\" data-transient-attributes=\"table-cell-selection\">\u300cTransformer\u300d<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-R3KHNSif\" data-transient-attributes=\"table-cell-selection\">Google \u9996\u6b21\u63d0\u51fa\u57fa\u4e8e Attention \u7684\u6a21\u578b\uff0c\u7528\u4e8e\u673a\u5668\u7ffb\u8bd1\u4efb\u52a1<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-ZBK9TP99\">\n<td data-id=\"t6267798-HVh3Dmjd\" data-transient-attributes=\"table-cell-selection\">2018 \u5e74 6 \u6708<\/td>\n<td data-id=\"t6267798-IyM6Thyw\" data-transient-attributes=\"table-cell-selection\">\u300cGPT\u300d<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-iPcoiEUG\" data-transient-attributes=\"table-cell-selection\">\u7b2c\u4e00\u4e2a\u4f7f\u7528 Transformer \u89e3\u7801\u5668\u6a21\u5757\u8fdb\u884c\u9884\u8bad\u7ec3\u7684\u8bed\u8a00\u6a21\u578b\uff0c\u9002\u7528\u4e8e\u591a\u79cd NLP \u4efb\u52a1<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-Eefhj7qJ\">\n<td data-id=\"t6267798-wvXjUBqD\" data-transient-attributes=\"table-cell-selection\">2018 \u5e74 10 \u6708<\/td>\n<td data-id=\"t6267798-fMtsn9Rh\" data-transient-attributes=\"table-cell-selection\">\u300cBERT\u300d<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-dnJ7Gf7U\" data-transient-attributes=\"table-cell-selection\">\u4f7f\u7528 Transformer \u7f16\u7801\u5668\u6a21\u5757\uff0c\u901a\u8fc7\u63a9\u7801\u8bed\u8a00\u5efa\u6a21\u751f\u6210\u66f4\u5f3a\u5927\u7684\u53e5\u5b50\u8868\u793a<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-IimGrkKb\">\n<td data-id=\"t6267798-NBF2u4DD\" data-transient-attributes=\"table-cell-selection\">2019 \u5e74 2 \u6708<\/td>\n<td data-id=\"t6267798-wEFVGrBv\" data-transient-attributes=\"table-cell-selection\">\u300cGPT-2\u300d<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-7kH2TVMi\" data-transient-attributes=\"table-cell-selection\">\u66f4\u5927\u66f4\u5f3a\u7684 GPT \u7248\u672c\uff0c\u7531\u4e8e\u6f5c\u5728\u98ce\u9669\u672a\u7acb\u5373\u53d1\u5e03\uff0c\u5177\u5907\u51fa\u8272\u7684\u6587\u672c\u751f\u6210\u80fd\u529b<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-nJuDOoOK\">\n<td data-id=\"t6267798-jouz9Opd\" data-transient-attributes=\"table-cell-selection\">2019 \u5e74 10 \u6708<\/td>\n<td data-id=\"t6267798-TYZVZwXT\" data-transient-attributes=\"table-cell-selection\">\u300cDistilBERT\u300d<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-s0rExEIg\" data-transient-attributes=\"table-cell-selection\">BERT \u7684\u8f7b\u91cf\u5316\u7248\u672c\uff0c\u5728\u4fdd\u7559 97% \u6027\u80fd\u7684\u540c\u65f6\uff0c\u901f\u5ea6\u66f4\u5feb\u3001\u5185\u5b58\u5360\u7528\u66f4\u4f4e<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-LJfFpXYa\">\n<td data-id=\"t6267798-drpqOYhE\" data-transient-attributes=\"table-cell-selection\">2019 \u5e74 10 \u6708<\/td>\n<td data-id=\"t6267798-PPcnZy45\" data-transient-attributes=\"table-cell-selection\">\u300cBART\u3001T5\u300d<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-LDG8LBAL\" data-transient-attributes=\"table-cell-selection\">\u4f7f\u7528\u5b8c\u6574\u7684 Encoder-Decoder \u67b6\u6784\uff0c\u5728\u5404\u79cd NLP \u4efb\u52a1\u4e2d\u8868\u73b0\u4f18\u5f02<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-lcFf7Orx\">\n<td class=\"table-last-column\" data-id=\"t6267798-egg9YiAX\" data-transient-attributes=\"table-cell-selection\">2020 \u5e74 5 \u6708<\/td>\n<td class=\"table-last-column\" data-id=\"t6267798-ndjGBHTY\" data-transient-attributes=\"table-cell-selection\">\u300cGPT-3\u300d<\/td>\n<td class=\"table-last-column table-last-row\" data-id=\"t6267798-usAuypQs\" data-transient-attributes=\"table-cell-selection\">\u8d85\u5927\u89c4\u6a21\u8bed\u8a00\u6a21\u578b\uff0c\u652f\u6301\u201c\u96f6\u6837\u672c\u5b66\u4e60\u201d\uff0c\u65e0\u9700\u5fae\u8c03\u5373\u53ef\u5b8c\u6210\u65b0\u4efb\u52a1<\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<p>\u8fd9\u4e2a\u5217\u8868\u5e76\u4e0d\u5168\u9762\uff0c\u53ea\u662f\u4e3a\u4e86\u7a81\u51fa\u4e00\u4e9b\u4e0d\u540c\u7c7b\u578b\u7684 Transformer \u6a21\u578b\u3002\u5927\u4f53\u4e0a\uff0c\u5b83\u4eec\u53ef\u4ee5\u5206\u4e3a\u4e09\u7c7b\uff1a<\/p>\n<table class=\"data-table\" data-transient-attributes=\"class\" data-width=\"655.99px\">\n<colgroup data-id=\"c7104f7d-EPZYDVSC\">\n<col span=\"1\" width=\"163.993\" data-id=\"cd89ecb0-VqPFNIoM\" \/>\n<col span=\"1\" width=\"163.993\" data-id=\"cd89ecb0-XCZWF0gA\" \/>\n<col span=\"1\" width=\"163.993\" data-id=\"cd89ecb0-MLkOY0u7\" \/>\n<col span=\"1\" width=\"164.01\" data-id=\"cd89ecb0-WQNcMLUG\" \/><\/colgroup>\n<tbody data-id=\"t6d5e859-UiLWZdXK\">\n<tr data-id=\"t31e458f-OcVDFulm\">\n<td data-id=\"t6267798-n16ciDq3\" data-transient-attributes=\"table-cell-selection\">\u7c7b\u522b<\/td>\n<td data-id=\"t6267798-gmO8Qghp\" data-transient-attributes=\"table-cell-selection\">\u6784\u6210<\/td>\n<td data-id=\"t6267798-oBQdVHcf\" data-transient-attributes=\"table-cell-selection\">\u7279\u70b9<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-Zm28hIx3\" data-transient-attributes=\"table-cell-selection\">\u5178\u578b\u6a21\u578b<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-mzijogs1\">\n<td data-id=\"t6267798-0IXtKtme\" data-transient-attributes=\"table-cell-selection\">\u300cGPT-like\u300d<\/p>\n<p>\uff08\u81ea\u56de\u5f52 Transformer\uff09<\/td>\n<td data-id=\"t6267798-85ZVUSGS\" data-transient-attributes=\"table-cell-selection\">\u53ea\u4f7f\u7528\u89e3\u7801\u5668<\/td>\n<td data-id=\"t6267798-PmozPC6s\" data-transient-attributes=\"table-cell-selection\">\u81ea\u56de\u5f52\u65b9\u5f0f\u9884\u6d4b\u4e0b\u4e00\u4e2a\u8bcd\uff0c\u9002\u5408\u6587\u672c\u751f\u6210\u4efb\u52a1<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-iwVbf7Rd\" data-transient-attributes=\"table-cell-selection\">GPT\u3001GPT-2\u3001GPT-3<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-fCqtYc11\">\n<td data-id=\"t6267798-Kn4Dgww9\" data-transient-attributes=\"table-cell-selection\">\u300cBERT-like\u300d<\/p>\n<p>\uff08\u81ea\u52a8\u7f16\u7801 Transformer\uff09<\/td>\n<td data-id=\"t6267798-SvYILzsu\" data-transient-attributes=\"table-cell-selection\">\u53ea\u4f7f\u7528\u7f16\u7801\u5668<\/td>\n<td data-id=\"t6267798-vVJAHM77\" data-transient-attributes=\"table-cell-selection\">\u63a9\u7801\u673a\u5236\u5b66\u4e60\u4e0a\u4e0b\u6587\u8868\u793a\uff0c\u9002\u5408\u7406\u89e3\u7c7b\u4efb\u52a1\u5982\u95ee\u7b54\u3001\u60c5\u611f\u5206\u6790<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-IyZqOymt\" data-transient-attributes=\"table-cell-selection\">BERT\u3001RoBERTa\u3001DistilBERT<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-WWWeOlxM\">\n<td class=\"table-last-column\" data-id=\"t6267798-reP65UTO\" data-transient-attributes=\"table-cell-selection\">\u300cBART\/T5-like\u300d<\/p>\n<p>\uff08\u5e8f\u5217\u5230\u5e8f\u5217 Transformer\uff09<\/td>\n<td class=\"table-last-column\" data-id=\"t6267798-YL3vlYwZ\" data-transient-attributes=\"table-cell-selection\">\u7f16\u7801\u5668 + \u89e3\u7801\u5668<\/td>\n<td class=\"table-last-column\" data-id=\"t6267798-H1Ryy2pf\" data-transient-attributes=\"table-cell-selection\">\u5b8c\u6574\u7684 encoder-decoder \u67b6\u6784\uff0c\u9002\u5408\u7ffb\u8bd1\u3001\u6458\u8981\u7b49\u751f\u6210+\u7406\u89e3\u7ed3\u5408\u7684\u4efb\u52a1<\/td>\n<td class=\"table-last-column table-last-row\" data-id=\"t6267798-YqFhwRQo\" data-transient-attributes=\"table-cell-selection\">BART\u3001T5<\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<p>Transformer \u9ed8\u8ba4\u90fd\u662f\u5927\u6a21\u578b\uff0c\u9664\u4e86\u4e00\u4e9b\u7279\u4f8b\uff08\u5982 DistilBERT\uff09\u5916\uff0c\u5b9e\u73b0\u66f4\u597d\u6027\u80fd\u7684\u4e00\u822c\u7b56\u7565\u662f\u589e\u52a0\u6a21\u578b\u7684\u5927\u5c0f\u4ee5\u53ca\u9884\u8bad\u7ec3\u7684\u6570\u636e\u91cf\u3002\u5176\u4e2d\uff0cGPT-2 \u662f\u4f7f\u7528\u300ctransformer \u89e3\u7801\u5668\u6a21\u5757\u300d\u6784\u5efa\u7684\uff0c\u800c BERT \u5219\u662f\u901a\u8fc7\u300ctransformer \u7f16\u7801\u5668\u300d\u6a21\u5757\u6784\u5efa\u7684\u3002<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s5.51cto.com\/oss\/202505\/19\/c1fd639020bc2a7b0bc2015e51ac0c96a1cddc.webp\" data-type=\"block\" \/><\/p>\n<p>&nbsp;<\/p>\n<h3>\u4e8c\u3001Transformer \u6574\u4f53\u67b6\u6784<\/h3>\n<p>\u8bba\u6587\u4e2d\u7ed9\u51fa\u7528\u4e8e\u4e2d\u82f1\u6587\u7ffb\u8bd1\u4efb\u52a1\u7684 Transformer \u6574\u4f53\u67b6\u6784\u5982\u4e0b\u56fe\u6240\u793a\uff1a<\/p>\n<p>&nbsp;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s3.51cto.com\/oss\/202505\/19\/c40fe7e95d7adcfb82a23050e5847be3a70e44.webp\" data-type=\"block\" \/><\/p>\n<p>\u53ef\u4ee5\u770b\u51faTransformer\u67b6\u6784\u7531Encoder\u548cDecoder\u4e24\u4e2a\u90e8\u5206\u7ec4\u6210\uff1a\u5176\u4e2dEncoder\u548cDecoder\u90fd\u662f\u7531N=6\u4e2a\u76f8\u540c\u7684\u5c42\u5806\u53e0\u800c\u6210\u3002Multi-Head Attention \u7ed3\u6784\u662f Transformer \u67b6\u6784\u7684\u6838\u5fc3\u7ed3\u6784\uff0c\u5176\u7531\u591a\u4e2a Self-Attention \u7ec4\u6210\u7684\u3002\u5176\u4e2d\uff0c<\/p>\n<table class=\"data-table\" data-transient-attributes=\"class\" data-width=\"655.99px\">\n<colgroup data-id=\"c7104f7d-1MITTjCo\">\n<col span=\"1\" width=\"163.993\" data-id=\"cd89ecb0-HFNCJar0\" \/>\n<col span=\"1\" width=\"163.993\" data-id=\"cd89ecb0-6thc2fYm\" \/>\n<col span=\"1\" width=\"163.993\" data-id=\"cd89ecb0-vtoBUByf\" \/>\n<col span=\"1\" width=\"164.01\" data-id=\"cd89ecb0-U07m5i09\" \/><\/colgroup>\n<tbody data-id=\"t6d5e859-pc3pho6k\">\n<tr data-id=\"t31e458f-C9w6ygBY\">\n<td data-id=\"t6267798-djDNYWWf\" data-transient-attributes=\"table-cell-selection\">\u90e8\u4ef6<\/td>\n<td data-id=\"t6267798-M5xsIgf5\" data-transient-attributes=\"table-cell-selection\">\u7ed3\u6784<\/td>\n<td data-id=\"t6267798-SlBQFPrD\" data-transient-attributes=\"table-cell-selection\">\u5c42\u6570<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-UJhmhAEh\" data-transient-attributes=\"table-cell-selection\">\u4e3b\u8981\u6a21\u5757<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-oLdKNPiG\">\n<td data-id=\"t6267798-RuCvXNGR\" data-transient-attributes=\"table-cell-selection\">Encoder<\/td>\n<td data-id=\"t6267798-z0bD7ffS\" data-transient-attributes=\"table-cell-selection\">\u7f16\u7801\u5668\u5c42\u5806\u53e0<\/td>\n<td data-id=\"t6267798-evPwthSS\" data-transient-attributes=\"table-cell-selection\">N=6\u5c42<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-Hwi9Zj7N\" data-transient-attributes=\"table-cell-selection\">Self-Attention+Feed Forward<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-QLz82eM4\">\n<td class=\"table-last-column\" data-id=\"t6267798-J4cAzNuo\" data-transient-attributes=\"table-cell-selection\">Decoder<\/td>\n<td class=\"table-last-column\" data-id=\"t6267798-Tft1vOjQ\" data-transient-attributes=\"table-cell-selection\">\u89e3\u7801\u5668\u5c42\u5806\u53e0<\/td>\n<td class=\"table-last-column\" data-id=\"t6267798-s12eBifq\" data-transient-attributes=\"table-cell-selection\">N=6\u5c42<\/td>\n<td class=\"table-last-column table-last-row\" data-id=\"t6267798-peD0lkH2\" data-transient-attributes=\"table-cell-selection\">Self-Attention+Encoder-Decoder Attention+Feed Forward<\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<p>Transformer \u67b6\u6784\u66f4\u8be6\u7ec6\u7684\u53ef\u89c6\u5316\u56fe\u5982\u4e0b\u6240\u793a:<\/p>\n<p>&nbsp;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s7.51cto.com\/oss\/202505\/19\/d882faa859c01d9318a980405228dcd000ea82.webp\" data-type=\"block\" \/><\/p>\n<h4>1. \u8f93\u5165\u6a21\u5757<\/h4>\n<p><strong>(1) Tokenizer\u9884\u5904\u7406<\/strong><\/p>\n<p>\u5728\u57fa\u4e8eTransformer\u7684\u5927\u6a21\u578bLLM\u4e2d\uff0c\u8f93\u5165\u901a\u5e38\u4e3a\u5b57\u7b26\u4e32\u6587\u672c\u3002\u7531\u4e8e\u6a21\u578b\u65e0\u6cd5\u76f4\u63a5\u5904\u7406\u81ea\u7136\u8bed\u8a00\uff0c\u56e0\u6b64\u9700\u8981\u501f\u52a9Tokenizer\u5bf9\u8f93\u5165\u8fdb\u884c\u9884\u5904\u7406\u3002\u5177\u4f53\u6d41\u7a0b\u5982\u4e0b\uff1a<\/p>\n<ul data-id=\"u738a58b-cpMqvUsE\">\n<li data-id=\"ld70c578-nweVDiuA\">\u5206\u8bcd\uff08Tokenization\uff09\uff1a\u5c06\u8f93\u5165\u6587\u672c\u6309\u89c4\u5219\u5207\u5206\u4e3a\u4e00\u4e2a\u4e2a\u8bcd\u5143\uff08token\uff09\uff0c\u5982\u5355\u8bcd\u3001\u5b50\u8bcd\u6216\u7279\u6b8a\u7b26\u53f7\u3002<\/li>\n<li data-id=\"ld70c578-cmimb3Aw\">\u8bcd\u8868\u6620\u5c04\uff08Vocabulary Mapping\uff09\uff1a\u6bcf\u4e2a token \u88ab\u6620\u5c04\u5230\u4e00\u4e2a\u552f\u4e00\u7684\u6574\u6570 ID\uff0c\u8be5 ID \u6765\u81ea\u9884\u8bad\u7ec3\u6a21\u578b\u6240\u4f7f\u7528\u7684\u8bcd\u6c47\u8868\u3002<\/li>\n<li data-id=\"ld70c578-Bw9ku278\">\u751f\u6210 input_ids \u5411\u91cf\uff08\u77e9\u9635\uff09\uff1a\u6700\u7ec8\u8f93\u51fa\u662f\u4e00\u4e2a\u7531 token ID \u6784\u6210\u7684\u5411\u91cf\uff08\u6216\u77e9\u9635\uff09\uff0c\u4f5c\u4e3a\u6a21\u578b\u8f93\u5165\u3002<\/li>\n<\/ul>\n<p>\u4ee5\u4e0b\u662f\u4ee5 Hugging Face \u7684 transformers \u5e93\u4e3a\u4f8b\uff0c\u5c55\u793a\u5982\u4f55\u4f7f\u7528 BertTokenizer \u548c BertModel \u5b8c\u6210\u8f93\u5165\u6587\u672c\u7684\u9884\u5904\u7406\u548c\u7f16\u7801\uff1a<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_0\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\">from transformers <span class=\"token keyword\">import<\/span> BertTokenizer<span class=\"token punctuation\">,<\/span> BertModel\r\n<span class=\"token keyword\">import<\/span> torch\r\n\r\n# <span class=\"token number\">1.<\/span> \u52a0\u8f7d\u9884\u8bad\u7ec3\u7684 <span class=\"token constant\">BERT<\/span> tokenizer \u548c\u6a21\u578b\r\ntokenizer <span class=\"token operator\">=<\/span> BertTokenizer<span class=\"token punctuation\">.<\/span><span class=\"token function\">from_pretrained<\/span><span class=\"token punctuation\">(<\/span><span class=\"token string\">'bert-base-uncased'<\/span><span class=\"token punctuation\">)<\/span>\r\nmodel <span class=\"token operator\">=<\/span> BertModel<span class=\"token punctuation\">.<\/span><span class=\"token function\">from_pretrained<\/span><span class=\"token punctuation\">(<\/span><span class=\"token string\">'bert-base-uncased'<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n# <span class=\"token number\">2.<\/span> \u8f93\u5165\u6587\u672c\r\ntext <span class=\"token operator\">=<\/span> <span class=\"token string\">\"A Titan RTX has 24GB of VRAM\"<\/span>\r\n\r\n# <span class=\"token number\">3.<\/span> \u5206\u8bcd\u5e76\u6620\u5c04\u4e3a token <span class=\"token constant\">ID<\/span> \u5e8f\u5217\r\ninputs <span class=\"token operator\">=<\/span> <span class=\"token function\">tokenizer<\/span><span class=\"token punctuation\">(<\/span>text<span class=\"token punctuation\">,<\/span> return_tensors<span class=\"token operator\">=<\/span><span class=\"token string\">\"pt\"<\/span><span class=\"token punctuation\">,<\/span> truncatinotallow<span class=\"token operator\">=<\/span>True<span class=\"token punctuation\">,<\/span> padding<span class=\"token operator\">=<\/span>True<span class=\"token punctuation\">)<\/span>\r\n\r\n# \u8f93\u51fa token IDs\r\n<span class=\"token function\">print<\/span><span class=\"token punctuation\">(<\/span><span class=\"token string\">\"Token IDs:\"<\/span><span class=\"token punctuation\">,<\/span> inputs<span class=\"token punctuation\">[<\/span><span class=\"token string\">'input_ids'<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n# <span class=\"token number\">4.<\/span> \u4f20\u5165\u6a21\u578b\uff0c\u83b7\u53d6\u8f93\u51fa\r\noutputs <span class=\"token operator\">=<\/span> <span class=\"token function\">model<\/span><span class=\"token punctuation\">(<\/span><span class=\"token operator\">**<\/span>inputs<span class=\"token punctuation\">)<\/span>\r\n\r\n# <span class=\"token number\">5.<\/span> \u83b7\u53d6\u6700\u540e\u4e00\u5c42\u7684\u9690\u85cf\u72b6\u6001\u8868\u793a\r\nlast_hidden_states <span class=\"token operator\">=<\/span> outputs<span class=\"token punctuation\">.<\/span>last_hidden_state\r\n<span class=\"token function\">print<\/span><span class=\"token punctuation\">(<\/span><span class=\"token string\">\"Last hidden states shape:\"<\/span><span class=\"token punctuation\">,<\/span> last_hidden_states<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">)<\/span><\/code><\/pre>\n<ul id=\"code_id_0\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<li>9.<\/li>\n<li>10.<\/li>\n<li>11.<\/li>\n<li>12.<\/li>\n<li>13.<\/li>\n<li>14.<\/li>\n<li>15.<\/li>\n<li>16.<\/li>\n<li>17.<\/li>\n<li>18.<\/li>\n<li>19.<\/li>\n<li>20.<\/li>\n<li>21.<\/li>\n<li>22.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p>\u539f\u59cb\u8f93\u5165\u6587\u672c &#8220;A Titan RTX has 24GB of VRAM&#8221; \u901a\u8fc7 tokenizer \u5b8c\u6210\u5206\u8bcd\u548c\u8bcd\u8868\u6620\u5c04\u5de5\u4f5c\uff0c\u751f\u6210\u7684\u8f93\u5165 ID \u5217\u8868\uff1a<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_1\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\"><span class=\"token punctuation\">[<\/span><span class=\"token number\">101<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">138<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">28318<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">56898<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">12674<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10393<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10233<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">32469<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10108<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">74727<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">36535<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">102<\/span><span class=\"token punctuation\">]<\/span><\/code><\/pre>\n<ul id=\"code_id_1\" class=\"pre-numbering\">\n<li>1.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p>\u5176\u4e2d\uff0c<\/p>\n<ul data-id=\"u738a58b-5FLq1e4X\">\n<li data-id=\"ld70c578-iC7gYs7U\">101 \u8868\u793a [CLS] \u6807\u8bb0\uff1b<\/li>\n<li data-id=\"ld70c578-3nKjVXGH\">102 \u8868\u793a [SEP] \u6807\u8bb0\uff1b<\/li>\n<li data-id=\"ld70c578-A4RGJ7Er\">\u5176\u4f59\u4e3a\u5bf9\u5e94 token \u5728\u8bcd\u8868\u4e2d\u7684\u7d22\u5f15\u3002<\/li>\n<\/ul>\n<p>\u5728\u6240\u6709\u57fa\u4e8e Transformer \u7684 LLM \u4e2d\uff0c\u552f\u4e00\u5fc5\u987b\u7684\u8f93\u5165\u662f input_ids\uff0c\u5b83\u662f\u7531 Tokenizer \u6620\u5c04\u540e\u7684 token \u7d22\u5f15\u7ec4\u6210\u7684\u6574\u6570\u5411\u91cf\uff0c\u4ee3\u8868\u4e86\u8f93\u5165\u6587\u672c\u5728\u8bcd\u8868\u4e2d\u7684\u4f4d\u7f6e\u4fe1\u606f\u3002<\/p>\n<p><strong>(2) Embedding \u5c42<\/strong><\/p>\n<p>\u5728\u57fa\u4e8e Transformer \u7684\u5927\u578b\u8bed\u8a00\u6a21\u578b\uff08LLM\uff09\u4e2d\uff0c\u5d4c\u5165\u5c42\uff08Embedding Layer\uff09\u662f\u5c06\u8f93\u5165 token ID \u6620\u5c04\u4e3a\u5411\u91cf\u8868\u793a\u7684\u6838\u5fc3\u7ec4\u4ef6\u3002\u5176\u4f5c\u7528\u662f\u5c06\u79bb\u6563\u7684\u6574\u6570\u7d22\u5f15\u8f6c\u6362\u4e3a\u8fde\u7eed\u3001\u7a20\u5bc6\u7684\u5411\u91cf\u7a7a\u95f4\u8868\u793a\uff0c\u4ece\u800c\u4fbf\u4e8e\u540e\u7eed\u795e\u7ecf\u7f51\u7edc\u8fdb\u884c\u8bed\u4e49\u5efa\u6a21\u3002<\/p>\n<p>\u2705 \u4e07\u7269\u7686\u53ef Embedding\uff1a\u867d\u7136\u6700\u5e38\u89c1\u7684\u662f\u8bcd\u5d4c\u5165\uff08Word Embedding\uff09\uff0c\u4f46\u56fe\u50cf\u3001\u8bed\u97f3\u7b49\u4e5f\u53ef\u4ee5\u901a\u8fc7\u5d4c\u5165\u5c42\u6620\u5c04\u4e3a\u5411\u91cf\u5f62\u5f0f\uff0c\u5b9e\u73b0\u7edf\u4e00\u5efa\u6a21\u3002<\/p>\n<p>\u4f8b\u5982\uff0cmnist \u6570\u636e\u96c6\u4e2d\u7684\u56fe\u7247\uff0c\u53ef\u4ee5\u901a\u8fc7\u5d4c\u5165\u5c42\u6765\u8868\u793a\uff0c\u5982\u4e0b\u56fe\u6240\u793a\uff0c\u6bcf\u4e2a\u70b9\u4ee3\u8868\u4e00\u4e2a\u56fe\u7247(10000*784)\uff0c\u901a\u8fc7\u5d4c\u5165\u5c42\uff0c\u5c06\u56fe\u7247\u7684\u50cf\u7d20\u70b9\u8f6c\u5316\u4e3a\u7a20\u5bc6\u7684\u5411\u91cf\uff0c\u7136\u540e\u901a\u8fc7 t-SNE\/pca \u964d\u7ef4\uff0c\u53ef\u4ee5\u770b\u5230\u56fe\u7247\u7684\u7a7a\u95f4\u5206\u5e03\u3002<\/p>\n<p>&nbsp;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s8.51cto.com\/oss\/202505\/19\/a24b89560b18320b1cf92881230b196f547d69.webp\" data-type=\"block\" \/><\/p>\n<p>LLM \u4e2d\uff0c\u5355\u8bcd token \u9700\u8981\u7ecf\u8fc7 Embedding \u5c42\uff0cEmbedding \u5c42\u7684\u4f5c\u7528\u662f\u5c06\u8f93\u5165\u7684\u79bb\u6563\u5316\u8868\u793a\uff08\u4f8b\u5982 token ids\uff09\u8f6c\u6362\u4e3a\u8fde\u7eed\u7684\u4f4e\u7ef4\u5411\u91cf\u8868\u793a\uff0c\u5176\u7531\u5355\u8bcd Embedding \u548c\u4f4d\u7f6e Embedding \uff08Positional Encoding\uff09\u76f8\u52a0\u5f97\u5230\uff0c\u901a\u5e38\u5b9a\u4e49\u4e3a TransformerEmbedding \u5c42\u3002<\/p>\n<p>\u2460 \u5355\u8bcd\u5d4c\u5165\uff08Token Embedding\uff09\u300d<\/p>\n<p>\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u8f93\u5165\u6587\u672c\u901a\u5e38\u662f\u4ee5\u7b26\u53f7\u5f62\u5f0f\u5b58\u5728\u7684\u8bcd\u6c47\uff0c\u800c\u8fd9\u4e9b\u79bb\u6563\u7b26\u53f7\u65e0\u6cd5\u76f4\u63a5\u88ab\u795e\u7ecf\u7f51\u7edc\u5904\u7406\u3002\u56e0\u6b64\u9700\u8981\u4e00\u4e2a\u53ef\u5b66\u4e60\u7684\u5d4c\u5165\u77e9\u9635\u5c06\u6bcf\u4e2a token \u8f6c\u6362\u4e3a\u56fa\u5b9a\u7ef4\u5ea6\u7684\u5411\u91cf\u3002<\/p>\n<p>\u5de5\u4f5c\u539f\u7406\uff1a<\/p>\n<p>a. \u8f93\u5165\u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a [batch_size, seq_len] \u7684\u6574\u6570\u5f20\u91cf\uff0c\u8868\u793a\u6bcf\u4e2a token \u5728\u8bcd\u8868\u4e2d\u7684\u7d22\u5f15\uff1b<\/p>\n<p>b. \u8f93\u51fa\u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a [batch_size, seq_len, d_model] \u7684\u4e09\u7ef4\u5f20\u91cf\uff0c\u5176\u4e2d\uff1a<\/p>\n<ul data-id=\"u738a58b-cMKzKODw\">\n<li data-id=\"ld70c578-GWUarzsH\">d_model \u662f\u5d4c\u5165\u7ef4\u5ea6\uff08\u5982 512 \u6216 768\uff09\uff1b<\/li>\n<li data-id=\"ld70c578-qPmDENgM\">\u6bcf\u4e2a token \u5bf9\u5e94\u4e00\u4e2a d_model \u7ef4\u7684\u5411\u91cf\uff1b<\/li>\n<\/ul>\n<p>c. \u5d4c\u5165\u5c42\u6743\u91cd\u77e9\u9635\u5927\u5c0f\u4e3a [vocab_size, d_model]\uff0c\u53c2\u6570\u91cf\u4e3a\uff1a<\/p>\n<p>\u5728 PyTorch \u4e2d\uff0c\u8bcd\u5d4c\u5165\u5c42\u901a\u5e38\u4f7f\u7528 torch.nn.Embedding \u6a21\u5757\u5b9e\u73b0\uff0c\u5176\u4f5c\u7528\u662f\u5c06 token \u7684\u7d22\u5f15\u8f6c\u6362\u4e3a\u4f4e\u7ef4\u8bed\u4e49\u5411\u91cf\u8868\u793a\u3002<\/p>\n<p>\u2705 \u8f93\u5165\u4e0e\u8f93\u51fa\u8bf4\u660e<\/p>\n<table class=\"data-table\" data-transient-attributes=\"class\" data-width=\"655.99px\">\n<colgroup data-id=\"c7104f7d-1qBGTLmp\">\n<col span=\"1\" width=\"327.986\" data-id=\"cd89ecb0-ygB2uOmD\" \/>\n<col span=\"1\" width=\"328.003\" data-id=\"cd89ecb0-cCXbKOfD\" \/><\/colgroup>\n<tbody data-id=\"t6d5e859-YdFhB5bV\">\n<tr data-id=\"t31e458f-RsPBvytJ\">\n<td data-id=\"t6267798-mel2SMeB\" data-transient-attributes=\"table-cell-selection\">\u7c7b\u578b<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-dzqYkorz\" data-transient-attributes=\"table-cell-selection\">\u63cf\u8ff0<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-aAcI5lg2\">\n<td data-id=\"t6267798-HNPxegjy\" data-transient-attributes=\"table-cell-selection\">\u8f93\u5165<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-Et8cpp12\" data-transient-attributes=\"table-cell-selection\">\u4e00\u4e2a\u6574\u6570\u5f20\u91cf\uff0c\u8868\u793a\u6bcf\u4e2a token \u5728\u8bcd\u8868\u4e2d\u7684\u7d22\u5f15<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-vC58moSs\">\n<td data-id=\"t6267798-zsqaUbiu\" data-transient-attributes=\"table-cell-selection\">\u8f93\u5165\u5f62\u72b6<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-BRAqb7iG\" data-transient-attributes=\"table-cell-selection\">(batch_size, sequence_length)<br \/>\n\u5176\u4e2d\uff1a<br \/>\n&#8211; batch_size\uff1a\u6279\u6b21\u5927\u5c0f\uff08\u5373\u4e00\u6b21\u5904\u7406\u591a\u5c11\u6761\u6587\u672c\uff09<br \/>\n&#8211; sequence_length\uff1a\u6bcf\u6761\u6587\u672c\u5305\u542b\u7684 token \u6570\u91cf<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-xENdIVDf\">\n<td data-id=\"t6267798-JFx5FBE6\" data-transient-attributes=\"table-cell-selection\">\u8f93\u51fa<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-2iAbS16f\" data-transient-attributes=\"table-cell-selection\">\u6bcf\u4e2a token \u88ab\u6620\u5c04\u5230 embedding_dim \u7ef4\u5ea6\u7684\u7a20\u5bc6\u5411\u91cf<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-snEzgPOX\">\n<td class=\"table-last-column\" data-id=\"t6267798-S68PtH95\" data-transient-attributes=\"table-cell-selection\">\u8f93\u51fa\u5f62\u72b6<\/td>\n<td class=\"table-last-column table-last-row\" data-id=\"t6267798-4BBpFS2J\" data-transient-attributes=\"table-cell-selection\">(batch_size, sequence_length, embedding_dim)<\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<ul data-id=\"u738a58b-ngUaTEz4\">\n<li data-id=\"ld70c578-h3o6x6Mb\">embedding_dim \u662f\u5d4c\u5165\u5411\u91cf\u7684\u7ef4\u5ea6\uff0c\u4e5f\u79f0\u4e3a\u8bcd\u5411\u91cf\u7ef4\u5ea6\uff1b<\/li>\n<li data-id=\"ld70c578-nPZBTc3I\">\u5b83\u901a\u5e38\u88ab\u8bbe\u7f6e\u4e3a d_model \u6216 h\uff0c\u5373\u540e\u7eed Transformer \u5c42\u4f7f\u7528\u7684\u9690\u85cf\u5c42\u7ef4\u5ea6\uff08\u5982 512 \u6216 768\uff09.<\/li>\n<\/ul>\n<p>\ud83d\udcd0 \u793a\u4f8b\u4ee3\u7801\uff1a\u6784\u5efa Token Embedding \u5c42<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_2\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\">from transformers <span class=\"token keyword\">import<\/span> BertTokenizer\r\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn\r\n\r\n## <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> \u4f7f\u7528 <span class=\"token constant\">BERT<\/span> tokenizer \u5c06\u6279\u91cf\u8f93\u5165\u7684\u5b57\u7b26\u4e32\u6587\u672c\u5e8f\u5217\u8f6c\u5316\u4e3a input_ids\r\ntokenizer <span class=\"token operator\">=<\/span> BertTokenizer<span class=\"token punctuation\">.<\/span><span class=\"token function\">from_pretrained<\/span><span class=\"token punctuation\">(<\/span><span class=\"token string\">\"bert-base-multilingual-cased\"<\/span><span class=\"token punctuation\">)<\/span> \r\nbatch_text <span class=\"token operator\">=<\/span> <span class=\"token punctuation\">[<\/span><span class=\"token string\">\"A Titan RTX has 24GB of VRAM\"<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token string\">\"I have a dog and cat\"<\/span><span class=\"token punctuation\">]<\/span>\r\ninputs <span class=\"token operator\">=<\/span> <span class=\"token function\">tokenizer<\/span><span class=\"token punctuation\">(<\/span>batch_text<span class=\"token punctuation\">,<\/span> return_tensors<span class=\"token operator\">=<\/span><span class=\"token string\">\"pt\"<\/span><span class=\"token punctuation\">,<\/span> truncation<span class=\"token operator\">=<\/span>True<span class=\"token punctuation\">,<\/span> padding<span class=\"token operator\">=<\/span>True<span class=\"token punctuation\">)<\/span>\r\ninput_ids <span class=\"token operator\">=<\/span> inputs<span class=\"token punctuation\">[<\/span><span class=\"token string\">\"input_ids\"<\/span><span class=\"token punctuation\">]<\/span>\r\n\r\n# <span class=\"token number\">2.<\/span> \u521b\u5efa\u4e00\u4e2a nn<span class=\"token punctuation\">.<\/span>Embedding \u5c42\r\nvocab_size <span class=\"token operator\">=<\/span> tokenizer<span class=\"token punctuation\">.<\/span>vocab_size  # \u8bcd\u8868\u5927\u5c0f\u53d6\u51b3\u4e8e\u4f60\u52a0\u8f7d\u7684\u5177\u4f53 tokenizer \u6a21\u578b\r\nembedding_dim <span class=\"token operator\">=<\/span> <span class=\"token number\">512<\/span>  # \u5d4c\u5165\u5411\u91cf\u7684\u7ef4\u5ea6\uff0c\u53c2\u8003 transformer \u8bba\u6587\u7684\u5927\u5c0f\r\nembedding_layer <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Embedding<\/span><span class=\"token punctuation\">(<\/span>vocab_size<span class=\"token punctuation\">,<\/span> embedding_dim<span class=\"token punctuation\">)<\/span>\r\n\r\n# <span class=\"token number\">3.<\/span> \u901a\u8fc7 nn<span class=\"token punctuation\">.<\/span>Embedding \u5c42\uff0c\u5c06\u8f93\u5165\u7684 IDs \u6620\u5c04\u5230\u5d4c\u5165\u5411\u91cf\r\nembedded_output <span class=\"token operator\">=<\/span> <span class=\"token function\">embedding_layer<\/span><span class=\"token punctuation\">(<\/span>input_ids<span class=\"token punctuation\">)<\/span>\r\n\r\n# <span class=\"token number\">4.<\/span> \u8f93\u51fa\u5d4c\u5165\u5411\u91cf\u7684\u5f62\u72b6\r\n<span class=\"token function\">print<\/span><span class=\"token punctuation\">(<\/span><span class=\"token string\">\"\u5d4c\u5165\u5411\u91cf\u7684\u5f62\u72b6\uff1a\"<\/span><span class=\"token punctuation\">,<\/span> embedded_output<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">)<\/span>  # <span class=\"token punctuation\">(<\/span>batch_size<span class=\"token punctuation\">,<\/span> sequence_length<span class=\"token punctuation\">,<\/span> embedding_dim<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">Size<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">12<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">512<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n# <span class=\"token number\">5.<\/span> \u6253\u5370\u5d4c\u5165\u5411\u91cf\r\n<span class=\"token function\">print<\/span><span class=\"token punctuation\">(<\/span>embedded_output<span class=\"token punctuation\">)<\/span><\/code><\/pre>\n<ul id=\"code_id_2\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<li>9.<\/li>\n<li>10.<\/li>\n<li>11.<\/li>\n<li>12.<\/li>\n<li>13.<\/li>\n<li>14.<\/li>\n<li>15.<\/li>\n<li>16.<\/li>\n<li>17.<\/li>\n<li>18.<\/li>\n<li>19.<\/li>\n<li>20.<\/li>\n<li>21.<\/li>\n<li>22.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p>\u7a0b\u5e8f\u8fd0\u884c\u540e\u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\u6240\u793a:<\/p>\n<p>&nbsp;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s7.51cto.com\/oss\/202505\/19\/187e01b74e6f9aea824195d1bdce23ce8acfb9.webp\" data-type=\"block\" \/><\/p>\n<p>&nbsp;<\/p>\n<p>\u2461 \u4f4d\u7f6e\u5d4c\u5165\uff08Positional Encoding\uff09\u300d<\/p>\n<p>\u7531\u4e8e Transformer \u4e0d\u4f9d\u8d56\u4e8e RNN \u7684\u987a\u5e8f\u6027\u5efa\u6a21\u65b9\u5f0f\uff0c\u5b83\u5fc5\u987b\u663e\u5f0f\u5730\u5f15\u5165\u4f4d\u7f6e\u4fe1\u606f\uff0c\u4ee5\u4fdd\u7559 token \u5728\u5e8f\u5217\u4e2d\u7684\u4f4d\u7f6e\u7279\u5f81\u3002<\/p>\n<p>\u4e3a\u6b64\uff0cTransformer \u4f7f\u7528\u4e86 Sinusoidal Positional Encoding\uff08\u6b63\u5f26\/\u4f59\u5f26\u4f4d\u7f6e\u7f16\u7801\uff09\uff1a<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s8.51cto.com\/oss\/202505\/19\/37c4b8769d24a96970b9902302aaa9a234f9a1.webp\" data-type=\"block\" \/><\/p>\n<p>\u5176\u4e2d\uff1a<\/p>\n<ul data-id=\"u738a58b-Hzy6fWJk\">\n<li data-id=\"ld70c578-WiAV90fJ\">pos\uff1atoken \u5728\u5e8f\u5217\u4e2d\u7684\u4f4d\u7f6e\uff1b<\/li>\n<li data-id=\"ld70c578-03CxHmY4\">i\uff1a\u7ef4\u5ea6\u7d22\u5f15\uff1b<\/li>\n<li data-id=\"ld70c578-pkVQb0Ce\">d_model\uff1a\u5d4c\u5165\u7ef4\u5ea6\u3002<\/li>\n<\/ul>\n<p>\u2462 TransformerEmbedding \u5c42\u96c6\u6210<\/p>\n<p>transformer \u8f93\u5165\u6a21\u5757\u6709\u4e09\u4e2a\u7ec4\u6210\u90e8\u5206\uff1a\u6587\u672c\/\u63d0\u793a\u8bcd\u3001\u5206\u8bcd\u5668\uff08Tokenizer\uff09\u548c\u5d4c\u5165\u5c42\uff08Embeddings\uff09\u3002\u8f93\u5165\u6a21\u5757\u7684\u5de5\u4f5c\u6d41\u7a0b\u548c\u4ee3\u7801\u5b9e\u73b0\u5982\u4e0b\u6240\u793a:<\/p>\n<p>&nbsp;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s7.51cto.com\/oss\/202505\/19\/2173cd242df60f055856250bac89fd7a708195.webp\" data-type=\"block\" \/><\/p>\n<p>&nbsp;<\/p>\n<p>\u77e9\u9635\u7684\u6bcf\u4e00\u5217\u8868\u793a\u4e00\u4e2a token \u7684\u5d4c\u5165\u5411\u91cf\u3002<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_3\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\"><span class=\"token keyword\">class<\/span> <span class=\"token class-name\">PositionalEncoding<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n    <span class=\"token string\">\"\"<\/span>\"\r\n    compute sinusoid encoding<span class=\"token punctuation\">.<\/span>\r\n    <span class=\"token string\">\"\"<\/span>\"\r\n    def <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> max_len<span class=\"token punctuation\">,<\/span> device<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        constructor <span class=\"token keyword\">of<\/span> sinusoid encoding <span class=\"token keyword\">class<\/span>\r\n\r\n        <span class=\"token operator\">:<\/span>param d_model<span class=\"token operator\">:<\/span> dimension <span class=\"token keyword\">of<\/span> <span class=\"token literal-property property\">model<\/span>\r\n        <span class=\"token operator\">:<\/span>param max_len<span class=\"token operator\">:<\/span> max sequence length\r\n        <span class=\"token operator\">:<\/span>param device<span class=\"token operator\">:<\/span> hardware device setting\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token keyword\">super<\/span><span class=\"token punctuation\">(<\/span>PositionalEncoding<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n        # same size <span class=\"token keyword\">with<\/span> input <span class=\"token function\">matrix<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token keyword\">for<\/span> adding <span class=\"token keyword\">with<\/span> input matrix<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>encoding <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">zeros<\/span><span class=\"token punctuation\">(<\/span>max_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> device<span class=\"token operator\">=<\/span>device<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>encoding<span class=\"token punctuation\">.<\/span>requires_grad <span class=\"token operator\">=<\/span> False  # we don't need to compute gradient\r\n\r\n        pos <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">arange<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> max_len<span class=\"token punctuation\">,<\/span> device<span class=\"token operator\">=<\/span>device<span class=\"token punctuation\">)<\/span>\r\n        pos <span class=\"token operator\">=<\/span> pos<span class=\"token punctuation\">.<\/span><span class=\"token function\">float<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">unsqueeze<\/span><span class=\"token punctuation\">(<\/span>dim<span class=\"token operator\">=<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span>\r\n        # 1D <span class=\"token operator\">=&gt;<\/span> 2D unsqueeze to represent word's position\r\n\r\n        _2i <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">arange<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> step<span class=\"token operator\">=<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">,<\/span> device<span class=\"token operator\">=<\/span>device<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">float<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        # <span class=\"token string\">'i'<\/span> means index <span class=\"token keyword\">of<\/span> <span class=\"token function\">d_model<\/span> <span class=\"token punctuation\">(<\/span>e<span class=\"token punctuation\">.<\/span>g<span class=\"token punctuation\">.<\/span> embedding size <span class=\"token operator\">=<\/span> <span class=\"token number\">50<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token string\">'i'<\/span> <span class=\"token operator\">=<\/span> <span class=\"token punctuation\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span><span class=\"token number\">50<\/span><span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span>\r\n        # <span class=\"token string\">\"step=2\"<\/span> means <span class=\"token string\">'i'<\/span> multiplied <span class=\"token keyword\">with<\/span> <span class=\"token function\">two<\/span> <span class=\"token punctuation\">(<\/span>same <span class=\"token keyword\">with<\/span> <span class=\"token number\">2<\/span> <span class=\"token operator\">*<\/span> i<span class=\"token punctuation\">)<\/span>\r\n\r\n        self<span class=\"token punctuation\">.<\/span>encoding<span class=\"token punctuation\">[<\/span><span class=\"token operator\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">0<\/span><span class=\"token operator\">:<\/span><span class=\"token operator\">:<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">sin<\/span><span class=\"token punctuation\">(<\/span>pos <span class=\"token operator\">\/<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token number\">10000<\/span> <span class=\"token operator\">**<\/span> <span class=\"token punctuation\">(<\/span>_2i <span class=\"token operator\">\/<\/span> d_model<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>encoding<span class=\"token punctuation\">[<\/span><span class=\"token operator\">:<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token operator\">:<\/span><span class=\"token operator\">:<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">]<\/span> <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">cos<\/span><span class=\"token punctuation\">(<\/span>pos <span class=\"token operator\">\/<\/span> <span class=\"token punctuation\">(<\/span><span class=\"token number\">10000<\/span> <span class=\"token operator\">**<\/span> <span class=\"token punctuation\">(<\/span>_2i <span class=\"token operator\">\/<\/span> d_model<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span>\r\n        # compute positional encoding to consider positional information <span class=\"token keyword\">of<\/span> words\r\n\r\n    def <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        # self<span class=\"token punctuation\">.<\/span>encoding\r\n        # <span class=\"token punctuation\">[<\/span>max_len <span class=\"token operator\">=<\/span> <span class=\"token number\">512<\/span><span class=\"token punctuation\">,<\/span> d_model <span class=\"token operator\">=<\/span> <span class=\"token number\">512<\/span><span class=\"token punctuation\">]<\/span>\r\n\r\n        batch_size<span class=\"token punctuation\">,<\/span> seq_len <span class=\"token operator\">=<\/span> x<span class=\"token punctuation\">.<\/span><span class=\"token function\">size<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        # <span class=\"token punctuation\">[<\/span>batch_size <span class=\"token operator\">=<\/span> <span class=\"token number\">128<\/span><span class=\"token punctuation\">,<\/span> seq_len <span class=\"token operator\">=<\/span> <span class=\"token number\">30<\/span><span class=\"token punctuation\">]<\/span>\r\n\r\n        <span class=\"token keyword\">return<\/span> self<span class=\"token punctuation\">.<\/span>encoding<span class=\"token punctuation\">[<\/span><span class=\"token operator\">:<\/span>seq_len<span class=\"token punctuation\">,<\/span> <span class=\"token operator\">:<\/span><span class=\"token punctuation\">]<\/span>\r\n        # <span class=\"token punctuation\">[<\/span>seq_len <span class=\"token operator\">=<\/span> <span class=\"token number\">30<\/span><span class=\"token punctuation\">,<\/span> d_model <span class=\"token operator\">=<\/span> <span class=\"token number\">512<\/span><span class=\"token punctuation\">]<\/span>\r\n        # it will add <span class=\"token keyword\">with<\/span> <span class=\"token literal-property property\">tok_emb<\/span> <span class=\"token operator\">:<\/span> <span class=\"token punctuation\">[<\/span><span class=\"token number\">128<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">30<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">512<\/span><span class=\"token punctuation\">]<\/span>         \r\n\r\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">TokenEmbedding<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Embedding<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n    <span class=\"token string\">\"\"<\/span>\"\r\n    Token Embedding using torch<span class=\"token punctuation\">.<\/span>nn\r\n    they will dense representation <span class=\"token keyword\">of<\/span> word using weighted matrix\r\n    <span class=\"token string\">\"\"<\/span>\"\r\n\r\n    def <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> vocab_size<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token keyword\">class<\/span> <span class=\"token class-name\">for<\/span> token embedding that included positional information\r\n        <span class=\"token operator\">:<\/span>param vocab_size<span class=\"token operator\">:<\/span> size <span class=\"token keyword\">of<\/span> <span class=\"token literal-property property\">vocabulary<\/span>\r\n        <span class=\"token operator\">:<\/span>param d_model<span class=\"token operator\">:<\/span> dimensions <span class=\"token keyword\">of<\/span> model\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token keyword\">super<\/span><span class=\"token punctuation\">(<\/span>TokenEmbedding<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>vocab_size<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> padding_idx<span class=\"token operator\">=<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">TransformerEmbedding<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n    <span class=\"token string\">\"\"<\/span>\"\r\n    token embedding <span class=\"token operator\">+<\/span> positional <span class=\"token function\">encoding<\/span> <span class=\"token punctuation\">(<\/span>sinusoid<span class=\"token punctuation\">)<\/span>\r\n    positional encoding can give positional information to network\r\n    <span class=\"token string\">\"\"<\/span>\"\r\n\r\n    def <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> vocab_size<span class=\"token punctuation\">,<\/span> max_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> drop_prob<span class=\"token punctuation\">,<\/span> device<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token keyword\">class<\/span> <span class=\"token class-name\">for<\/span> word embedding that included positional information\r\n        <span class=\"token operator\">:<\/span>param vocab_size<span class=\"token operator\">:<\/span> size <span class=\"token keyword\">of<\/span> <span class=\"token literal-property property\">vocabulary<\/span>\r\n        <span class=\"token operator\">:<\/span>param d_model<span class=\"token operator\">:<\/span> dimensions <span class=\"token keyword\">of<\/span> model\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token keyword\">super<\/span><span class=\"token punctuation\">(<\/span>TransformerEmbedding<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>tok_emb <span class=\"token operator\">=<\/span> <span class=\"token function\">TokenEmbedding<\/span><span class=\"token punctuation\">(<\/span>vocab_size<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span>\r\n        # self<span class=\"token punctuation\">.<\/span>position_embedding <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Embedding<\/span><span class=\"token punctuation\">(<\/span>max_len<span class=\"token punctuation\">,<\/span> embed_size<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>pos_emb <span class=\"token operator\">=<\/span> <span class=\"token function\">PositionalEncoding<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> max_len<span class=\"token punctuation\">,<\/span> device<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>drop_out <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Dropout<\/span><span class=\"token punctuation\">(<\/span>p<span class=\"token operator\">=<\/span>drop_prob<span class=\"token punctuation\">)<\/span>\r\n\r\n    def <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        tok_emb <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">tok_emb<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span>\r\n        pos_emb <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">pos_emb<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span>\r\n        <span class=\"token keyword\">return<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">drop_out<\/span><span class=\"token punctuation\">(<\/span>tok_emb <span class=\"token operator\">+<\/span> pos_emb<span class=\"token punctuation\">)<\/span><\/code><\/pre>\n<ul id=\"code_id_3\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<li>9.<\/li>\n<li>10.<\/li>\n<li>11.<\/li>\n<li>12.<\/li>\n<li>13.<\/li>\n<li>14.<\/li>\n<li>15.<\/li>\n<li>16.<\/li>\n<li>17.<\/li>\n<li>18.<\/li>\n<li>19.<\/li>\n<li>20.<\/li>\n<li>21.<\/li>\n<li>22.<\/li>\n<li>23.<\/li>\n<li>24.<\/li>\n<li>25.<\/li>\n<li>26.<\/li>\n<li>27.<\/li>\n<li>28.<\/li>\n<li>29.<\/li>\n<li>30.<\/li>\n<li>31.<\/li>\n<li>32.<\/li>\n<li>33.<\/li>\n<li>34.<\/li>\n<li>35.<\/li>\n<li>36.<\/li>\n<li>37.<\/li>\n<li>38.<\/li>\n<li>39.<\/li>\n<li>40.<\/li>\n<li>41.<\/li>\n<li>42.<\/li>\n<li>43.<\/li>\n<li>44.<\/li>\n<li>45.<\/li>\n<li>46.<\/li>\n<li>47.<\/li>\n<li>48.<\/li>\n<li>49.<\/li>\n<li>50.<\/li>\n<li>51.<\/li>\n<li>52.<\/li>\n<li>53.<\/li>\n<li>54.<\/li>\n<li>55.<\/li>\n<li>56.<\/li>\n<li>57.<\/li>\n<li>58.<\/li>\n<li>59.<\/li>\n<li>60.<\/li>\n<li>61.<\/li>\n<li>62.<\/li>\n<li>63.<\/li>\n<li>64.<\/li>\n<li>65.<\/li>\n<li>66.<\/li>\n<li>67.<\/li>\n<li>68.<\/li>\n<li>69.<\/li>\n<li>70.<\/li>\n<li>71.<\/li>\n<li>72.<\/li>\n<li>73.<\/li>\n<li>74.<\/li>\n<li>75.<\/li>\n<li>76.<\/li>\n<li>77.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<h4>2. Multi-Head Attention \u7ed3\u6784<\/h4>\n<p>Encoder \u548c Decoder \u7ed3\u6784\u4e2d\u516c\u5171\u7684 layer \u4e4b\u4e00\u662f Multi-Head Attention\uff0c\u5176\u662f\u7531\u591a\u4e2a Self-Attention \u5e76\u884c\u7ec4\u6210\u7684\u3002Encoder block \u53ea\u5305\u542b\u4e00\u4e2a Multi-Head Attention\uff0c\u800c Decoder block \u5305\u542b\u4e24\u4e2a Multi-Head Attention (\u5176\u4e2d\u6709\u4e00\u4e2a\u7528\u5230 Masked)\u3002<\/p>\n<p>&nbsp;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s6.51cto.com\/oss\/202505\/19\/e4a3dee675ec956452f8766e7393ab9b50ec8b.webp\" data-type=\"block\" \/><\/p>\n<p><strong>(1) Self-Attention \u7ed3\u6784<\/strong><\/p>\n<p>Self-Attention \u4e2d\u6587\u7ffb\u8bd1\u4e3a\u81ea\u6ce8\u610f\u529b\u673a\u5236\uff0c\u8bba\u6587\u4e2d\u53eb\u4f5c Scale Dot Product Attention\uff0c\u5b83\u662f Transformer \u67b6\u6784\u7684\u6838\u5fc3\uff0c\u4f7f\u5f97\u6bcf\u4e2a token \u80fd\u591f\u5173\u6ce8\u6574\u4e2a\u5e8f\u5217\u4e2d\u7684\u5176\u4ed6 token\uff0c\u4ece\u800c\u5efa\u7acb\u5168\u5c40\u4f9d\u8d56\u5173\u7cfb\u3002\u5176\u7ed3\u6784\u5982\u4e0b\u56fe\u6240\u793a\uff1a<\/p>\n<p>&nbsp;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s4.51cto.com\/oss\/202505\/19\/6245bc99361ef432dc6540298f465532e2fa20.webp\" data-type=\"block\" \/><\/p>\n<p><strong>(2) Self-Attention \u5b9e\u73b0<\/strong>\u2705 \u5728\u672c\u6587\u4e2d\uff0cSelf-Attention \u5c42\u4e0e\u8bba\u6587\u4e2d\u7684 ScaleDotProductAttention \u5c42\u610f\u4e49\u4e00\u81f4\uff0c\u5b9e\u73b0\u65b9\u5f0f\u5b8c\u5168\u76f8\u540c\u3002<\/p>\n<p>\ud83e\uddee \u6570\u5b66\u5b9a\u4e49<\/p>\n<p>Self-Attention \u7684\u8ba1\u7b97\u8fc7\u7a0b\u53ef\u4ee5\u8868\u793a\u4e3a\uff1a<\/p>\n<p><img data-dominant-color=\"f1f1f1\" data-has-transparency=\"false\" style=\"--dominant-color: #f1f1f1;\" loading=\"lazy\" decoding=\"async\" class=\"not-transparent alignnone size-full wp-image-25921\" src=\"https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/44deff790edbeb0b596812cbd6409fc151fe16.png\" width=\"342\" height=\"57\" srcset=\"https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/44deff790edbeb0b596812cbd6409fc151fe16.png 342w, https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/44deff790edbeb0b596812cbd6409fc151fe16-300x50.png 300w, https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/44deff790edbeb0b596812cbd6409fc151fe16-150x25.png 150w\" sizes=\"auto, (max-width: 342px) 100vw, 342px\" \/><\/p>\n<p>\u5176\u4e2d\uff1a<\/p>\n<ul data-id=\"u738a58b-WsvWf4la\">\n<li data-id=\"ld70c578-LQBSM7EE\"><img decoding=\"async\" src=\"https:\/\/s9.51cto.com\/oss\/202505\/19\/c7df6f110b5ddbb341f211ba03ec49d4593fb1.png\" data-type=\"inline\" \/>\u00a0\uff1aQuery \u5411\u91cf\uff1b<\/li>\n<li data-id=\"ld70c578-JSrgjHzA\"><img decoding=\"async\" src=\"https:\/\/s4.51cto.com\/oss\/202505\/19\/333c79672347458a36000050c760a3cc3f05be.png\" data-type=\"inline\" \/>\uff1aKey \u5411\u91cf\uff1b<\/li>\n<li data-id=\"ld70c578-012ybbyJ\"><img decoding=\"async\" src=\"https:\/\/s4.51cto.com\/oss\/202505\/19\/11084d915754f5444dc964516e98224c7aee54.png\" data-type=\"inline\" \/>\uff1aValue \u5411\u91cf\uff1b<\/li>\n<li data-id=\"ld70c578-l72uLx4R\"><img decoding=\"async\" src=\"https:\/\/s4.51cto.com\/oss\/202505\/19\/c8c100d77bb5a2bf1e06121d0ed41113d80a1a.png\" data-type=\"inline\" \/>\uff1aQuery \u548c Key \u7684\u7ef4\u5ea6\uff1b<\/li>\n<li data-id=\"ld70c578-L1ulSHLr\">Softmax \u5bf9\u6ce8\u610f\u529b\u5206\u6570\u6309\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6\u5f52\u4e00\u5316\uff1b<\/li>\n<li data-id=\"ld70c578-B9fpmp41\"><img decoding=\"async\" src=\"https:\/\/s9.51cto.com\/oss\/202505\/19\/73dd018617b877a071e033f3959d446343c210.png\" data-type=\"inline\" \/>\uff1a\u7528\u4e8e\u7f29\u653e\u70b9\u79ef\uff0c\u9632\u6b62 softmax \u68af\u5ea6\u6d88\u5931\uff1b<\/li>\n<\/ul>\n<p>\u8f93\u5165\u6765\u6e90\uff1a<\/p>\n<ul data-id=\"u738a58b-ZrgKkpPL\">\n<li data-id=\"ld70c578-BjziBcMg\">\u8f93\u5165\u8bcd\u5411\u91cf\u7ecf\u8fc7 Embedding \u5c42\u540e\uff0c\u8fdb\u5165\u4f4d\u7f6e\u7f16\u7801\u5c42\uff1b<\/li>\n<li data-id=\"ld70c578-5m1tkQwl\">\u518d\u901a\u8fc7\u7ebf\u6027\u53d8\u6362\uff08Linear \u5c42\uff09\uff0c\u5206\u522b\u751f\u6210 Query\u3001Key \u548c Value \u5411\u91cf\uff1b<\/li>\n<li data-id=\"ld70c578-Wn9jhhhS\">\u8fd9\u4e09\u4e2a\u5411\u91cf\u7684\u5f62\u72b6\u901a\u5e38\u4e3a [batch_size, seq_len, d_k] \u6216 [seq_len, d_k]\u3002<\/li>\n<\/ul>\n<p>\u8ba1\u7b97\u6b65\u9aa4\u5982\u4e0b\uff1a<\/p>\n<p>\u2460 \u8ba1\u7b97\u6ce8\u610f\u529b\u5206\u6570\u77e9\u9635<\/p>\n<p><img data-dominant-color=\"f3f3f3\" data-has-transparency=\"false\" style=\"--dominant-color: #f3f3f3;\" loading=\"lazy\" decoding=\"async\" class=\"not-transparent alignnone size-full wp-image-25922\" src=\"https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/f94336619d330f21c25832f92faa38396bf201.png\" width=\"155\" height=\"60\" srcset=\"https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/f94336619d330f21c25832f92faa38396bf201.png 155w, https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/f94336619d330f21c25832f92faa38396bf201-150x58.png 150w\" sizes=\"auto, (max-width: 155px) 100vw, 155px\" \/><\/p>\n<ul data-id=\"u738a58b-kNOQfM1C\">\n<li data-id=\"ld70c578-lkgKFudH\">\u5176\u4e2d\u00a0<img decoding=\"async\" src=\"https:\/\/s8.51cto.com\/oss\/202505\/19\/461853714055bdd9bd9420a79629c4ab13b017.png\" data-type=\"inline\" \/>\u662f Key \u5f20\u91cf\u7684\u8f6c\u7f6e\uff1b<\/li>\n<li data-id=\"ld70c578-xWbLHdJt\">\u70b9\u79ef\u7ed3\u679c\u662f\u4e00\u4e2a [seq_len, seq_len] \u7684\u6ce8\u610f\u529b\u5f97\u5206\u77e9\u9635\uff1b<\/li>\n<li data-id=\"ld70c578-sOJTvnmO\">\u4f7f\u7528 softmax \u5f52\u4e00\u5316\uff0c\u5f97\u5230\u6ce8\u610f\u529b\u6743\u91cd\u3002<\/li>\n<\/ul>\n<p>\u2461 \u5e94\u7528\u63a9\u7801\uff08\u53ef\u9009\uff09<\/p>\n<ul data-id=\"u738a58b-fE9TkTQO\">\n<li data-id=\"ld70c578-3xqPeLHw\">\u5728 Decoder \u4e2d\u4f7f\u7528 Masked Self-Attention\uff0c\u9632\u6b62\u672a\u6765\u4fe1\u606f\u6cc4\u9732\uff1b<\/li>\n<li data-id=\"ld70c578-sbFmar6r\">\u82e5\u4f20\u5165 mask\uff0c\u5c06\u5bf9\u5e94\u4f4d\u7f6e\u8bbe\u4e3a\u6781\u5c0f\u503c\uff08\u5982 -1e9\uff09\u4ee5\u6291\u5236\u5176\u5f71\u54cd\u3002<\/li>\n<\/ul>\n<p>\u2462 \u52a0\u6743\u805a\u5408 Value \u5411\u91cf<\/p>\n<ul data-id=\"u738a58b-zMwOJBtq\">\n<li data-id=\"ld70c578-jRtxm2iG\">\u5c06 softmax \u540e\u7684\u6ce8\u610f\u529b\u6743\u91cd\u4e0e Value \u76f8\u4e58\uff0c\u5f97\u5230\u4e0a\u4e0b\u6587\u611f\u77e5\u7684\u8f93\u51fa\u5f20\u91cf\uff1b<\/li>\n<li data-id=\"ld70c578-MEXZO6Sd\">\u8f93\u51fa\u7ef4\u5ea6\u4fdd\u6301\u4e0e\u8f93\u5165\u4e00\u81f4\uff1a[batch_size, seq_len, d_v]\u3002<\/li>\n<\/ul>\n<p>\ud83e\uddf1 \u4ee3\u7801\u5b9e\u73b0\uff1a<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_4\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\"><span class=\"token keyword\">import<\/span> torch\r\n<span class=\"token keyword\">import<\/span> math\r\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn\r\n\r\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">ScaleDotProductAttention<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n    def <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        \u521d\u59cb\u5316 Self<span class=\"token operator\">-<\/span>Attention \u5c42\uff0c\u4ec5\u5305\u542b\u4e00\u4e2a softmax \u64cd\u4f5c\u3002\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token keyword\">super<\/span><span class=\"token punctuation\">(<\/span>ScaleDotProductAttention<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>softmax <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Softmax<\/span><span class=\"token punctuation\">(<\/span>dim<span class=\"token operator\">=<\/span><span class=\"token operator\">-<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n    def <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> <span class=\"token constant\">Q<\/span><span class=\"token operator\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span> <span class=\"token constant\">K<\/span><span class=\"token operator\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span> <span class=\"token constant\">V<\/span><span class=\"token operator\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor<span class=\"token punctuation\">,<\/span> <span class=\"token literal-property property\">mask<\/span><span class=\"token operator\">:<\/span> torch<span class=\"token punctuation\">.<\/span>Tensor <span class=\"token operator\">=<\/span> None<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        Self<span class=\"token operator\">-<\/span>Attention \u524d\u5411\u4f20\u64ad\u51fd\u6570\r\n\r\n        <span class=\"token operator\">:<\/span>param <span class=\"token constant\">Q<\/span><span class=\"token operator\">:<\/span> Query \u5411\u91cf\uff0c\u5f62\u72b6\u4e3a <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_k<span class=\"token punctuation\">]<\/span>\r\n        <span class=\"token operator\">:<\/span>param <span class=\"token constant\">K<\/span><span class=\"token operator\">:<\/span> Key \u5411\u91cf\uff0c\u5f62\u72b6\u4e3a <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_k<span class=\"token punctuation\">]<\/span>\r\n        <span class=\"token operator\">:<\/span>param <span class=\"token constant\">V<\/span><span class=\"token operator\">:<\/span> Value \u5411\u91cf\uff0c\u5f62\u72b6\u4e3a <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_v<span class=\"token punctuation\">]<\/span>\r\n        <span class=\"token operator\">:<\/span>param mask<span class=\"token operator\">:<\/span> \u63a9\u7801\u5f20\u91cf\uff0c\u5f62\u72b6\u4e3a <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">]<\/span>\r\n        <span class=\"token operator\">:<\/span><span class=\"token keyword\">return<\/span><span class=\"token operator\">:<\/span> \r\n            <span class=\"token literal-property property\">output<\/span><span class=\"token operator\">:<\/span> \u52a0\u6743\u540e\u7684 Value \u5411\u91cf\uff0c\u5f62\u72b6\u4e3a <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_v<span class=\"token punctuation\">]<\/span>\r\n            <span class=\"token literal-property property\">attn_weights<\/span><span class=\"token operator\">:<\/span> \u6ce8\u610f\u529b\u6743\u91cd\u77e9\u9635\uff0c\u5f62\u72b6\u4e3a <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">]<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        # <span class=\"token number\">1.<\/span> \u8ba1\u7b97 <span class=\"token constant\">QK<\/span><span class=\"token operator\">^<\/span><span class=\"token constant\">T<\/span> \u5f97\u5230\u6ce8\u610f\u529b\u5206\u6570\r\n        <span class=\"token constant\">K_T<\/span> <span class=\"token operator\">=<\/span> <span class=\"token constant\">K<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">transpose<\/span><span class=\"token punctuation\">(<\/span><span class=\"token operator\">-<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token operator\">-<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span>  # <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> d_k<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">]<\/span>\r\n        scores <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">matmul<\/span><span class=\"token punctuation\">(<\/span><span class=\"token constant\">Q<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token constant\">K_T<\/span><span class=\"token punctuation\">)<\/span> <span class=\"token operator\">\/<\/span> math<span class=\"token punctuation\">.<\/span><span class=\"token function\">sqrt<\/span><span class=\"token punctuation\">(<\/span><span class=\"token constant\">Q<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">size<\/span><span class=\"token punctuation\">(<\/span><span class=\"token operator\">-<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span>  # <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">]<\/span>\r\n\r\n        # <span class=\"token number\">2.<\/span> \u5982\u679c\u6709 mask\uff0c\u5e94\u7528\u63a9\u7801\uff08\u4f8b\u5982 Decoder \u4e2d\u9632\u6b62\u770b\u5230\u672a\u6765\u8bcd\uff09\r\n        <span class=\"token keyword\">if<\/span> mask is not None<span class=\"token operator\">:<\/span>\r\n            scores <span class=\"token operator\">=<\/span> scores<span class=\"token punctuation\">.<\/span><span class=\"token function\">masked_fill<\/span><span class=\"token punctuation\">(<\/span>mask <span class=\"token operator\">==<\/span> <span class=\"token number\">0<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token operator\">-<\/span><span class=\"token number\">1e9<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n        # <span class=\"token number\">3.<\/span> \u5e94\u7528 softmax \u5f97\u5230\u6ce8\u610f\u529b\u6743\u91cd\r\n        attn_weights <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">softmax<\/span><span class=\"token punctuation\">(<\/span>scores<span class=\"token punctuation\">)<\/span>  # <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">]<\/span>\r\n\r\n        # <span class=\"token number\">4.<\/span> \u6743\u91cd \u00d7 Value \u5f97\u5230\u6700\u7ec8\u8f93\u51fa\r\n        output <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">matmul<\/span><span class=\"token punctuation\">(<\/span>attn_weights<span class=\"token punctuation\">,<\/span> <span class=\"token constant\">V<\/span><span class=\"token punctuation\">)<\/span>  # <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_v<span class=\"token punctuation\">]<\/span>\r\n\r\n        <span class=\"token keyword\">return<\/span> output<span class=\"token punctuation\">,<\/span> attn_weights<\/code><\/pre>\n<ul id=\"code_id_4\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<li>9.<\/li>\n<li>10.<\/li>\n<li>11.<\/li>\n<li>12.<\/li>\n<li>13.<\/li>\n<li>14.<\/li>\n<li>15.<\/li>\n<li>16.<\/li>\n<li>17.<\/li>\n<li>18.<\/li>\n<li>19.<\/li>\n<li>20.<\/li>\n<li>21.<\/li>\n<li>22.<\/li>\n<li>23.<\/li>\n<li>24.<\/li>\n<li>25.<\/li>\n<li>26.<\/li>\n<li>27.<\/li>\n<li>28.<\/li>\n<li>29.<\/li>\n<li>30.<\/li>\n<li>31.<\/li>\n<li>32.<\/li>\n<li>33.<\/li>\n<li>34.<\/li>\n<li>35.<\/li>\n<li>36.<\/li>\n<li>37.<\/li>\n<li>38.<\/li>\n<li>39.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p>\ud83d\udcd0 \u793a\u4f8b\u8c03\u7528\u4e0e\u8f93\u51fa\u89e3\u6790\uff1a<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_5\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\"># \u521b\u5efa <span class=\"token constant\">Q<\/span>\u3001<span class=\"token constant\">K<\/span>\u3001<span class=\"token constant\">V<\/span> \u5f20\u91cf\r\n<span class=\"token constant\">Q<\/span> <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">randn<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">5<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">64<\/span><span class=\"token punctuation\">)<\/span>  # <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token operator\">=<\/span><span class=\"token number\">5<\/span><span class=\"token punctuation\">,<\/span> seq_len<span class=\"token operator\">=<\/span><span class=\"token number\">10<\/span><span class=\"token punctuation\">,<\/span> d_k<span class=\"token operator\">=<\/span><span class=\"token number\">64<\/span><span class=\"token punctuation\">]<\/span>\r\n<span class=\"token constant\">K<\/span> <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">randn<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">5<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">64<\/span><span class=\"token punctuation\">)<\/span>\r\n<span class=\"token constant\">V<\/span> <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">randn<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">5<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">64<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n# \u521b\u5efa Self<span class=\"token operator\">-<\/span>Attention \u5c42\r\nattention <span class=\"token operator\">=<\/span> <span class=\"token function\">ScaleDotProductAttention<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n# \u524d\u5411\u4f20\u64ad\r\noutput<span class=\"token punctuation\">,<\/span> attn_weights <span class=\"token operator\">=<\/span> <span class=\"token function\">attention<\/span><span class=\"token punctuation\">(<\/span><span class=\"token constant\">Q<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token constant\">K<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token constant\">V<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n# \u6253\u5370\u8f93\u51fa\u5f62\u72b6\r\n<span class=\"token function\">print<\/span><span class=\"token punctuation\">(<\/span>f<span class=\"token string\">\"ScaleDotProductAttention output shape: {output.shape}\"<\/span><span class=\"token punctuation\">)<\/span>      # <span class=\"token punctuation\">[<\/span><span class=\"token number\">5<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">64<\/span><span class=\"token punctuation\">]<\/span>\r\n<span class=\"token function\">print<\/span><span class=\"token punctuation\">(<\/span>f<span class=\"token string\">\"attn_weights shape: {attn_weights.shape}\"<\/span><span class=\"token punctuation\">)<\/span>                # <span class=\"token punctuation\">[<\/span><span class=\"token number\">5<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10<\/span><span class=\"token punctuation\">]<\/span><\/code><\/pre>\n<ul id=\"code_id_5\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<li>9.<\/li>\n<li>10.<\/li>\n<li>11.<\/li>\n<li>12.<\/li>\n<li>13.<\/li>\n<li>14.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<table class=\"data-table\" data-transient-attributes=\"class\" data-width=\"655.99px\">\n<colgroup data-id=\"c7104f7d-njcPXETe\">\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-qFhIhdNA\" \/>\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-DDjdBPTK\" \/>\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-uGFPL176\" \/><\/colgroup>\n<tbody data-id=\"t6d5e859-Io5JanpR\">\n<tr data-id=\"t31e458f-1khFVAaH\">\n<td data-id=\"t6267798-frt4jmUe\" data-transient-attributes=\"table-cell-selection\">\u53d8\u91cf<\/td>\n<td data-id=\"t6267798-XJELV6BD\" data-transient-attributes=\"table-cell-selection\">\u5f62\u72b6<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-hU5jOTVH\" data-transient-attributes=\"table-cell-selection\">\u63cf\u8ff0<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-GTcL9f96\">\n<td data-id=\"t6267798-n4JLL1bj\" data-transient-attributes=\"table-cell-selection\">Q, K, V<\/td>\n<td data-id=\"t6267798-gTcFNLsT\" data-transient-attributes=\"table-cell-selection\">[5, 10, 64]<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-AxkCB105\" data-transient-attributes=\"table-cell-selection\">batch=5\uff0c\u5e8f\u5217\u957f\u5ea6=10\uff0c\u5d4c\u5165\u7ef4\u5ea6=64<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-2P0ztg6K\">\n<td data-id=\"t6267798-6o3WC1sJ\" data-transient-attributes=\"table-cell-selection\">scores<\/td>\n<td data-id=\"t6267798-xRdGLjCj\" data-transient-attributes=\"table-cell-selection\">[5, 10, 10]<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-h4NEOwwX\" data-transient-attributes=\"table-cell-selection\">\u6ce8\u610f\u529b\u5f97\u5206\u77e9\u9635\uff0c\u53cd\u6620 token \u4e4b\u95f4\u7684\u76f8\u4f3c\u5ea6<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-JVM1UNm8\">\n<td data-id=\"t6267798-W59n9CRt\" data-transient-attributes=\"table-cell-selection\">attn_weights<\/td>\n<td data-id=\"t6267798-YEcMcti7\" data-transient-attributes=\"table-cell-selection\">[5, 10, 10]<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-31D9vdE5\" data-transient-attributes=\"table-cell-selection\">softmax \u540e\u7684\u6ce8\u610f\u529b\u6743\u91cd\uff0c\u7528\u4e8e\u52a0\u6743\u805a\u5408 Value<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-lRPTPAKQ\">\n<td class=\"table-last-column\" data-id=\"t6267798-0uw2lAVH\" data-transient-attributes=\"table-cell-selection\">output<\/td>\n<td class=\"table-last-column\" data-id=\"t6267798-OOVRYVj9\" data-transient-attributes=\"table-cell-selection\">[5, 10, 64]<\/td>\n<td class=\"table-last-column table-last-row\" data-id=\"t6267798-6Pqim6Ch\" data-transient-attributes=\"table-cell-selection\">\u6700\u7ec8\u8f93\u51fa\uff0c\u878d\u5408\u4e86\u4e0a\u4e0b\u6587\u4fe1\u606f\u7684 Value \u52a0\u6743\u8868\u793a<\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<p><strong>(3) Multi-Head Attention<\/strong><\/p>\n<p>Multi-Head Attention(MHA)\u662f\u5728Self-Attention\u57fa\u7840\u4e0a\u5f15\u5165\u7684\u4e00\u79cd\u589e\u5f3a\u673a\u5236\u3002\u5176\u6838\u5fc3\u7406\u5ff5\u662f\uff1a\u5c06\u8f93\u5165\u5411\u91cf\u7a7a\u95f4\u5212\u5206\u4e3a\u591a\u4e2a\u5b50\u7a7a\u95f4\uff0c\u5728\u6bcf\u4e2a\u5b50\u7a7a\u95f4\u4e2d\u72ec\u7acb\u8ba1\u7b97Self-Attention\uff0c\u6700\u540e\u5c06\u591a\u4e2a\u5b50\u7a7a\u95f4\u7684\u8f93\u51fa\u62fc\u63a5\u5728\u4e00\u8d77\u5e76\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\uff0c\u4ece\u800c\u5f97\u5230\u6700\u7ec8\u7684\u8f93\u51fa\u3002<\/p>\n<p>\u5bf9\u4e8e MHA\uff0c\u4e4b\u6240\u4ee5\u9700\u8981\u5bf9 Q\u3001K\u3001V \u8fdb\u884c\u591a\u5934\uff08head\uff09\u5212\u5206\uff0c\u5176\u76ee\u7684\u662f\u4e3a\u4e86\u589e\u5f3a\u6a21\u578b\u5bf9\u4e0d\u540c\u4fe1\u606f\u7684\u5173\u6ce8\u3002\u5177\u4f53\u6765\u8bf4\uff0c\u591a\u7ec4 Q\u3001K\u3001V \u5206\u522b\u8ba1\u7b97 Self-Attention\uff0c\u6bcf\u4e2a\u5934\u81ea\u7136\u5c31\u4f1a\u6709\u72ec\u7acb\u7684 Q\u3001K\u3001V \u53c2\u6570\uff0c\u4ece\u800c\u8ba9\u6a21\u578b\u540c\u65f6\u5173\u6ce8\u591a\u4e2a\u4e0d\u540c\u7684\u4fe1\u606f\uff0c\u8fd9\u6709\u4e9b\u7c7b\u4f3c CNN \u67b6\u6784\u6a21\u578b\u7684\u591a\u901a\u9053\u673a\u5236\u3002<\/p>\n<p>\u4e0b\u56fe\u662f\u8bba\u6587\u4e2d Multi-Head Attention \u7684\u7ed3\u6784\u56fe\u3002<\/p>\n<p>&nbsp;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s5.51cto.com\/oss\/202505\/19\/d87988587a7d8bf2ec56760523b193df08acb9.webp\" data-type=\"block\" \/><\/p>\n<p>\u4ece\u56fe\u4e2d\u53ef\u4ee5\u770b\u51fa\uff0c MHA \u7ed3\u6784\u7684\u8ba1\u7b97\u8fc7\u7a0b\u53ef\u603b\u7ed3\u4e3a\u4e0b\u8ff0\u6b65\u9aa4:<\/p>\n<ul data-id=\"u738a58b-NDI3rb7U\">\n<li data-id=\"ld70c578-E2RH7KJq\">\u5c06\u8f93\u5165 Q\u3001K\u3001V \u5f20\u91cf\u8fdb\u884c\u7ebf\u6027\u53d8\u6362\uff08Linear \u5c42\uff09\uff0c\u8f93\u51fa\u5f20\u91cf\u5c3a\u5bf8\u4e3a [batch_size, seq_len, d_model]\uff1b<\/li>\n<li data-id=\"ld70c578-C4v7AIbj\">\u5c06\u524d\u9762\u6b65\u9aa4\u8f93\u51fa\u7684\u5f20\u91cf\uff0c\u6309\u7167\u5934\u7684\u6570\u91cf\uff08n_head\uff09\u62c6\u5206\u4e3a n_head \u5b50\u5f20\u91cf\uff0c\u5176\u5c3a\u5bf8\u4e3a [batch_size, n_head, seq_len, d_model\/\/n_head]\uff1b<\/li>\n<li data-id=\"ld70c578-v9T62tX4\">\u6bcf\u4e2a\u5b50\u5f20\u91cf\u5e76\u884c\u8ba1\u7b97\u6ce8\u610f\u529b\u5206\u6570\uff0c\u5373\u6267\u884c dot-product attention \u5c42\uff0c\u8f93\u51fa\u5f20\u91cf\u5c3a\u5bf8\u4e3a [batch_size, n_head, seq_len, d_model\/\/n_head]\uff1b<\/li>\n<li data-id=\"ld70c578-3O9Jig0U\">\u5c06\u8fd9\u4e9b\u5b50\u5f20\u91cf\u8fdb\u884c\u62fc\u63a5 concat \uff0c\u5e76\u7ecf\u8fc7\u7ebf\u6027\u53d8\u6362\u5f97\u5230\u6700\u7ec8\u7684\u8f93\u51fa\u5f20\u91cf\uff0c\u5c3a\u5bf8\u4e3a [batch_size, seq_len, d_model]\u3002<\/li>\n<\/ul>\n<p>\ud83d\udcd0 \u6570\u5b66\u8868\u8fbe\u5f0f<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s3.51cto.com\/oss\/202505\/19\/225b9e559fc4e80636007996fbd848ba9c3f37.png\" data-type=\"block\" \/><\/p>\n<p>\u5176\u4e2d\uff1a<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s6.51cto.com\/oss\/202505\/19\/28c02b645ec9d132555814e4ee5eebf7a400bc.png\" data-type=\"block\" \/><\/p>\n<ul data-id=\"u738a58b-9JHFguSJ\">\n<li data-id=\"ld70c578-S5N2SFl7\"><img data-dominant-color=\"ebebeb\" data-has-transparency=\"false\" style=\"--dominant-color: #ebebeb;\" loading=\"lazy\" decoding=\"async\" class=\"not-transparent alignnone size-full wp-image-25923\" src=\"https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/886b88324721c1e040b964eb1ed3bc8175ef7d.png\" width=\"199\" height=\"33\" srcset=\"https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/886b88324721c1e040b964eb1ed3bc8175ef7d.png 199w, https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/886b88324721c1e040b964eb1ed3bc8175ef7d-150x25.png 150w\" sizes=\"auto, (max-width: 199px) 100vw, 199px\" \/>\uff1a\u7b2c i \u4e2a head \u7684\u53ef\u5b66\u4e60\u53c2\u6570\uff1b<\/li>\n<li data-id=\"ld70c578-fphlYrGC\"><img decoding=\"async\" src=\"https:\/\/s5.51cto.com\/oss\/202505\/19\/b79eb5775e714af3e6f841e4230d5402a06faf.png\" data-type=\"inline\" \/>\uff1a\u6700\u7ec8\u8f93\u51fa\u7684\u7ebf\u6027\u53d8\u6362\u77e9\u9635\uff1b<\/li>\n<li data-id=\"ld70c578-e5CYP0pK\">Concat\u8868\u793a\u5c06\u5404\u4e2a head \u7684\u8f93\u51fa\u62fc\u63a5\u5728\u4e00\u8d77\u3002<\/li>\n<\/ul>\n<p><strong>(4) Multi-Head Attention \u5b9e\u73b0<\/strong><\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_6\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\"><span class=\"token keyword\">import<\/span> torch\r\n<span class=\"token keyword\">import<\/span> math\r\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn\r\n\r\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">MultiHeadAttention<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n    <span class=\"token string\">\"\"<\/span><span class=\"token string\">\"Multi-Head Attention Layer\"<\/span><span class=\"token string\">\"\"<\/span>\r\n    \r\n    def <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token literal-property property\">Args<\/span><span class=\"token operator\">:<\/span>\r\n            <span class=\"token literal-property property\">d_model<\/span><span class=\"token operator\">:<\/span> \u6a21\u578b\u5d4c\u5165\u7ef4\u5ea6\uff08\u901a\u5e38\u4e3a <span class=\"token number\">512<\/span> \u6216 <span class=\"token number\">768<\/span>\uff09\uff1b\r\n            <span class=\"token literal-property property\">n_head<\/span><span class=\"token operator\">:<\/span> \u6ce8\u610f\u529b\u5934\u7684\u6570\u91cf\uff08\u5982 <span class=\"token number\">8<\/span>\uff09\uff1b\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token keyword\">super<\/span><span class=\"token punctuation\">(<\/span>MultiHeadAttention<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        \r\n        # \u521d\u59cb\u5316\u53c2\u6570\r\n        self<span class=\"token punctuation\">.<\/span>n_head <span class=\"token operator\">=<\/span> n_head\r\n        self<span class=\"token punctuation\">.<\/span>attention <span class=\"token operator\">=<\/span> <span class=\"token function\">ScaleDotProductAttention<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>  # \u4f7f\u7528\u524d\u9762\u5b9a\u4e49\u7684 Self<span class=\"token operator\">-<\/span>Attention\r\n        \r\n        # \u7ebf\u6027\u53d8\u6362\u5c42\r\n        self<span class=\"token punctuation\">.<\/span>w_q <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Linear<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span>       # Query \u53d8\u6362\r\n        self<span class=\"token punctuation\">.<\/span>w_k <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Linear<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span>       # Key \u53d8\u6362\r\n        self<span class=\"token punctuation\">.<\/span>w_v <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Linear<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span>       # Value \u53d8\u6362\r\n        self<span class=\"token punctuation\">.<\/span>fc <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Linear<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span>         # \u8f93\u51fa\u6295\u5f71\u5c42\r\n\r\n    def <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> q<span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">,<\/span> v<span class=\"token punctuation\">,<\/span> mask<span class=\"token operator\">=<\/span>None<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token literal-property property\">Args<\/span><span class=\"token operator\">:<\/span>\r\n            <span class=\"token literal-property property\">q<\/span><span class=\"token operator\">:<\/span> Query \u5f20\u91cf\uff0c<span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">]<\/span>\r\n            <span class=\"token literal-property property\">k<\/span><span class=\"token operator\">:<\/span> Key \u5f20\u91cf\uff0c<span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">]<\/span>\r\n            <span class=\"token literal-property property\">v<\/span><span class=\"token operator\">:<\/span> Value \u5f20\u91cf\uff0c<span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">]<\/span>\r\n            <span class=\"token literal-property property\">mask<\/span><span class=\"token operator\">:<\/span> \u63a9\u7801\u5f20\u91cf\uff0c<span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">]<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        # Step <span class=\"token number\">1<\/span><span class=\"token operator\">:<\/span> \u7ebf\u6027\u53d8\u6362\r\n        q<span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">,<\/span> v <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">w_q<\/span><span class=\"token punctuation\">(<\/span>q<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">w_k<\/span><span class=\"token punctuation\">(<\/span>k<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">w_v<\/span><span class=\"token punctuation\">(<\/span>v<span class=\"token punctuation\">)<\/span>\r\n\r\n        # Step <span class=\"token number\">2<\/span><span class=\"token operator\">:<\/span> \u62c6\u5206\u5230\u591a\u4e2a head\r\n        q <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">split<\/span><span class=\"token punctuation\">(<\/span>q<span class=\"token punctuation\">)<\/span>   # <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_tensor<span class=\"token punctuation\">]<\/span>\r\n        k <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">split<\/span><span class=\"token punctuation\">(<\/span>k<span class=\"token punctuation\">)<\/span>   # <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_tensor<span class=\"token punctuation\">]<\/span>\r\n        v <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">split<\/span><span class=\"token punctuation\">(<\/span>v<span class=\"token punctuation\">)<\/span>   # <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_tensor<span class=\"token punctuation\">]<\/span>\r\n\r\n        # Step <span class=\"token number\">3<\/span><span class=\"token operator\">:<\/span> \u8ba1\u7b97\u6bcf\u4e2a head \u7684 attention\r\n        sa_output<span class=\"token punctuation\">,<\/span> attn_weights <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">attention<\/span><span class=\"token punctuation\">(<\/span>q<span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">,<\/span> v<span class=\"token punctuation\">,<\/span> mask<span class=\"token punctuation\">)<\/span>\r\n\r\n        # Step <span class=\"token number\">4<\/span><span class=\"token operator\">:<\/span> \u62fc\u63a5\u6240\u6709 head \u7684\u8f93\u51fa\r\n        mha_output <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">concat<\/span><span class=\"token punctuation\">(<\/span>sa_output<span class=\"token punctuation\">)<\/span>  # <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">]<\/span>\r\n\r\n        # Step <span class=\"token number\">5<\/span><span class=\"token operator\">:<\/span> \u6700\u7ec8\u7ebf\u6027\u53d8\u6362\r\n        mha_output <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">fc<\/span><span class=\"token punctuation\">(<\/span>mha_output<span class=\"token punctuation\">)<\/span>\r\n\r\n        <span class=\"token keyword\">return<\/span> mha_output<span class=\"token punctuation\">,<\/span> attn_weights\r\n\r\n    def <span class=\"token function\">split<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> tensor<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        \u62c6\u5206\u8f93\u5165\u5f20\u91cf\u4e3a\u591a\u4e2a head\r\n\r\n        <span class=\"token literal-property property\">Args<\/span><span class=\"token operator\">:<\/span>\r\n            <span class=\"token literal-property property\">tensor<\/span><span class=\"token operator\">:<\/span> <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">]<\/span>\r\n        <span class=\"token literal-property property\">Returns<\/span><span class=\"token operator\">:<\/span>\r\n            <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_tensor<span class=\"token punctuation\">]<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_model <span class=\"token operator\">=<\/span> tensor<span class=\"token punctuation\">.<\/span><span class=\"token function\">size<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        d_tensor <span class=\"token operator\">=<\/span> d_model <span class=\"token comment\">\/\/ self.n_head  # \u6bcf\u4e2a head \u7684\u7ef4\u5ea6<\/span>\r\n        \r\n        # reshape <span class=\"token operator\">+<\/span> transpose \u5b9e\u73b0\u62c6\u5206\r\n        tensor <span class=\"token operator\">=<\/span> tensor<span class=\"token punctuation\">.<\/span><span class=\"token function\">view<\/span><span class=\"token punctuation\">(<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>n_head<span class=\"token punctuation\">,<\/span> d_tensor<span class=\"token punctuation\">)<\/span>\r\n        tensor <span class=\"token operator\">=<\/span> tensor<span class=\"token punctuation\">.<\/span><span class=\"token function\">transpose<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span>  # <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_tensor<span class=\"token punctuation\">]<\/span>\r\n        \r\n        <span class=\"token keyword\">return<\/span> tensor\r\n\r\n    def <span class=\"token function\">concat<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> sa_output<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        \u62fc\u63a5\u591a\u4e2a head \u7684\u8f93\u51fa\r\n\r\n        <span class=\"token literal-property property\">Args<\/span><span class=\"token operator\">:<\/span>\r\n            <span class=\"token literal-property property\">sa_output<\/span><span class=\"token operator\">:<\/span> <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_tensor<span class=\"token punctuation\">]<\/span>\r\n        <span class=\"token literal-property property\">Returns<\/span><span class=\"token operator\">:<\/span>\r\n            <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">]<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        batch_size<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_tensor <span class=\"token operator\">=<\/span> sa_output<span class=\"token punctuation\">.<\/span><span class=\"token function\">size<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        d_model <span class=\"token operator\">=<\/span> n_head <span class=\"token operator\">*<\/span> d_tensor\r\n        \r\n        # transpose <span class=\"token operator\">+<\/span> reshape \u5b9e\u73b0\u5408\u5e76\r\n        sa_output <span class=\"token operator\">=<\/span> sa_output<span class=\"token punctuation\">.<\/span><span class=\"token function\">transpose<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">contiguous<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">view<\/span><span class=\"token punctuation\">(<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span>\r\n        \r\n        <span class=\"token keyword\">return<\/span> sa_output<\/code><\/pre>\n<ul id=\"code_id_6\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<li>9.<\/li>\n<li>10.<\/li>\n<li>11.<\/li>\n<li>12.<\/li>\n<li>13.<\/li>\n<li>14.<\/li>\n<li>15.<\/li>\n<li>16.<\/li>\n<li>17.<\/li>\n<li>18.<\/li>\n<li>19.<\/li>\n<li>20.<\/li>\n<li>21.<\/li>\n<li>22.<\/li>\n<li>23.<\/li>\n<li>24.<\/li>\n<li>25.<\/li>\n<li>26.<\/li>\n<li>27.<\/li>\n<li>28.<\/li>\n<li>29.<\/li>\n<li>30.<\/li>\n<li>31.<\/li>\n<li>32.<\/li>\n<li>33.<\/li>\n<li>34.<\/li>\n<li>35.<\/li>\n<li>36.<\/li>\n<li>37.<\/li>\n<li>38.<\/li>\n<li>39.<\/li>\n<li>40.<\/li>\n<li>41.<\/li>\n<li>42.<\/li>\n<li>43.<\/li>\n<li>44.<\/li>\n<li>45.<\/li>\n<li>46.<\/li>\n<li>47.<\/li>\n<li>48.<\/li>\n<li>49.<\/li>\n<li>50.<\/li>\n<li>51.<\/li>\n<li>52.<\/li>\n<li>53.<\/li>\n<li>54.<\/li>\n<li>55.<\/li>\n<li>56.<\/li>\n<li>57.<\/li>\n<li>58.<\/li>\n<li>59.<\/li>\n<li>60.<\/li>\n<li>61.<\/li>\n<li>62.<\/li>\n<li>63.<\/li>\n<li>64.<\/li>\n<li>65.<\/li>\n<li>66.<\/li>\n<li>67.<\/li>\n<li>68.<\/li>\n<li>69.<\/li>\n<li>70.<\/li>\n<li>71.<\/li>\n<li>72.<\/li>\n<li>73.<\/li>\n<li>74.<\/li>\n<li>75.<\/li>\n<li>76.<\/li>\n<li>77.<\/li>\n<li>78.<\/li>\n<li>79.<\/li>\n<li>80.<\/li>\n<li>81.<\/li>\n<li>82.<\/li>\n<li>83.<\/li>\n<li>84.<\/li>\n<li>85.<\/li>\n<li>86.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p>\ud83d\udcca \u793a\u4f8b\u8c03\u7528\u4e0e\u8f93\u51fa\u89e3\u6790\uff1a<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_7\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\"># \u5b9a\u4e49\u53c2\u6570\r\nd_model <span class=\"token operator\">=<\/span> <span class=\"token number\">512<\/span>\r\nn_head <span class=\"token operator\">=<\/span> <span class=\"token number\">8<\/span>\r\nseq_len <span class=\"token operator\">=<\/span> <span class=\"token number\">10<\/span>\r\nbatch_size <span class=\"token operator\">=<\/span> <span class=\"token number\">32<\/span>\r\n\r\n# \u521b\u5efa <span class=\"token constant\">Q<\/span>\u3001<span class=\"token constant\">K<\/span>\u3001<span class=\"token constant\">V<\/span> \u5f20\u91cf\r\n<span class=\"token constant\">Q<\/span> <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">randn<\/span><span class=\"token punctuation\">(<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span>\r\n<span class=\"token constant\">K<\/span> <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">randn<\/span><span class=\"token punctuation\">(<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span>\r\n<span class=\"token constant\">V<\/span> <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">randn<\/span><span class=\"token punctuation\">(<\/span>batch_size<span class=\"token punctuation\">,<\/span> seq_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span>\r\n\r\n# \u6784\u5efa <span class=\"token constant\">MHA<\/span> \u5c42\r\nmha_layer <span class=\"token operator\">=<\/span> <span class=\"token function\">MultiHeadAttention<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token operator\">=<\/span>d_model<span class=\"token punctuation\">,<\/span> n_head<span class=\"token operator\">=<\/span>n_head<span class=\"token punctuation\">)<\/span>\r\n\r\n# \u524d\u5411\u4f20\u64ad\r\noutput<span class=\"token punctuation\">,<\/span> weights <span class=\"token operator\">=<\/span> <span class=\"token function\">mha_layer<\/span><span class=\"token punctuation\">(<\/span><span class=\"token constant\">Q<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token constant\">K<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token constant\">V<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n# \u6253\u5370\u8f93\u51fa\u5f62\u72b6\r\n<span class=\"token function\">print<\/span><span class=\"token punctuation\">(<\/span><span class=\"token string\">\"MHA Output Shape:\"<\/span><span class=\"token punctuation\">,<\/span> output<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">)<\/span>      # <span class=\"token punctuation\">[<\/span><span class=\"token number\">32<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">512<\/span><span class=\"token punctuation\">]<\/span>\r\n<span class=\"token function\">print<\/span><span class=\"token punctuation\">(<\/span><span class=\"token string\">\"Attn Weights Shape:\"<\/span><span class=\"token punctuation\">,<\/span> weights<span class=\"token punctuation\">.<\/span>shape<span class=\"token punctuation\">)<\/span>   # <span class=\"token punctuation\">[<\/span><span class=\"token number\">32<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">8<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">10<\/span><span class=\"token punctuation\">]<\/span><\/code><\/pre>\n<ul id=\"code_id_7\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<li>9.<\/li>\n<li>10.<\/li>\n<li>11.<\/li>\n<li>12.<\/li>\n<li>13.<\/li>\n<li>14.<\/li>\n<li>15.<\/li>\n<li>16.<\/li>\n<li>17.<\/li>\n<li>18.<\/li>\n<li>19.<\/li>\n<li>20.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<table class=\"data-table\" data-transient-attributes=\"class\" data-width=\"655.99px\">\n<colgroup data-id=\"c7104f7d-a2YOmIYk\">\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-QSOu7fny\" \/>\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-pU2GErzo\" \/>\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-CivkFCMg\" \/><\/colgroup>\n<tbody data-id=\"t6d5e859-8p5T9C0L\">\n<tr data-id=\"t31e458f-BxyKxxU0\">\n<td data-id=\"t6267798-kDy6Tgdn\" data-transient-attributes=\"table-cell-selection\">\u53d8\u91cf<\/td>\n<td data-id=\"t6267798-nG3Uvn51\" data-transient-attributes=\"table-cell-selection\">\u5f62\u72b6<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-E7h1XcyH\" data-transient-attributes=\"table-cell-selection\">\u63cf\u8ff0<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-9IFAo1cR\">\n<td data-id=\"t6267798-UT4NjTz6\" data-transient-attributes=\"table-cell-selection\">Q, K, V<\/td>\n<td data-id=\"t6267798-Cnhqwuj2\" data-transient-attributes=\"table-cell-selection\">[32, 10, 512]<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-YOLCjbFC\" data-transient-attributes=\"table-cell-selection\">\u8f93\u5165\u5f20\u91cf\uff0c\u8868\u793a batch=32\uff0cseq_len=10\uff0cd_model=512<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-sw7Gllij\">\n<td data-id=\"t6267798-Db456eko\" data-transient-attributes=\"table-cell-selection\">q, k, v<\/td>\n<td data-id=\"t6267798-qVKTfk9t\" data-transient-attributes=\"table-cell-selection\">[32, 8, 10, 64]<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-Il2xQFBd\" data-transient-attributes=\"table-cell-selection\">\u62c6\u5206\u540e\u7684 Q\/K\/V\uff0c\u6bcf\u4e2a head 64 \u7ef4<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-sNs2nbUD\">\n<td data-id=\"t6267798-qkvmkShq\" data-transient-attributes=\"table-cell-selection\">sa_output<\/td>\n<td data-id=\"t6267798-MxQL9I8W\" data-transient-attributes=\"table-cell-selection\">[32, 8, 10, 64]<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-Usixj3gp\" data-transient-attributes=\"table-cell-selection\">\u6bcf\u4e2a head \u7684 attention \u8f93\u51fa<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-3qEVhU7X\">\n<td data-id=\"t6267798-rxLVtV1i\" data-transient-attributes=\"table-cell-selection\">mha_output<\/td>\n<td data-id=\"t6267798-vHYuCwLE\" data-transient-attributes=\"table-cell-selection\">[32, 10, 512]<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-vO120dzp\" data-transient-attributes=\"table-cell-selection\">\u62fc\u63a5\u540e\u7684\u6700\u7ec8\u8f93\u51fa<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-WUc9g2DM\">\n<td class=\"table-last-column\" data-id=\"t6267798-TBrAM3ID\" data-transient-attributes=\"table-cell-selection\">attn_weights<\/td>\n<td class=\"table-last-column\" data-id=\"t6267798-ymdktCzZ\" data-transient-attributes=\"table-cell-selection\">[32, 8, 10, 10]<\/td>\n<td class=\"table-last-column table-last-row\" data-id=\"t6267798-D7bzuN6s\" data-transient-attributes=\"table-cell-selection\">\u6bcf\u4e2a head \u7684\u6ce8\u610f\u529b\u6743\u91cd\u77e9\u9635<\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<h4>3. Encoder\u7ed3\u6784<\/h4>\n<p>Transformer \u4e2d\u7684 Encoder \u662f\u6574\u4e2a\u6a21\u578b\u4e2d\u7528\u4e8e\u7f16\u7801\u8f93\u5165\u5e8f\u5217\u7684\u90e8\u5206\u3002\u5b83\u7531 N=6 \u4e2a\u76f8\u540c\u7684 encoder block \u5806\u53e0\u800c\u6210\u3002\u6bcf\u4e2a encoder block \u4e3b\u8981\u5305\u542b\u4e24\u4e2a\u5b50\u5c42\uff08sub-layers\uff09\uff1a\u591a\u5934\u81ea\u6ce8\u610f\u529b\u673a\u5236\uff08Multi-Head Self-Attention\uff09\u548c\u4f4d\u7f6e\u5168\u8fde\u63a5\u524d\u9988\u7f51\u7edc\uff08Position-wise Feed Forward Network\uff09\u3002<\/p>\n<p>\u8fd9\u4e24\u4e2a\u5b50\u5c42\u4e4b\u95f4\u90fd\u4f7f\u7528\u4e86 \u6b8b\u5dee\u8fde\u63a5\uff08Residual Connection\uff09 \u548c \u5c42\u5f52\u4e00\u5316\uff08Layer Normalization\uff09\uff0c\u4ee5\u589e\u5f3a\u8bad\u7ec3\u7a33\u5b9a\u6027\u548c\u6a21\u578b\u8868\u8fbe\u80fd\u529b\u3002<\/p>\n<p>\u4e0b\u56fe\u4e2d\u7ea2\u8272\u6846\u9009\u90e8\u5206\u8868\u793a\u4e00\u4e2a\u6807\u51c6\u7684 Encoder Block\uff0c\u5176\u5185\u90e8\u7ed3\u6784\u5982\u4e0b\uff1a<\/p>\n<p>&nbsp;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s7.51cto.com\/oss\/202505\/19\/351b7f5081951821bb6799ebebf44ead526b56.webp\" data-type=\"block\" \/><\/p>\n<p>\u7531\u4ee5\u4e0b\u56db\u4e2a\u5173\u952e\u90e8\u5206\u6784\u6210\u3002<\/p>\n<table class=\"data-table\" data-transient-attributes=\"class\" data-width=\"655.99px\">\n<colgroup data-id=\"c7104f7d-CTJFaTdE\">\n<col span=\"1\" width=\"327.986\" data-id=\"cd89ecb0-Bu0JVnVP\" \/>\n<col span=\"1\" width=\"328.003\" data-id=\"cd89ecb0-yqT0FS8n\" \/><\/colgroup>\n<tbody data-id=\"t6d5e859-9J17sRlD\">\n<tr data-id=\"t31e458f-Nu1GLzLw\">\n<td data-id=\"t6267798-BAlTTopZ\" data-transient-attributes=\"table-cell-selection\">\u6a21\u5757<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-fTYaxXtB\" data-transient-attributes=\"table-cell-selection\">\u63cf\u8ff0<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-N8TpP8il\">\n<td data-id=\"t6267798-dWrPSHec\" data-transient-attributes=\"table-cell-selection\">Multi-Head Attention<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-3kBnWLi1\" data-transient-attributes=\"table-cell-selection\">\u4f7f\u7528\u591a\u4e2a attention head \u5e76\u884c\u63d0\u53d6\u5e8f\u5217\u4e2d\u7684\u4e0d\u540c\u7279\u5f81<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-pgrcRFdy\">\n<td data-id=\"t6267798-vVyDTja1\" data-transient-attributes=\"table-cell-selection\">Add &amp; Norm (1)<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-E3cCDgYS\" data-transient-attributes=\"table-cell-selection\">\u6b8b\u5dee\u8fde\u63a5\uff08Residual Connection\uff09+ \u5c42\u5f52\u4e00\u5316\uff08LayerNorm\uff09<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-1hoEvJZH\">\n<td data-id=\"t6267798-BUJ14kxu\" data-transient-attributes=\"table-cell-selection\">Position-wise FeedForward<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-egRhZZaN\" data-transient-attributes=\"table-cell-selection\">\u4e24\u5c42\u7ebf\u6027\u53d8\u6362 + \u6fc0\u6d3b\u51fd\u6570\uff0c\u5bf9\u6bcf\u4e2a\u8bcd\u72ec\u7acb\u5efa\u6a21<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-QNxf3kty\">\n<td class=\"table-last-column\" data-id=\"t6267798-LJf4PFoQ\" data-transient-attributes=\"table-cell-selection\">Add &amp; Norm (2)<\/td>\n<td class=\"table-last-column table-last-row\" data-id=\"t6267798-0YgFOcBk\" data-transient-attributes=\"table-cell-selection\">\u540c\u6837\u5e94\u7528\u6b8b\u5dee\u8fde\u63a5\u548c LayerNorm<\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<p><strong>(1) \u6bcf\u4e00\u5c42\u7684\u8ba1\u7b97\u6d41\u7a0b\uff08\u4ee5\u5355\u4e2a encoder block \u4e3a\u4f8b\uff09<\/strong><\/p>\n<p>\u2460 \u591a\u5934\u81ea\u6ce8\u610f\u529b\u673a\u5236\uff08Multi-Head Self-Attention\uff09<\/p>\n<ul data-id=\"u738a58b-kG1F2ygt\">\n<li data-id=\"ld70c578-7ZxQjtHP\">\u8f93\u5165\uff1a\u5d4c\u5165\u540e\u7684\u5f20\u91cf\u00a0<img decoding=\"async\" src=\"https:\/\/s6.51cto.com\/oss\/202505\/19\/871148680683fa1497c98941c2dc7755bf2d62.png\" data-type=\"inline\" \/><\/li>\n<li data-id=\"ld70c578-bCfEkeWf\">\u8f93\u51fa\uff1a\u901a\u8fc7\u81ea\u6ce8\u610f\u529b\u52a0\u6743\u540e\u7684\u65b0\u5f20\u91cf sa_output<\/li>\n<\/ul>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_8\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\">sa_output<span class=\"token punctuation\">,<\/span> attn_weights <span class=\"token operator\">=<\/span> <span class=\"token function\">MultiHeadAttention<\/span><span class=\"token punctuation\">(<\/span>q<span class=\"token operator\">=<\/span>x<span class=\"token punctuation\">,<\/span> k<span class=\"token operator\">=<\/span>x<span class=\"token punctuation\">,<\/span> v<span class=\"token operator\">=<\/span>x<span class=\"token punctuation\">,<\/span> mask<span class=\"token operator\">=<\/span>src_mask<span class=\"token punctuation\">)<\/span><\/code><\/pre>\n<ul id=\"code_id_8\" class=\"pre-numbering\">\n<li>1.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p>\u5176\u4e2d\uff0c<\/p>\n<ul data-id=\"u738a58b-5MQ4GOlk\">\n<li data-id=\"ld70c578-qkRD2Yvs\">Query\u3001Key \u548c Value \u6765\u81ea\u540c\u4e00\u4e2a\u8f93\u5165\u00a0\uff1b<\/li>\n<li data-id=\"ld70c578-e2UrXwsZ\">\u53ef\u9009 mask \u901a\u5e38\u7528\u4e8e\u5c4f\u853d padding token \u6216\u63a7\u5236\u4f4d\u7f6e\u611f\u77e5\u8303\u56f4\u3002<\/li>\n<\/ul>\n<p>\u2461\u00a0\u6b8b\u5dee\u8fde\u63a5 + \u5c42\u5f52\u4e00\u5316\uff08Sublayer 1\uff09<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_9\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\">x <span class=\"token operator\">=<\/span> x <span class=\"token operator\">+<\/span> <span class=\"token function\">dropout<\/span><span class=\"token punctuation\">(<\/span>sa_output<span class=\"token punctuation\">)<\/span>\r\nx <span class=\"token operator\">=<\/span> <span class=\"token function\">layer_norm<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><\/code><\/pre>\n<ul id=\"code_id_9\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<ul data-id=\"u738a58b-ar7S69Xd\">\n<li data-id=\"ld70c578-PQsj4e3Y\">\u5e94\u7528\u6b8b\u5dee\u6620\u5c04\uff0c\u7f13\u89e3\u68af\u5ea6\u6d88\u5931\u95ee\u9898\uff1b<\/li>\n<li data-id=\"ld70c578-5rc7jDBE\">\u4f7f\u7528 LayerNorm \u5bf9\u6bcf\u4e2a token \u7684\u5411\u91cf\u8fdb\u884c\u6807\u51c6\u5316\u5904\u7406\uff1b<\/li>\n<li data-id=\"ld70c578-Y4iIGg5h\">\u6574\u4f53\u76ee\u6807\uff1a\u63d0\u5347\u6a21\u578b\u8868\u8fbe\u80fd\u529b\u4e0e\u8bad\u7ec3\u7a33\u5b9a\u6027\u3002<\/li>\n<\/ul>\n<p>\u2462 \u4f4d\u7f6e\u5168\u8fde\u63a5\u524d\u9988\u7f51\u7edc\uff08Position-wise FeedForward\uff09<\/p>\n<p>\u5b9a\u4e49\u4e3a\uff1a<\/p>\n<p><img data-dominant-color=\"ececec\" data-has-transparency=\"false\" style=\"--dominant-color: #ececec;\" loading=\"lazy\" decoding=\"async\" class=\"not-transparent alignnone size-full wp-image-25924\" src=\"https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/4862f4953a576c49aa1453afd02d4e42224ed0.png\" width=\"288\" height=\"34\" srcset=\"https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/4862f4953a576c49aa1453afd02d4e42224ed0.png 288w, https:\/\/aiforumimage.oss-cn-shanghai.aliyuncs.com\/wp-content\/uploads\/2025\/05\/4862f4953a576c49aa1453afd02d4e42224ed0-150x18.png 150w\" sizes=\"auto, (max-width: 288px) 100vw, 288px\" \/><\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_10\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\">nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Sequential<\/span><span class=\"token punctuation\">(<\/span>\r\n    nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Linear<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span>\r\n    nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">ReLU<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span>\r\n    nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Linear<\/span><span class=\"token punctuation\">(<\/span>d_ff<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span>\r\n    nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Dropout<\/span><span class=\"token punctuation\">(<\/span>drop_prob<span class=\"token punctuation\">)<\/span>\r\n<span class=\"token punctuation\">)<\/span><\/code><\/pre>\n<ul id=\"code_id_10\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<ul data-id=\"u738a58b-ILZzGXCY\">\n<li data-id=\"ld70c578-3wUlhka6\">d_model\uff1a\u6a21\u578b\u9690\u5c42\u7ef4\u5ea6\uff08\u5982 512\uff09\uff1b<\/li>\n<li data-id=\"ld70c578-S5kLW4M8\">d_ff\uff1aFeedForward \u7f51\u7edc\u4e2d\u95f4\u7ef4\u5ea6\uff08\u5982 2048\uff09\uff1b<\/li>\n<li data-id=\"ld70c578-Y4o1PvB9\">ReLU \u5bfc\u81f4\u975e\u7ebf\u6027\u66f4\u5f3a\u7684\u8bed\u4e49\u8868\u8fbe\u3002<\/li>\n<\/ul>\n<p>\u2463\u00a0\u518d\u6b21\u6b8b\u5dee\u8fde\u63a5 + \u5c42\u5f52\u4e00\u5316\uff08Sublayer 2\uff09<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_11\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\">x <span class=\"token operator\">=<\/span> x <span class=\"token operator\">+<\/span> <span class=\"token function\">dropout<\/span><span class=\"token punctuation\">(<\/span>ffn_output<span class=\"token punctuation\">)<\/span>\r\nx <span class=\"token operator\">=<\/span> <span class=\"token function\">layer_norm<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><\/code><\/pre>\n<ul id=\"code_id_11\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<ul data-id=\"u738a58b-SLmdwZeU\">\n<li data-id=\"ld70c578-oHbyXyHc\">\u4fdd\u8bc1\u6a21\u578b\u5728\u7ecf\u8fc7\u590d\u6742\u53d8\u6362\u540e\u4ecd\u80fd\u4fdd\u7559\u539f\u59cb\u4fe1\u606f\uff1b<\/li>\n<li data-id=\"ld70c578-GYGKEedz\">\u8fbe\u6210\u5bf9\u4e0a\u4e0b\u6587\u611f\u77e5\u8868\u793a\u7684\u7a33\u5b9a\u5b66\u4e60\u3002<\/li>\n<\/ul>\n<p><strong>(2) \u7ef4\u5ea6\u53d8\u5316\u8bf4\u660e\uff08\u8f93\u5165\u8f93\u51fa\u4fdd\u6301\u4e00\u81f4\uff09<\/strong><\/p>\n<p>\u65e0\u8bba\u7ecf\u8fc7\u591a\u5c11\u5c42 Encoder block\uff0c\u6bcf\u4e2a block \u7684\u8f93\u5165\u4e0e\u8f93\u51fa\u5f62\u72b6\u59cb\u7ec8\u4e00\u81f4\uff1a<\/p>\n<table class=\"data-table\" data-transient-attributes=\"class\" data-width=\"655.99px\">\n<colgroup data-id=\"c7104f7d-TGcm2eXi\">\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-EeBEXb8O\" \/>\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-zdiPl4rq\" \/>\n<col span=\"1\" width=\"218.663\" data-id=\"cd89ecb0-vX3rIdl7\" \/><\/colgroup>\n<tbody data-id=\"t6d5e859-YOPWC5j3\">\n<tr data-id=\"t31e458f-aIhVymI3\">\n<td data-id=\"t6267798-JDvMxFtt\" data-transient-attributes=\"table-cell-selection\">\u5f20\u91cf<\/td>\n<td data-id=\"t6267798-oACZFmCv\" data-transient-attributes=\"table-cell-selection\">\u5f62\u72b6<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-2hbxyzpJ\" data-transient-attributes=\"table-cell-selection\">\u63cf\u8ff0<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-MBUD6Nij\">\n<td data-id=\"t6267798-MXfyDkOG\" data-transient-attributes=\"table-cell-selection\">\u8f93\u5165<\/td>\n<td data-id=\"t6267798-OT43XtyY\" data-transient-attributes=\"table-cell-selection\">[batch_size, seq_len, d_model]<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-OC6MVAJa\" data-transient-attributes=\"table-cell-selection\">\u6279\u6b21\u5927\u5c0f \u00d7 \u5e8f\u5217\u957f\u5ea6 \u00d7 \u6a21\u578b\u7ef4\u5ea6<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-OUfybGHe\">\n<td data-id=\"t6267798-3nC4IFgj\" data-transient-attributes=\"table-cell-selection\">MHA \u8f93\u51fa<\/td>\n<td data-id=\"t6267798-NBjCNeYH\" data-transient-attributes=\"table-cell-selection\">[batch_size, seq_len, d_model]<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-FcccWWiK\" data-transient-attributes=\"table-cell-selection\">\u6ce8\u610f\u529b\u52a0\u6743\u540e\u7684\u8f93\u51fa<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-x1fLakpB\">\n<td data-id=\"t6267798-PdDxvYwt\" data-transient-attributes=\"table-cell-selection\">FFN \u8f93\u51fa<\/td>\n<td data-id=\"t6267798-4pkvI9lF\" data-transient-attributes=\"table-cell-selection\">[batch_size, seq_len, d_model]<\/td>\n<td class=\"table-last-row\" data-id=\"t6267798-YdbzEgjm\" data-transient-attributes=\"table-cell-selection\">\u6bcf\u4e2a Token \u7684\u524d\u9988\u7f51\u7edc\u8f93\u51fa<\/td>\n<\/tr>\n<tr data-id=\"t31e458f-JPaOzl1r\">\n<td class=\"table-last-column\" data-id=\"t6267798-GartoWor\" data-transient-attributes=\"table-cell-selection\">\u6700\u7ec8\u8f93\u51fa<\/td>\n<td class=\"table-last-column\" data-id=\"t6267798-2S67xHqD\" data-transient-attributes=\"table-cell-selection\">[batch_size, seq_len, d_model]<\/td>\n<td class=\"table-last-column table-last-row\" data-id=\"t6267798-tPTNs6EJ\" data-transient-attributes=\"table-cell-selection\">\u7ecf\u8fc7\u4e24\u6b21 Sublayer \u540e\u4ecd\u7136\u4fdd\u6301\u76f8\u540c\u7ef4\u5ea6<\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<p><strong>(3) PyTorch \u6a21\u5757\u5c01\u88c5\u793a\u4f8b<\/strong><\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_12\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\"><span class=\"token keyword\">import<\/span> torch\r\n<span class=\"token keyword\">import<\/span> torch<span class=\"token punctuation\">.<\/span>nn <span class=\"token keyword\">as<\/span> nn\r\n\r\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">EncoderBlock<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n    def <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> drop_prob<span class=\"token operator\">=<\/span><span class=\"token number\">0.1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token literal-property property\">Args<\/span><span class=\"token operator\">:<\/span>\r\n            <span class=\"token literal-property property\">d_model<\/span><span class=\"token operator\">:<\/span> \u5d4c\u5165\u7ef4\u5ea6\uff08\u4f8b\u5982 <span class=\"token number\">512<\/span>\uff09\r\n            <span class=\"token literal-property property\">n_head<\/span><span class=\"token operator\">:<\/span> \u591a\u5934\u6570\u91cf\uff08\u901a\u5e38\u8bbe\u4e3a <span class=\"token number\">8<\/span>\uff09\r\n            <span class=\"token literal-property property\">d_ff<\/span><span class=\"token operator\">:<\/span> Feed Forward \u7f51\u7edc\u4e2d\u95f4\u7ef4\u5ea6\uff08\u901a\u5e38\u4e3a <span class=\"token number\">2048<\/span>\uff09\r\n            <span class=\"token literal-property property\">drop_prob<\/span><span class=\"token operator\">:<\/span> Dropout \u6982\u7387\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token keyword\">super<\/span><span class=\"token punctuation\">(<\/span>EncoderBlock<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>attention <span class=\"token operator\">=<\/span> <span class=\"token function\">MultiHeadAttention<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>norm1 <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">LayerNorm<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>ffn <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Sequential<\/span><span class=\"token punctuation\">(<\/span>\r\n            nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Linear<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span>\r\n            nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">ReLU<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span>\r\n            nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Linear<\/span><span class=\"token punctuation\">(<\/span>d_ff<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span>\r\n            nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Dropout<\/span><span class=\"token punctuation\">(<\/span>drop_prob<span class=\"token punctuation\">)<\/span>\r\n        <span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>norm2 <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">LayerNorm<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>dropout <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Dropout<\/span><span class=\"token punctuation\">(<\/span>drop_prob<span class=\"token punctuation\">)<\/span>\r\n\r\n    def <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token operator\">=<\/span>None<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        # Step <span class=\"token number\">1<\/span><span class=\"token operator\">:<\/span> Multi<span class=\"token operator\">-<\/span>Head Self<span class=\"token operator\">-<\/span>Attention\r\n        sa_output<span class=\"token punctuation\">,<\/span> _ <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">attention<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">)<\/span>\r\n        x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">norm1<\/span><span class=\"token punctuation\">(<\/span>x <span class=\"token operator\">+<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">dropout<\/span><span class=\"token punctuation\">(<\/span>sa_output<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n        # Step <span class=\"token number\">2<\/span><span class=\"token operator\">:<\/span> Position<span class=\"token operator\">-<\/span>wise FeedForward\r\n        ffn_output <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">ffn<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span>\r\n        x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">norm2<\/span><span class=\"token punctuation\">(<\/span>x <span class=\"token operator\">+<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">dropout<\/span><span class=\"token punctuation\">(<\/span>ffn_output<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n        <span class=\"token keyword\">return<\/span> x<\/code><\/pre>\n<ul id=\"code_id_12\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<li>9.<\/li>\n<li>10.<\/li>\n<li>11.<\/li>\n<li>12.<\/li>\n<li>13.<\/li>\n<li>14.<\/li>\n<li>15.<\/li>\n<li>16.<\/li>\n<li>17.<\/li>\n<li>18.<\/li>\n<li>19.<\/li>\n<li>20.<\/li>\n<li>21.<\/li>\n<li>22.<\/li>\n<li>23.<\/li>\n<li>24.<\/li>\n<li>25.<\/li>\n<li>26.<\/li>\n<li>27.<\/li>\n<li>28.<\/li>\n<li>29.<\/li>\n<li>30.<\/li>\n<li>31.<\/li>\n<li>32.<\/li>\n<li>33.<\/li>\n<li>34.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p><strong>(4) \u5c01\u88c5\u6574\u4e2a Encoder \u6a21\u5757<\/strong><\/p>\n<p>\u6709\u4e86 EncoderBlock \u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u5b83 \u91cd\u590d N \u6b21 \u6784\u5efa\u5b8c\u6574\u7684 Encoder\uff1a<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_13\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\"><span class=\"token keyword\">class<\/span> <span class=\"token class-name\">TransformerEncoder<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n    def <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> num_layers<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> drop_prob<span class=\"token operator\">=<\/span><span class=\"token number\">0.1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token literal-property property\">Args<\/span><span class=\"token operator\">:<\/span>\r\n            <span class=\"token literal-property property\">num_layers<\/span><span class=\"token operator\">:<\/span> encoder block \u5806\u53e0\u5c42\u6570\uff08\u539f\u8bba\u6587\u4e3a <span class=\"token number\">6<\/span>\uff09\r\n            <span class=\"token literal-property property\">d_model<\/span><span class=\"token operator\">:<\/span> \u6a21\u578b\u7ef4\u5ea6\uff08\u5982 <span class=\"token number\">512<\/span>\uff09\r\n            <span class=\"token literal-property property\">n_head<\/span><span class=\"token operator\">:<\/span> \u6ce8\u610f\u529b\u5934\u6570\uff08\u5982 <span class=\"token number\">8<\/span>\uff09\r\n            <span class=\"token literal-property property\">d_ff<\/span><span class=\"token operator\">:<\/span> FeedForward \u7f51\u7edc\u7ef4\u5ea6\uff08\u5982 <span class=\"token number\">2048<\/span>\uff09\r\n        <span class=\"token string\">\"\"<\/span>\"\r\n        <span class=\"token keyword\">super<\/span><span class=\"token punctuation\">(<\/span>TransformerEncoder<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>blocks <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">ModuleList<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>\r\n            <span class=\"token function\">EncoderBlock<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> d_ff<span class=\"token punctuation\">,<\/span> drop_prob<span class=\"token punctuation\">)<\/span>\r\n            <span class=\"token keyword\">for<\/span> _ <span class=\"token keyword\">in<\/span> <span class=\"token function\">range<\/span><span class=\"token punctuation\">(<\/span>num_layers<span class=\"token punctuation\">)<\/span>\r\n        <span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n    def <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> x<span class=\"token punctuation\">,<\/span> mask<span class=\"token operator\">=<\/span>None<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token keyword\">for<\/span> block <span class=\"token keyword\">in<\/span> self<span class=\"token punctuation\">.<\/span>blocks<span class=\"token operator\">:<\/span>\r\n            x <span class=\"token operator\">=<\/span> <span class=\"token function\">block<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">,<\/span> mask<span class=\"token punctuation\">)<\/span>\r\n        <span class=\"token keyword\">return<\/span> x<\/code><\/pre>\n<ul id=\"code_id_13\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<li>9.<\/li>\n<li>10.<\/li>\n<li>11.<\/li>\n<li>12.<\/li>\n<li>13.<\/li>\n<li>14.<\/li>\n<li>15.<\/li>\n<li>16.<\/li>\n<li>17.<\/li>\n<li>18.<\/li>\n<li>19.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p>\ud83d\udcc8 \u793a\u4f8b\u8c03\u7528<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_14\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-plain\" tabindex=\"0\"><code class=\"language-plain\"># \u521b\u5efa\u8f93\u5165\u5f20\u91cf\r\nx = torch.randn(batch_size=32, seq_len=20, d_model=512)  # [32, 20, 512]\r\n\r\n# \u6784\u5efa Encoder\r\nencoder = TransformerEncoder(num_layers=6, d_model=512, n_head=8, d_ff=2048)\r\noutput = encoder(x)\r\n\r\nprint(\"Encoder \u8f93\u51fa\u5f62\u72b6:\", output.shape)  # [32, 20, 512]<\/code><\/pre>\n<ul id=\"code_id_14\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<h4>4. Decoder\u7ed3\u6784<\/h4>\n<p>Decoder\u662fTransformer\u67b6\u6784\u4e2d\u7528\u4e8e\u751f\u6210\u8f93\u51fa\u5e8f\u5217\u7684\u90e8\u5206\u3002\u4e0eEncoder\u7c7b\u4f3c\uff0c\u5b83\u7531N=6\u4e2a\u76f8\u540c\u7684Decoder block\u5806\u53e0\u800c\u6210\uff0c\u4f46\u7ed3\u6784\u66f4\u4e3a\u590d\u6742\u3002<\/p>\n<p><strong>(1) Decoder Block \u7684\u6838\u5fc3\u7ec4\u4ef6<\/strong><\/p>\n<p>\u4e00\u4e2a\u6807\u51c6\u7684Decoder block\u5305\u542b\u4e09\u4e2a\u4e3b\u8981\u5b50\u5c42\uff1a<\/p>\n<ul data-id=\"u738a58b-Ju3Cf646\">\n<li data-id=\"ld70c578-fRsk7K5d\">Masked Multi-Head Self-Attention<\/li>\n<li data-id=\"ld70c578-6NcjIugw\">Encoder-Decoder Attention<\/li>\n<li data-id=\"ld70c578-D33Pb8tb\">Position-wise Feed Forward Network<\/li>\n<\/ul>\n<p>\u6bcf\u4e2a\u5b50\u5c42\u540e\u9762\u90fd\u8ddf\u968f \u6b8b\u5dee\u8fde\u63a5\uff08Residual Connection\uff09 \u548c \u5c42\u5f52\u4e00\u5316\uff08Layer Normalization\uff09\u3002<\/p>\n<p>\u5982\u4e0b\u56fe\u53f3\u4fa7\u7ea2\u6846\u8868\u793a\u4e00\u4e2a\u6807\u51c6\u7684 Decoder Block\u3002<\/p>\n<p>&nbsp;<\/p>\n<p><img decoding=\"async\" src=\"https:\/\/s2.51cto.com\/oss\/202505\/19\/61420c471428ff89cd7919c6b20a57625063f5.webp\" data-type=\"block\" \/><\/p>\n<p>&nbsp;<\/p>\n<p>\u2460\u300cMasked Multi-Head Self-Attention\u300d<\/p>\n<p>\u8fd9\u662fDecoder\u7684\u7b2c\u4e00\u4e2a\u6ce8\u610f\u529b\u673a\u5236\uff0c\u7528\u4e8e\u5904\u7406\u76ee\u6807\u8bed\u8a00\u7684\u8f93\u5165\u5e8f\u5217\uff08\u5373\u89e3\u7801\u5668\u81ea\u8eab\u7684\u8f93\u5165\uff09\u3002<\/p>\n<ul data-id=\"u738a58b-dUZPv9CO\">\n<li data-id=\"ld70c578-nDBDjwHf\">\u4f7f\u7528 masking \u6280\u672f \u9632\u6b62\u5728\u9884\u6d4b\u5f53\u524d\u8bcd\u65f6\u770b\u5230\u672a\u6765\u7684\u8bcd\uff0c\u4fdd\u6301\u56e0\u679c\u5173\u7cfb\u3002<\/li>\n<li data-id=\"ld70c578-Um8LXDNr\">\u5b9e\u73b0\u65b9\u5f0f\uff1a\u901a\u8fc7 trg_mask \u5c4f\u853d\u672a\u6765\u4fe1\u606f\u3002<\/li>\n<\/ul>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_15\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\">x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">mha1<\/span><span class=\"token punctuation\">(<\/span>q<span class=\"token operator\">=<\/span>dec_out<span class=\"token punctuation\">,<\/span> k<span class=\"token operator\">=<\/span>dec_out<span class=\"token punctuation\">,<\/span> v<span class=\"token operator\">=<\/span>dec_out<span class=\"token punctuation\">,<\/span> mask<span class=\"token operator\">=<\/span>trg_mask<span class=\"token punctuation\">)<\/span><\/code><\/pre>\n<ul id=\"code_id_15\" class=\"pre-numbering\">\n<li>1.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p>\u2461\u300cEncoder-Decoder Attention\u300d<\/p>\n<p>\u8fd9\u662f Decoder \u7684\u7b2c\u4e8c\u4e2a\u6ce8\u610f\u529b\u673a\u5236\uff0c\u7528\u4e8e\u5c06 Encoder \u7684\u8f93\u51fa\u4fe1\u606f\u878d\u5408\u5230 Decoder \u4e2d\u3002<\/p>\n<ul data-id=\"u738a58b-VJarNUdM\">\n<li data-id=\"ld70c578-eGeHZrRd\">Query (Q) \u6765\u81ea\u4e0a\u4e00\u5c42 Decoder \u7684\u8f93\u51fa\uff1b<\/li>\n<li data-id=\"ld70c578-Pk3ARrg5\">Key (K) \u548c Value (V) \u6765\u81ea Encoder \u7684\u8f93\u51fa\uff1b<\/li>\n<li data-id=\"ld70c578-WYKIfrJC\">\u8fd9\u6837 Decoder \u5728\u751f\u6210\u6bcf\u4e2a\u8bcd\u65f6\u90fd\u80fd\u5173\u6ce8\u5230\u6574\u4e2a\u8f93\u5165\u53e5\u5b50\u7684\u4fe1\u606f\u3002<\/li>\n<\/ul>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_16\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\">x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">mha2<\/span><span class=\"token punctuation\">(<\/span>q<span class=\"token operator\">=<\/span>x<span class=\"token punctuation\">,<\/span> k<span class=\"token operator\">=<\/span>enc_out<span class=\"token punctuation\">,<\/span> v<span class=\"token operator\">=<\/span>enc_out<span class=\"token punctuation\">,<\/span> mask<span class=\"token operator\">=<\/span>src_mask<span class=\"token punctuation\">)<\/span><\/code><\/pre>\n<ul id=\"code_id_16\" class=\"pre-numbering\">\n<li>1.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p>\u2462 \u300cPosition-wise Feed Forward Network\u300d<\/p>\n<p>\u8fd9\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u4e24\u5c42\u5168\u8fde\u63a5\u7f51\u7edc\uff0c\u5bf9\u6bcf\u4e2a\u4f4d\u7f6e\u7684\u5411\u91cf\u8fdb\u884c\u975e\u7ebf\u6027\u53d8\u6362\u3002<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_17\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\">x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">ffn<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><\/code><\/pre>\n<ul id=\"code_id_17\" class=\"pre-numbering\">\n<li>1.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p>\u2463\u300c\u6b8b\u5dee\u8fde\u63a5 + \u5c42\u5f52\u4e00\u5316\uff08Add &amp; Norm\uff09\u300d<\/p>\n<p>\u6bcf\u4e2a\u5b50\u5c42\u90fd\u5e94\u7528\uff1a<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_18\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\">x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">ln<\/span><span class=\"token punctuation\">(<\/span>x_residual <span class=\"token operator\">+<\/span> <span class=\"token function\">dropout<\/span><span class=\"token punctuation\">(<\/span>sublayer_output<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><\/code><\/pre>\n<ul id=\"code_id_18\" class=\"pre-numbering\">\n<li>1.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<ul data-id=\"u738a58b-5TuzAQon\">\n<li data-id=\"ld70c578-GVtVBzLq\">\u63d0\u5347\u8bad\u7ec3\u7a33\u5b9a\u6027\uff1b<\/li>\n<li data-id=\"ld70c578-Tdxh8dnJ\">\u7f13\u89e3\u68af\u5ea6\u6d88\u5931\u95ee\u9898\u3002<\/li>\n<\/ul>\n<p>\u2464\u300cDecoder\u7684\u5b8c\u6574\u5b9e\u73b0\u300d<\/p>\n<p>DecoderLayer\u7c7b\uff1a<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_19\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\"><span class=\"token keyword\">class<\/span> <span class=\"token class-name\">DecoderLayer<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n    def <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> ffn_hidden<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> drop_prob<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token keyword\">super<\/span><span class=\"token punctuation\">(<\/span>DecoderLayer<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        # \u7b2c\u4e00\u4e2a Multi<span class=\"token operator\">-<\/span>Head Attention<span class=\"token operator\">:<\/span> Masked Self<span class=\"token operator\">-<\/span>Attention\r\n        self<span class=\"token punctuation\">.<\/span>mha1 <span class=\"token operator\">=<\/span> <span class=\"token function\">MultiHeadAttention<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>ln1 <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">LayerNorm<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>dropout1 <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Dropout<\/span><span class=\"token punctuation\">(<\/span>p<span class=\"token operator\">=<\/span>drop_prob<span class=\"token punctuation\">)<\/span>\r\n\r\n        # \u7b2c\u4e8c\u4e2a Multi<span class=\"token operator\">-<\/span>Head Attention<span class=\"token operator\">:<\/span> Encoder<span class=\"token operator\">-<\/span>Decoder Attention\r\n        self<span class=\"token punctuation\">.<\/span>mha2 <span class=\"token operator\">=<\/span> <span class=\"token function\">MultiHeadAttention<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>ln2 <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">LayerNorm<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>dropout2 <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Dropout<\/span><span class=\"token punctuation\">(<\/span>p<span class=\"token operator\">=<\/span>drop_prob<span class=\"token punctuation\">)<\/span>\r\n\r\n        # \u524d\u9988\u7f51\u7edc\r\n        self<span class=\"token punctuation\">.<\/span>ffn <span class=\"token operator\">=<\/span> <span class=\"token function\">PositionwiseFeedForward<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> ffn_hidden<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>ln3 <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">LayerNorm<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">)<\/span>\r\n        self<span class=\"token punctuation\">.<\/span>dropout3 <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Dropout<\/span><span class=\"token punctuation\">(<\/span>p<span class=\"token operator\">=<\/span>drop_prob<span class=\"token punctuation\">)<\/span>\r\n\r\n    def <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> dec_out<span class=\"token punctuation\">,<\/span> enc_out<span class=\"token punctuation\">,<\/span> trg_mask<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        x_residual1 <span class=\"token operator\">=<\/span> dec_out\r\n        # Step <span class=\"token number\">1<\/span><span class=\"token operator\">:<\/span> Masked Self<span class=\"token operator\">-<\/span>Attention\r\n        x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">mha1<\/span><span class=\"token punctuation\">(<\/span>q<span class=\"token operator\">=<\/span>dec_out<span class=\"token punctuation\">,<\/span> k<span class=\"token operator\">=<\/span>dec_out<span class=\"token punctuation\">,<\/span> v<span class=\"token operator\">=<\/span>dec_out<span class=\"token punctuation\">,<\/span> mask<span class=\"token operator\">=<\/span>trg_mask<span class=\"token punctuation\">)<\/span>\r\n        x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">ln1<\/span><span class=\"token punctuation\">(<\/span>x_residual1 <span class=\"token operator\">+<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">dropout1<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n        <span class=\"token keyword\">if<\/span> enc_out is not None<span class=\"token operator\">:<\/span>\r\n            # Step <span class=\"token number\">2<\/span><span class=\"token operator\">:<\/span> Encoder<span class=\"token operator\">-<\/span>Decoder Attention\r\n            x_residual2 <span class=\"token operator\">=<\/span> x\r\n            x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">mha2<\/span><span class=\"token punctuation\">(<\/span>q<span class=\"token operator\">=<\/span>x<span class=\"token punctuation\">,<\/span> k<span class=\"token operator\">=<\/span>enc_out<span class=\"token punctuation\">,<\/span> v<span class=\"token operator\">=<\/span>enc_out<span class=\"token punctuation\">,<\/span> mask<span class=\"token operator\">=<\/span>src_mask<span class=\"token punctuation\">)<\/span>\r\n            x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">ln2<\/span><span class=\"token punctuation\">(<\/span>x_residual2 <span class=\"token operator\">+<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">dropout2<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n        # Step <span class=\"token number\">3<\/span><span class=\"token operator\">:<\/span> Position<span class=\"token operator\">-<\/span>wise Feed Forward\r\n        x_residual3 <span class=\"token operator\">=<\/span> x\r\n        x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">ffn<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span>\r\n        x <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">ln3<\/span><span class=\"token punctuation\">(<\/span>x_residual3 <span class=\"token operator\">+<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">dropout3<\/span><span class=\"token punctuation\">(<\/span>x<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n        <span class=\"token keyword\">return<\/span> x<\/code><\/pre>\n<ul id=\"code_id_19\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<li>9.<\/li>\n<li>10.<\/li>\n<li>11.<\/li>\n<li>12.<\/li>\n<li>13.<\/li>\n<li>14.<\/li>\n<li>15.<\/li>\n<li>16.<\/li>\n<li>17.<\/li>\n<li>18.<\/li>\n<li>19.<\/li>\n<li>20.<\/li>\n<li>21.<\/li>\n<li>22.<\/li>\n<li>23.<\/li>\n<li>24.<\/li>\n<li>25.<\/li>\n<li>26.<\/li>\n<li>27.<\/li>\n<li>28.<\/li>\n<li>29.<\/li>\n<li>30.<\/li>\n<li>31.<\/li>\n<li>32.<\/li>\n<li>33.<\/li>\n<li>34.<\/li>\n<li>35.<\/li>\n<li>36.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<p>Decoder \u7c7b\uff1a<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_20\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\"><span class=\"token keyword\">class<\/span> <span class=\"token class-name\">Decoder<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n    def <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> dec_voc_size<span class=\"token punctuation\">,<\/span> max_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">,<\/span> ffn_hidden<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> n_layers<span class=\"token punctuation\">,<\/span> drop_prob<span class=\"token punctuation\">,<\/span> device<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token keyword\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        # \u8f93\u5165\u5d4c\u5165 <span class=\"token operator\">+<\/span> \u4f4d\u7f6e\u7f16\u7801\r\n        self<span class=\"token punctuation\">.<\/span>emb <span class=\"token operator\">=<\/span> <span class=\"token function\">TransformerEmbedding<\/span><span class=\"token punctuation\">(<\/span>\r\n            d_model<span class=\"token operator\">=<\/span>d_model<span class=\"token punctuation\">,<\/span>\r\n            drop_prob<span class=\"token operator\">=<\/span>drop_prob<span class=\"token punctuation\">,<\/span>\r\n            max_len<span class=\"token operator\">=<\/span>max_len<span class=\"token punctuation\">,<\/span>\r\n            vocab_size<span class=\"token operator\">=<\/span>dec_voc_size<span class=\"token punctuation\">,<\/span>\r\n            device<span class=\"token operator\">=<\/span>device\r\n        <span class=\"token punctuation\">)<\/span>\r\n\r\n        # \u5806\u53e0\u591a\u4e2a Decoder Layer\r\n        self<span class=\"token punctuation\">.<\/span>layers <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">ModuleList<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">[<\/span>\r\n            <span class=\"token function\">DecoderLayer<\/span><span class=\"token punctuation\">(<\/span>\r\n                d_model<span class=\"token operator\">=<\/span>d_model<span class=\"token punctuation\">,<\/span>\r\n                ffn_hidden<span class=\"token operator\">=<\/span>ffn_hidden<span class=\"token punctuation\">,<\/span>\r\n                n_head<span class=\"token operator\">=<\/span>n_head<span class=\"token punctuation\">,<\/span>\r\n                drop_prob<span class=\"token operator\">=<\/span>drop_prob\r\n            <span class=\"token punctuation\">)<\/span> <span class=\"token keyword\">for<\/span> _ <span class=\"token keyword\">in<\/span> <span class=\"token function\">range<\/span><span class=\"token punctuation\">(<\/span>n_layers<span class=\"token punctuation\">)<\/span>\r\n        <span class=\"token punctuation\">]<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n        # \u6700\u7ec8\u8f93\u51fa\u5c42\uff1a\u6620\u5c04\u5230\u76ee\u6807\u8bcd\u6c47\u8868\u5927\u5c0f\r\n        self<span class=\"token punctuation\">.<\/span>linear <span class=\"token operator\">=<\/span> nn<span class=\"token punctuation\">.<\/span><span class=\"token function\">Linear<\/span><span class=\"token punctuation\">(<\/span>d_model<span class=\"token punctuation\">,<\/span> dec_voc_size<span class=\"token punctuation\">)<\/span>\r\n\r\n    def <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> trg<span class=\"token punctuation\">,<\/span> src<span class=\"token punctuation\">,<\/span> trg_mask<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        # trg<span class=\"token operator\">:<\/span> \u76ee\u6807\u5e8f\u5217 <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> trg_seq_len<span class=\"token punctuation\">]<\/span>\r\n        # src<span class=\"token operator\">:<\/span> Encoder \u8f93\u51fa <span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> src_seq_len<span class=\"token punctuation\">,<\/span> d_model<span class=\"token punctuation\">]<\/span>\r\n        trg <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">emb<\/span><span class=\"token punctuation\">(<\/span>trg<span class=\"token punctuation\">)<\/span>\r\n\r\n        <span class=\"token keyword\">for<\/span> layer <span class=\"token keyword\">in<\/span> self<span class=\"token punctuation\">.<\/span>layers<span class=\"token operator\">:<\/span>\r\n            trg <span class=\"token operator\">=<\/span> <span class=\"token function\">layer<\/span><span class=\"token punctuation\">(<\/span>trg<span class=\"token punctuation\">,<\/span> src<span class=\"token punctuation\">,<\/span> trg_mask<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">)<\/span>\r\n\r\n        # \u8f93\u51fa\uff1a<span class=\"token punctuation\">[<\/span>batch_size<span class=\"token punctuation\">,<\/span> trg_seq_len<span class=\"token punctuation\">,<\/span> dec_voc_size<span class=\"token punctuation\">]<\/span>\r\n        output <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">linear<\/span><span class=\"token punctuation\">(<\/span>trg<span class=\"token punctuation\">)<\/span>\r\n        <span class=\"token keyword\">return<\/span> output<\/code><\/pre>\n<ul id=\"code_id_20\" class=\"pre-numbering\">\n<li>1.<\/li>\n<li>2.<\/li>\n<li>3.<\/li>\n<li>4.<\/li>\n<li>5.<\/li>\n<li>6.<\/li>\n<li>7.<\/li>\n<li>8.<\/li>\n<li>9.<\/li>\n<li>10.<\/li>\n<li>11.<\/li>\n<li>12.<\/li>\n<li>13.<\/li>\n<li>14.<\/li>\n<li>15.<\/li>\n<li>16.<\/li>\n<li>17.<\/li>\n<li>18.<\/li>\n<li>19.<\/li>\n<li>20.<\/li>\n<li>21.<\/li>\n<li>22.<\/li>\n<li>23.<\/li>\n<li>24.<\/li>\n<li>25.<\/li>\n<li>26.<\/li>\n<li>27.<\/li>\n<li>28.<\/li>\n<li>29.<\/li>\n<li>30.<\/li>\n<li>31.<\/li>\n<li>32.<\/li>\n<li>33.<\/li>\n<li>34.<\/li>\n<li>35.<\/li>\n<li>36.<\/li>\n<\/ul>\n<div class=\"toolbar\"><\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<h3>\u4e09\u3001Transformer\u5b9e\u73b0<\/h3>\n<p>\u57fa\u4e8e\u524d\u9762\u5b9e\u73b0\u7684 Encoder \u548c Decoder \u7ec4\u4ef6\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5b9e\u73b0 Transformer \u6a21\u578b\u7684\u5b8c\u6574\u4ee3\u7801\uff0c\u5982\u4e0b\u6240\u793a:<\/p>\n<div>\n<div class=\"hljs-cto\">\n<div class=\"hljs-cto\"><button class=\"copy_btn disable\" data-clipboard-target=\"#code_id_21\">\u590d\u5236<\/button><\/p>\n<div class=\"code-toolbar\">\n<pre class=\"has-pre-numbering language-javascript\" tabindex=\"0\"><code class=\"language-javascript\"><span class=\"token keyword\">import<\/span> torch\r\nfrom torch <span class=\"token keyword\">import<\/span> nn\r\n\r\n<span class=\"token keyword\">class<\/span> <span class=\"token class-name\">Transformer<\/span><span class=\"token punctuation\">(<\/span>nn<span class=\"token punctuation\">.<\/span>Module<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n    def <span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> src_pad_idx<span class=\"token punctuation\">,<\/span> trg_pad_idx<span class=\"token punctuation\">,<\/span> trg_sos_idx<span class=\"token punctuation\">,<\/span> enc_voc_size<span class=\"token punctuation\">,<\/span> dec_voc_size<span class=\"token punctuation\">,<\/span> \r\n                 d_model<span class=\"token punctuation\">,<\/span> n_head<span class=\"token punctuation\">,<\/span> max_len<span class=\"token punctuation\">,<\/span> ffn_hidden<span class=\"token punctuation\">,<\/span> n_layers<span class=\"token punctuation\">,<\/span> drop_prob<span class=\"token punctuation\">,<\/span> device<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        <span class=\"token keyword\">super<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">__init__<\/span><span class=\"token punctuation\">(<\/span><span class=\"token punctuation\">)<\/span>\r\n        \r\n        # \u4fdd\u5b58\u7279\u6b8a\u6807\u8bb0\u7684\u7d22\u5f15\r\n        self<span class=\"token punctuation\">.<\/span>src_pad_idx <span class=\"token operator\">=<\/span> src_pad_idx  # \u6e90\u8bed\u8a00\u586b\u5145\u7d22\u5f15\r\n        self<span class=\"token punctuation\">.<\/span>trg_pad_idx <span class=\"token operator\">=<\/span> trg_pad_idx  # \u76ee\u6807\u8bed\u8a00\u586b\u5145\u7d22\u5f15\r\n        self<span class=\"token punctuation\">.<\/span>trg_sos_idx <span class=\"token operator\">=<\/span> trg_sos_idx  # \u76ee\u6807\u8bed\u8a00\u8d77\u59cb\u7b26\u53f7\u7d22\u5f15\r\n        self<span class=\"token punctuation\">.<\/span>device <span class=\"token operator\">=<\/span> device  # \u8bbe\u5907\u4fe1\u606f\r\n        \r\n        # \u6784\u5efa Encoder\r\n        self<span class=\"token punctuation\">.<\/span>encoder <span class=\"token operator\">=<\/span> <span class=\"token function\">Encoder<\/span><span class=\"token punctuation\">(<\/span>\r\n            d_model<span class=\"token operator\">=<\/span>d_model<span class=\"token punctuation\">,<\/span>  # \u6a21\u578b\u7ef4\u5ea6\r\n            n_head<span class=\"token operator\">=<\/span>n_head<span class=\"token punctuation\">,<\/span>  # \u6ce8\u610f\u529b\u5934\u6570\u91cf\r\n            max_len<span class=\"token operator\">=<\/span>max_len<span class=\"token punctuation\">,<\/span>  # \u6700\u5927\u5e8f\u5217\u957f\u5ea6\r\n            ffn_hidden<span class=\"token operator\">=<\/span>ffn_hidden<span class=\"token punctuation\">,<\/span>  # \u524d\u9988\u7f51\u7edc\u9690\u85cf\u5c42\u7ef4\u5ea6\r\n            enc_voc_size<span class=\"token operator\">=<\/span>enc_voc_size<span class=\"token punctuation\">,<\/span>  # \u6e90\u8bed\u8a00\u8bcd\u6c47\u8868\u5927\u5c0f\r\n            drop_prob<span class=\"token operator\">=<\/span>drop_prob<span class=\"token punctuation\">,<\/span>  # Dropout \u6982\u7387\r\n            n_layers<span class=\"token operator\">=<\/span>n_layers<span class=\"token punctuation\">,<\/span>  # Encoder \u5c42\u6570\r\n            device<span class=\"token operator\">=<\/span>device  # \u8bbe\u5907\u4fe1\u606f\r\n        <span class=\"token punctuation\">)<\/span>\r\n        \r\n        # \u6784\u5efa Decoder\r\n        self<span class=\"token punctuation\">.<\/span>decoder <span class=\"token operator\">=<\/span> <span class=\"token function\">Decoder<\/span><span class=\"token punctuation\">(<\/span>\r\n            d_model<span class=\"token operator\">=<\/span>d_model<span class=\"token punctuation\">,<\/span>  # \u6a21\u578b\u7ef4\u5ea6\r\n            n_head<span class=\"token operator\">=<\/span>n_head<span class=\"token punctuation\">,<\/span>  # \u6ce8\u610f\u529b\u5934\u6570\u91cf\r\n            max_len<span class=\"token operator\">=<\/span>max_len<span class=\"token punctuation\">,<\/span>  # \u6700\u5927\u5e8f\u5217\u957f\u5ea6\r\n            ffn_hidden<span class=\"token operator\">=<\/span>ffn_hidden<span class=\"token punctuation\">,<\/span>  # \u524d\u9988\u7f51\u7edc\u9690\u85cf\u5c42\u7ef4\u5ea6\r\n            dec_voc_size<span class=\"token operator\">=<\/span>dec_voc_size<span class=\"token punctuation\">,<\/span>  # \u76ee\u6807\u8bed\u8a00\u8bcd\u6c47\u8868\u5927\u5c0f\r\n            drop_prob<span class=\"token operator\">=<\/span>drop_prob<span class=\"token punctuation\">,<\/span>  # Dropout \u6982\u7387\r\n            n_layers<span class=\"token operator\">=<\/span>n_layers<span class=\"token punctuation\">,<\/span>  # Decoder \u5c42\u6570\r\n            device<span class=\"token operator\">=<\/span>device  # \u8bbe\u5907\u4fe1\u606f\r\n        <span class=\"token punctuation\">)<\/span>\r\n\r\n    def <span class=\"token function\">forward<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> src<span class=\"token punctuation\">,<\/span> trg<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        # \u521b\u5efa\u6e90\u5e8f\u5217\u7684 padding mask\r\n        src_mask <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">make_pad_mask<\/span><span class=\"token punctuation\">(<\/span>src<span class=\"token punctuation\">,<\/span> src<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>src_pad_idx<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>src_pad_idx<span class=\"token punctuation\">)<\/span>\r\n        \r\n        # \u521b\u5efa\u76ee\u6807\u5e8f\u5217\u5230\u6e90\u5e8f\u5217\u7684 mask\r\n        src_trg_mask <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">make_pad_mask<\/span><span class=\"token punctuation\">(<\/span>trg<span class=\"token punctuation\">,<\/span> src<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>trg_pad_idx<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>src_pad_idx<span class=\"token punctuation\">)<\/span>\r\n        \r\n        # \u521b\u5efa\u76ee\u6807\u5e8f\u5217\u7684 padding mask \u548c\u56e0\u679cmask\u7684\u7ec4\u5408\r\n        trg_mask <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">make_pad_mask<\/span><span class=\"token punctuation\">(<\/span>trg<span class=\"token punctuation\">,<\/span> trg<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>trg_pad_idx<span class=\"token punctuation\">,<\/span> self<span class=\"token punctuation\">.<\/span>trg_pad_idx<span class=\"token punctuation\">)<\/span> <span class=\"token operator\">*<\/span> \\\r\n                   self<span class=\"token punctuation\">.<\/span><span class=\"token function\">make_no_peak_mask<\/span><span class=\"token punctuation\">(<\/span>trg<span class=\"token punctuation\">,<\/span> trg<span class=\"token punctuation\">)<\/span>\r\n\r\n        # \u7f16\u7801\u5668\u524d\u5411\u4f20\u64ad\r\n        enc_src <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">encoder<\/span><span class=\"token punctuation\">(<\/span>src<span class=\"token punctuation\">,<\/span> src_mask<span class=\"token punctuation\">)<\/span>\r\n        \r\n        # \u89e3\u7801\u5668\u524d\u5411\u4f20\u64ad\r\n        output <span class=\"token operator\">=<\/span> self<span class=\"token punctuation\">.<\/span><span class=\"token function\">decoder<\/span><span class=\"token punctuation\">(<\/span>trg<span class=\"token punctuation\">,<\/span> enc_src<span class=\"token punctuation\">,<\/span> trg_mask<span class=\"token punctuation\">,<\/span> src_trg_mask<span class=\"token punctuation\">)<\/span>\r\n        \r\n        <span class=\"token keyword\">return<\/span> output\r\n\r\n    def <span class=\"token function\">make_pad_mask<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> q<span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">,<\/span> q_pad_idx<span class=\"token punctuation\">,<\/span> k_pad_idx<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        # \u83b7\u53d6\u8f93\u5165\u5e8f\u5217\u957f\u5ea6\r\n        len_q<span class=\"token punctuation\">,<\/span> len_k <span class=\"token operator\">=<\/span> q<span class=\"token punctuation\">.<\/span><span class=\"token function\">size<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">.<\/span><span class=\"token function\">size<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n        # \u521b\u5efa\u9488\u5bf9 key \u7684 mask\r\n        # batch_size x <span class=\"token number\">1<\/span> x <span class=\"token number\">1<\/span> x len_k\r\n        k_mask <span class=\"token operator\">=<\/span> k<span class=\"token punctuation\">.<\/span><span class=\"token function\">ne<\/span><span class=\"token punctuation\">(<\/span>k_pad_idx<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">unsqueeze<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">unsqueeze<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\">)<\/span>\r\n        # batch_size x <span class=\"token number\">1<\/span> x len_q x len_k\r\n        k_mask <span class=\"token operator\">=<\/span> k_mask<span class=\"token punctuation\">.<\/span><span class=\"token function\">repeat<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> len_q<span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span>\r\n\r\n        # \u521b\u5efa\u9488\u5bf9 query \u7684 mask\r\n        # batch_size x <span class=\"token number\">1<\/span> x len_q x <span class=\"token number\">1<\/span>\r\n        q_mask <span class=\"token operator\">=<\/span> q<span class=\"token punctuation\">.<\/span><span class=\"token function\">ne<\/span><span class=\"token punctuation\">(<\/span>q_pad_idx<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">unsqueeze<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">unsqueeze<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">3<\/span><span class=\"token punctuation\">)<\/span>\r\n        # batch_size x <span class=\"token number\">1<\/span> x len_q x len_k\r\n        q_mask <span class=\"token operator\">=<\/span> q_mask<span class=\"token punctuation\">.<\/span><span class=\"token function\">repeat<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> <span class=\"token number\">1<\/span><span class=\"token punctuation\">,<\/span> len_k<span class=\"token punctuation\">)<\/span>\r\n        \r\n        # \u7ec4\u5408\u4e24\u4e2a mask\r\n        mask <span class=\"token operator\">=<\/span> k_mask <span class=\"token operator\">&amp;<\/span> q_mask\r\n        \r\n        <span class=\"token keyword\">return<\/span> mask\r\n\r\n    def <span class=\"token function\">make_no_peak_mask<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">,<\/span> q<span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">)<\/span><span class=\"token operator\">:<\/span>\r\n        # \u521b\u5efa\u56e0\u679cmask\uff0c\u9632\u6b62\u89e3\u7801\u5668\u770b\u5230\u672a\u6765\u4fe1\u606f\r\n        len_q<span class=\"token punctuation\">,<\/span> len_k <span class=\"token operator\">=<\/span> q<span class=\"token punctuation\">.<\/span><span class=\"token function\">size<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">,<\/span> k<span class=\"token punctuation\">.<\/span><span class=\"token function\">size<\/span><span class=\"token punctuation\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\">)<\/span>\r\n        \r\n        # \u521b\u5efa\u4e0b\u4e09\u89d2\u77e9\u9635\uff0c\u4fdd\u8bc1\u89e3\u7801\u5668\u53ea\u80fd\u5173\u6ce8\u5f53\u524d\u8bcd\u53ca\u4e4b\u524d\u7684\u8bcd\r\n        mask <span class=\"token operator\">=<\/span> torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">tril<\/span><span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span><span class=\"token function\">ones<\/span><span class=\"token punctuation\">(<\/span>len_q<span class=\"token punctuation\">,<\/span> len_k<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">type<\/span><span class=\"token punctuation\">(<\/span>torch<span class=\"token punctuation\">.<\/span>BoolTensor<span class=\"token punctuation\">)<\/span><span class=\"token punctuation\">.<\/span><span class=\"token function\">to<\/span><span class=\"token punctuation\">(<\/span>self<span class=\"token punctuation\">.<\/span>device<span class=\"token punctuation\">)<\/span>\r\n        \r\n        <span class=\"token keyword\">return<\/span> mask<\/code>\r\n\r\n\r\n\u6587\u7ae0\u6765\u81ea\uff1a51CTO\r\n\r\n\r\n<\/pre>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<div class=\"pvc_clear\"><\/div>\n<p id=\"pvc_stats_25918\" class=\"pvc_stats total_only  \" data-element-id=\"25918\" style=\"\"><i class=\"pvc-stats-icon medium\" aria-hidden=\"true\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" version=\"1.0\" viewBox=\"0 0 502 315\" preserveAspectRatio=\"xMidYMid meet\"><g transform=\"translate(0,332) scale(0.1,-0.1)\" fill=\"\" stroke=\"none\"><path d=\"M2394 3279 l-29 -30 -3 -207 c-2 -182 0 -211 15 -242 39 -76 157 -76 196 0 15 31 17 60 15 243 l-3 209 -33 29 c-26 23 -41 29 -80 29 -41 0 -53 -5 -78 -31z\"\/><path d=\"M3085 3251 c-45 -19 -58 -50 -96 -229 -47 -217 -49 -260 -13 -295 52 -53 146 -42 177 20 16 31 87 366 87 410 0 70 -86 122 -155 94z\"\/><path d=\"M1751 3234 c-13 -9 -29 -31 -37 -50 -12 -29 -10 -49 21 -204 19 -94 39 -189 45 -210 14 -50 54 -80 110 -80 34 0 48 6 76 34 21 21 34 44 34 59 0 14 -18 113 -40 219 -37 178 -43 195 -70 221 -36 32 -101 37 -139 11z\"\/><path d=\"M1163 3073 c-36 -7 -73 -59 -73 -102 0 -56 133 -378 171 -413 34 -32 83 -37 129 -13 70 36 67 87 -16 290 -86 209 -89 214 -129 231 -35 14 -42 15 -82 7z\"\/><path d=\"M3689 3066 c-15 -9 -33 -30 -42 -48 -48 -103 -147 -355 -147 -375 0 -98 131 -148 192 -74 13 15 57 108 97 206 80 196 84 226 37 273 -30 30 -99 39 -137 18z\"\/><path d=\"M583 2784 c-38 -19 -67 -74 -58 -113 9 -42 211 -354 242 -373 16 -10 45 -18 66 -18 51 0 107 52 107 100 0 39 -1 41 -124 234 -80 126 -108 162 -133 173 -41 17 -61 16 -100 -3z\"\/><path d=\"M4250 2784 c-14 -9 -74 -91 -133 -183 -95 -150 -107 -173 -107 -213 0 -55 33 -94 87 -104 67 -13 90 8 211 198 130 202 137 225 78 284 -27 27 -42 34 -72 34 -22 0 -50 -8 -64 -16z\"\/><path d=\"M2275 2693 c-553 -48 -1095 -270 -1585 -649 -135 -104 -459 -423 -483 -476 -23 -49 -22 -139 2 -186 73 -142 361 -457 571 -626 285 -228 642 -407 990 -497 242 -63 336 -73 660 -74 310 0 370 5 595 52 535 111 1045 392 1455 803 122 121 250 273 275 326 19 41 19 137 0 174 -41 79 -309 363 -465 492 -447 370 -946 591 -1479 653 -113 14 -422 18 -536 8z m395 -428 c171 -34 330 -124 456 -258 112 -119 167 -219 211 -378 27 -96 24 -300 -5 -401 -72 -255 -236 -447 -474 -557 -132 -62 -201 -76 -368 -76 -167 0 -236 14 -368 76 -213 98 -373 271 -451 485 -162 444 86 934 547 1084 153 49 292 57 452 25z m909 -232 c222 -123 408 -262 593 -441 76 -74 138 -139 138 -144 0 -16 -233 -242 -330 -319 -155 -123 -309 -223 -461 -299 l-81 -41 32 46 c18 26 49 83 70 128 143 306 141 649 -6 957 -25 52 -61 116 -79 142 l-34 47 45 -20 c26 -10 76 -36 113 -56z m-2057 25 c-40 -58 -105 -190 -130 -263 -110 -324 -59 -707 132 -981 25 -35 42 -64 37 -64 -19 0 -241 119 -326 174 -188 122 -406 314 -532 468 l-58 71 108 103 c185 178 428 349 672 473 66 33 121 60 123 61 2 0 -10 -19 -26 -42z\"\/><path d=\"M2375 1950 c-198 -44 -350 -190 -395 -379 -18 -76 -8 -221 19 -290 114 -284 457 -406 731 -260 98 52 188 154 231 260 27 69 37 214 19 290 -38 163 -166 304 -326 360 -67 23 -215 33 -279 19z\"\/><\/g><\/svg><\/i> <img loading=\"lazy\" decoding=\"async\" width=\"16\" height=\"16\" alt=\"Loading\" src=\"https:\/\/aif.amtbbs.org\/wp-content\/plugins\/page-views-count\/ajax-loader-2x.gif\" border=0 \/><\/p>\n<div class=\"pvc_clear\"><\/div>\n","protected":false},"excerpt":{"rendered":"<p>Transformer \u9ed8\u8ba4\u90fd\u662f\u5927\u6a21\u578b\uff0c\u9664\u4e86\u4e00\u4e9b\u7279\u4f8b\uff08\u5982 DistilBERT\uff09\u5916\uff0c\u5b9e\u73b0\u66f4\u597d\u6027\u80fd\u7684\u4e00\u822c\u7b56\u7565\u662f\u589e [&hellip;]<\/p>\n<div class=\"pvc_clear\"><\/div>\n<p id=\"pvc_stats_25918\" class=\"pvc_stats total_only  \" data-element-id=\"25918\" style=\"\"><i class=\"pvc-stats-icon medium\" aria-hidden=\"true\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" version=\"1.0\" viewBox=\"0 0 502 315\" preserveAspectRatio=\"xMidYMid meet\"><g transform=\"translate(0,332) scale(0.1,-0.1)\" fill=\"\" stroke=\"none\"><path d=\"M2394 3279 l-29 -30 -3 -207 c-2 -182 0 -211 15 -242 39 -76 157 -76 196 0 15 31 17 60 15 243 l-3 209 -33 29 c-26 23 -41 29 -80 29 -41 0 -53 -5 -78 -31z\"\/><path d=\"M3085 3251 c-45 -19 -58 -50 -96 -229 -47 -217 -49 -260 -13 -295 52 -53 146 -42 177 20 16 31 87 366 87 410 0 70 -86 122 -155 94z\"\/><path d=\"M1751 3234 c-13 -9 -29 -31 -37 -50 -12 -29 -10 -49 21 -204 19 -94 39 -189 45 -210 14 -50 54 -80 110 -80 34 0 48 6 76 34 21 21 34 44 34 59 0 14 -18 113 -40 219 -37 178 -43 195 -70 221 -36 32 -101 37 -139 11z\"\/><path d=\"M1163 3073 c-36 -7 -73 -59 -73 -102 0 -56 133 -378 171 -413 34 -32 83 -37 129 -13 70 36 67 87 -16 290 -86 209 -89 214 -129 231 -35 14 -42 15 -82 7z\"\/><path d=\"M3689 3066 c-15 -9 -33 -30 -42 -48 -48 -103 -147 -355 -147 -375 0 -98 131 -148 192 -74 13 15 57 108 97 206 80 196 84 226 37 273 -30 30 -99 39 -137 18z\"\/><path d=\"M583 2784 c-38 -19 -67 -74 -58 -113 9 -42 211 -354 242 -373 16 -10 45 -18 66 -18 51 0 107 52 107 100 0 39 -1 41 -124 234 -80 126 -108 162 -133 173 -41 17 -61 16 -100 -3z\"\/><path d=\"M4250 2784 c-14 -9 -74 -91 -133 -183 -95 -150 -107 -173 -107 -213 0 -55 33 -94 87 -104 67 -13 90 8 211 198 130 202 137 225 78 284 -27 27 -42 34 -72 34 -22 0 -50 -8 -64 -16z\"\/><path d=\"M2275 2693 c-553 -48 -1095 -270 -1585 -649 -135 -104 -459 -423 -483 -476 -23 -49 -22 -139 2 -186 73 -142 361 -457 571 -626 285 -228 642 -407 990 -497 242 -63 336 -73 660 -74 310 0 370 5 595 52 535 111 1045 392 1455 803 122 121 250 273 275 326 19 41 19 137 0 174 -41 79 -309 363 -465 492 -447 370 -946 591 -1479 653 -113 14 -422 18 -536 8z m395 -428 c171 -34 330 -124 456 -258 112 -119 167 -219 211 -378 27 -96 24 -300 -5 -401 -72 -255 -236 -447 -474 -557 -132 -62 -201 -76 -368 -76 -167 0 -236 14 -368 76 -213 98 -373 271 -451 485 -162 444 86 934 547 1084 153 49 292 57 452 25z m909 -232 c222 -123 408 -262 593 -441 76 -74 138 -139 138 -144 0 -16 -233 -242 -330 -319 -155 -123 -309 -223 -461 -299 l-81 -41 32 46 c18 26 49 83 70 128 143 306 141 649 -6 957 -25 52 -61 116 -79 142 l-34 47 45 -20 c26 -10 76 -36 113 -56z m-2057 25 c-40 -58 -105 -190 -130 -263 -110 -324 -59 -707 132 -981 25 -35 42 -64 37 -64 -19 0 -241 119 -326 174 -188 122 -406 314 -532 468 l-58 71 108 103 c185 178 428 349 672 473 66 33 121 60 123 61 2 0 -10 -19 -26 -42z\"\/><path d=\"M2375 1950 c-198 -44 -350 -190 -395 -379 -18 -76 -8 -221 19 -290 114 -284 457 -406 731 -260 98 52 188 154 231 260 27 69 37 214 19 290 -38 163 -166 304 -326 360 -67 23 -215 33 -279 19z\"\/><\/g><\/svg><\/i> <img loading=\"lazy\" decoding=\"async\" width=\"16\" height=\"16\" alt=\"Loading\" src=\"https:\/\/aif.amtbbs.org\/wp-content\/plugins\/page-views-count\/ajax-loader-2x.gif\" border=0 \/><\/p>\n<div class=\"pvc_clear\"><\/div>\n","protected":false},"author":56,"featured_media":25920,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[3,23,20,80],"tags":[124,667,937],"class_list":["post-25918","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-ai","category-23","category-20","category-80","tag-ai","tag-667","tag-937"],"_links":{"self":[{"href":"https:\/\/aif.amtbbs.org\/index.php\/wp-json\/wp\/v2\/posts\/25918","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/aif.amtbbs.org\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/aif.amtbbs.org\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/aif.amtbbs.org\/index.php\/wp-json\/wp\/v2\/users\/56"}],"replies":[{"embeddable":true,"href":"https:\/\/aif.amtbbs.org\/index.php\/wp-json\/wp\/v2\/comments?post=25918"}],"version-history":[{"count":1,"href":"https:\/\/aif.amtbbs.org\/index.php\/wp-json\/wp\/v2\/posts\/25918\/revisions"}],"predecessor-version":[{"id":25925,"href":"https:\/\/aif.amtbbs.org\/index.php\/wp-json\/wp\/v2\/posts\/25918\/revisions\/25925"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/aif.amtbbs.org\/index.php\/wp-json\/wp\/v2\/media\/25920"}],"wp:attachment":[{"href":"https:\/\/aif.amtbbs.org\/index.php\/wp-json\/wp\/v2\/media?parent=25918"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/aif.amtbbs.org\/index.php\/wp-json\/wp\/v2\/categories?post=25918"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/aif.amtbbs.org\/index.php\/wp-json\/wp\/v2\/tags?post=25918"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}