内容目录

前言

在上一章【课程总结】day20:Transformer源码深入理解-训练过程总结中,我们对Transformer的训练过程进行了详细的分析,本章将介绍Transformer的预测过程。

预测流程

  • 预测大致为如下步骤:
  • 第一部分:准备数据,通过调用 tokenizersplit_input() 进行分词、encoder_input() 进行编码,将输入内容最终转为张量。
  • 第二部分:调用greed_decoder()进行预测,返回预测结果。
  • 第三部分:将预测的结果进行处理,使用 get_real_output() 方法将预测的输出进行处理,得到真实输出序列。

代码分析理解

预测过程

    def infer(self, sentence="Am I wrong?"):
        """
            预测过程
        """
        print("原文:", sentence)
        sentence = self.tokenizer.split_input(sentence=sentence)
        print("分词:", sentence)
        sentence = self.tokenizer.encode_input(input_sentence=sentence, input_sentence_len=len(sentence))
        print("编码:", sentence)
        src = torch.LongTensor([sentence]).to(device=self.device)
        print("张量:", src)
        # src = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
        max_len = src.size(1)
        src_mask = torch.ones(1, 1, max_len)
        # 模型推理
        self.model.eval()
        with torch.no_grad():
            y_pred = greedy_decode(self.model,
                                   src,
                                   src_mask,
                                   max_len=self.tokenizer.output_max_len,
                                   start_symbol=self.tokenizer.output_word2idx.get("<SOS>"))
        # 去除 启动信号 <SOS>
        y_pred = y_pred[:, 1:]
        raw_results, final_results = get_real_output(y_pred.cpu(), self.tokenizer)
        print("原始预测:", raw_results[0])
        print("最终预测:", final_results[0])

代码理解:

  • 在Translation类中,封装infer()函数,该函数用于预测。
  • 在预测的时候,设置模型为eval模式,即关闭dropout;
  • 在无梯度模式下,调用模型推理函数greed_decode()进行预测。

greedy_decode() 函数


def greedy_decode(model, src, src_mask, max_len, start_symbol):
    # 获取中间表达
    memory = model.encode(src, src_mask)
    # 启动信号
    ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)
    # 自回归式生成
    for _ in range(max_len - 1):
        # 获取结果
        out = model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)
        )
        # print(out)
        # 取出最后一步的结果
        prob = model.generator(out[:, -1])
        # 获取概率最大的值
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        # 拼接起来,准备生成下一个词
        ys = torch.cat(
            [ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1
        )
    return ys

假设示例如下:

  • 上文输入:Am I wrong?

初始输入

  • 上文输入(源序列):[‘am’, ‘i’, ‘wrong’, ‘?’]
  • 下文输入(目标序列):"SOS"
  • memory 的形状:[1, 4, 512]
  • ys 的形状:[1, 1],其值为7228 对应 SOS

第一轮自回归生成

  • 上文输入(源序列):[‘am’, ‘i’, ‘wrong’, ‘?’]
  • 下文输入(目标序列):["SOS"]
  • out 的形状:[1, 1, 512]
  • prob 的形状:[1, 12547]
  • 得到下一个词是:4890,对应是"我"
  • 拼接后的ys 的形状:[1, 2]

第二轮自回归生成

  • 上文输入(源序列):[‘am’, ‘i’, ‘wrong’, ‘?’]
  • 下文输入(目标序列):["SOS", "我"]
  • out 的形状:[1, 2, 512]
  • 得到下一个词是:1924,对应是"错"
  • 拼接后的ys 的形状:[1, 3]

第三轮自回归生成

  • 上文输入(源序列):[‘am’, ‘i’, ‘wrong’, ‘?’]
  • 下文输入(目标序列):["SOS", "我", "错"]
  • out 的形状:[1, 3, 512]
  • 得到下一个词是:"了"
  • 拼接后的ys 的形状:[1, 4]

第四轮自回归生成

  • 上文输入(源序列):[‘am’, ‘i’, ‘wrong’, ‘?’]
  • 下文输入(目标序列):["SOS", "我", "错", "了"]
  • 得到下一个词是:"吗"

第五轮自回归生成

  • 上文输入(源序列):[‘am’, ‘i’, ‘wrong’, ‘?’]
  • 下文输入(目标序列):["SOS", "我", "错", "了", "吗"]
  • 得到下一个词是:"?"

代码运行

切换至程序目录下

cd transformer_demo

安装相关依赖

pip install OpenCC

pip install jieba

运行程序

python main.py

运行效果:

内容小结

  • 预测过程注意事项有:
    • 首先需要对输入进行分词、编码、转张量处理。
    • 其次要将模型设置为eval模式,同时在无梯度模式下,调用模型推理函数greed_decode()进行预测。
    • 下文的输入序列在初始化时需要增加启动信号\,预测完毕后需要将启动信号从结果中去除。
    • 在自回归的过程中,每次预测下一个词时,需要将预测结果拼接到输入序列中,作为下一轮的输入。
    • 预测完毕后,通过get_real_output()函数,将预测结果进行处理,得到真实输出序列。

参考资料

The Annotated Transformer

发表评论

您的电子邮箱地址不会被公开。 必填项已用 * 标注

分类文章

personal_logo
Dongming
自由职业者

推荐活动

推荐文章

【项目实战】基于Agent的金融问答系统:RAG的检索增强之上下文重排和压缩
【项目实战】基于Agent的金融问答系统:RAG的检索增强之ElasticSearch
【项目实战】基于Agent的金融问答系统:前后端流程打通
【项目实战】基于Agent的金融问答系统:代码重构
【项目实战】基于Agent的金融问答系统:Agent框架的构建
【项目实战】基于Agent的金融问答系统:RAG检索模块初建成
【项目实战】基于Agent的金融问答系统:项目简介
【课程总结】day29:大模型之深入了解Retrievers解析器
【课程总结】day28:大模型之深入探索RAG流程
【课程总结】day30:大模型之Agent的初步了解
内容目录
滚动至顶部