NNLM神经网络语言模型简单实现词语预测(含python代码详解)_神经网络语言模型代码-程序员宅基地

技术标签: python  nlp  自然语言处理  

NNLM:Neural Network Language Model,神经网络语言模型。源自Bengio等人于2001年发表在NIPS上的《A Neural Probabilistic Language Model一文。

利用神经网络计算词向量的方法,根据(w{t-n+1}...w{t-1})来预测(w{t})是什么单词,即用前(n-1)个单词来预测第(n)个单词。

二、NNLM词语预测代码

1. 导入包

torch库——又称PyTorach,是一个以Python优先的深度学习框架,一个开源的Python机器学习库,用于自然语言处理等应用程序。

torch.nn包——nn全称为neural network,意思是神经网络,是torch中构建神经网络的模块。

torch.optim包——这个包里面有很多的优化算法,比如我们常用的随机梯度下降算法,添加动量的随机梯度下降算法。

import torch
import torch.nn as nn
import torch.optim as optim

2. 文本数据处理

输入三句短文本,"i like dog", "i love coffee", "i hate milk",作为模型预测的资料。

dtype = torch.FloatTensor
sentences = ["i like dog", "i love coffee", "i hate milk"]
word_list = " ".join(sentences).split()  # 提取句子中所有词语
#print(word_list)
word_list = list(set(word_list))  # 去除重复元素,得到词汇表
#print("去重后的word_list:", word_list)
word_dict = {w: i for i, w in enumerate(word_list)}  # 按照词汇表生成相应的词典 {‘word’:0,...}
number_dict = {i: w for i, w in enumerate(word_list)}  # 将每个索引对应于相应的单词{0:'word',...}
n_class = len(word_dict)  # 单词的总数,也是分类数

torch.FloatTensor——FloatTensor用于生成浮点类型的张量。 torch.FloatTensor()默认生成32位浮点数,dtype 为 torch.float32 或 torch.float。  

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。 

3. 自定义mini-batch迭代器

自定义函数:def make_batch(sentences),make_batch(sentences)函数是一个mini-batch迭代器,实现数据的输入输出,函数以sentences列表作为输入, 最终函数将输入数据集input_batch和输出数据集target_batch返回为结果。详见代码注释。 

def make_batch(sentences):
    input_batch = []
    target_batch = []
 
    for sen in sentences:
    #通过for循环遍历sentences中的每个句子

        word = sen.split()
        input = [word_dict[n] for n in word[:-1]]
        #设定输入为列表word中每个词汇对应的数字所组成的序列,一句话中最后一个词是要用来预测的,                                不作为输入。最后的:-1就表示取每个句子在最后一个单词之前的单词作为输入,通过word_dict取出这些单词的下标,作为整个网络的输入。

        target = word_dict[word[-1]]
        #将每句话的最后一个词作为目标值(target),以本次实验为例就是cat,coffee和milk,word_dict取出单词的下标,作为输出。

        input_batch.append(input)
        #input_batch是空列表,将每句话的输入放入列表中,形成输入数据集

        target_batch.append(target)
        #target_batch是空列表,将每句话的输出放入列表中,形成输出数据集
 
    return input_batch, target_batch

接下来调用make_batch函数进行数据输入和转化:

将sentences输入make_batch函数,使用make_batch从训练集中获得输入和对应的标记,将输入数据集用input_batch存储,将输出数据集target_batch用存储。

input_batch, target_batch = make_batch(sentences)

  

 4. 定义NNLM模型

1. 定义模型结构

# 定义模型
class NNLM(nn.Module):
    def __init__(self):
        super(NNLM, self).__init__() #定义网络结构,继承nn.Module
        self.C = nn.Embedding(n_class, m) 
        self.H = nn.Parameter(torch.randn(n_step * m, n_hidden).type(dtype))
        self.W = nn.Parameter(torch.randn(n_step * m, n_class).type(dtype))
        self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))
        self.U = nn.Parameter(torch.randn(n_hidden, n_class).type(dtype))
        self.b = nn.Parameter(torch.randn(n_class).type(dtype))
        #C: 词向量,计算词向量表,大小是len(word_dict) * m 词向量随机赋值,先使用one-hot,然后使用matrix C映射到词向量。
        #H: 隐藏层的权重; W: 输入层到输出层的权重;
        #d: 隐藏层的bias;  U: 输出层的weight;  b: 输出层的bias;
        #n_step为文中用n_step个词预测下一个词,在本程序中其值为2
        #n_hidden为隐藏层的神经元的数量
        #m为词向量的维度 



    def forward(self, X): 
        X = self.C(X)  # [batch_size, n_step] => [batch_size, n_step, m]
        #输入层的输入转换:x=x’* C==[C(wi−(n−1)), …,C(wi−1)];
根据词向量表,将输入数据X转换成三维数据,将每个单词替换成相应的词向量。X原本形式为[batch_size, n_step],转换后为[batch_size, n_step, m]

        X = X.view(-1, n_step * m)  # [batch_size, n_step * m]
        #将替换后的词向量表的相同行进行拼接,view函数的第一个参数为-1表示自动判断需要合并成几行。

        hidden_out = torch.tanh(self.d + torch.mm(X, self.H))  # [batch_size, n_hidden]
        #隐藏层的计算,主要计算h=tanh(d+Hx)。其中,H表示输入层
到隐藏层的权重矩阵,其维度为|V| * |h|。|V|表示词表的大小,d表示偏置,torch.mm表示矩阵的相乘。输出为[batch_size, n_hidden]

        output = self.b + torch.mm(X, self.W) + torch.mm(hidden_out, self.U)  # [batch_size, n_class]
        #输出层的计算:主要计算y=b+Uh。其中,U表示隐藏层到输出层的权重矩阵,b表示偏置,y表示输出的一个|V|的向量,向量中内容是下一个词wi是词表中每一个词的可能性。输出为[batch_size, n_class],最终return返回output。
        return output

代码中的:

torch.nn.Embedding()函数是指torch.nn包下的Embedding,作为训练的一层,随模型训练得到适合的词向量。

torch.nn.Parameter()函数含义是将一个固定不可训练的tensor转换成可以训练的类型parameter,并将这个parameter绑定到这个module里面,所以经过类型转换这个self.H变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

torch.randn()函数用来生成随机数字的tensor,这些随机数字满足标准正态分布(0~1)。例如torch.randn(size),size可以是一个整数,也可以是一个元组。

输入层的输入:将词序列wi−(n-1)…wi−1中的n-1个词,每一个词 进行one-hot编码,得到向量1*V;词向量按照顺序进行拼接, 获的输入向量x’=[V(wi−(n−1)), …,V(wi−1)];

总之就是将将输入的 n-1 个单词索引转为词向量,然后将这 n-1 个词向量进行 concat,形成一个 (n-1)*w 的输入向量。接下来将向量作为X送入隐藏层进行计算,hidden = tanh(d + X * H) 这就涉及到了自定义函数forward,使NNLM模型可以训练并完成向量的迭代更新,forword函数的代码解释详见代码注释。

2. NNLM参数设置

# NNLM参数设置
n_step = 2   # 设定n_gram为2,即根据当前词的前两个词语预测当前单词
n_hidden = 2  # 设定隐藏层神经元的个数为2
m = 2  # 设定词向量的维度为2
model = NNLM() #将之前建立的NNLM模型实例化为model
criterion = nn.CrossEntropyLoss() #使用交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)  #优化器 选择Adam

其中分类问题用交叉熵作为损失函数; nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题。在使用nn.CrossEntropyLoss()其内部会自动加上Sofrmax层。

优化器使用Adam。所谓的优化器,实际上就是你用什么方法去更新网路中的参数。 torch.optim是一个实现了多种优化算法的包,大多数通用的方法都已支持,提供了丰富的接口调用。 Adam算法本质上是带有动量项的RMSprop,它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。

5. 输入数据并完成训练

输入数据:

# 数据输入
input_batch, target_batch = make_batch(sentences)
input_batch = torch.LongTensor(input_batch)
target_batch = torch.LongTensor(target_batch)

其中使用make_batch从训练集中获得输入和对应的标记;

input_batch:一组batch中前n_steps个单词的索引;

target_batch:一组batch中每句话待预测单词的索引 torch.FloatTensor是32位浮点类型数据,而torch.LongTensor是64位整型;

开始训练:

# 开始训练
for epoch in range(5000):  #设定训练5000轮
    optimizer.zero_grad()  #梯度清零,也就是把loss关于weight的导数变成0
    output = model(input_batch)  #模型训练 tensor(3,7)
    # output : [batch_size, n_class], target_batch : [batch_size] (LongTensor, not one-hot)
    
    loss = criterion(output, target_batch) 
    #计算损失,criterion()为损失函数,用来计算出loss
    if (epoch + 1) % 1000 == 0:
        print("Epoch:{}".format(epoch + 1), "Loss:{:.3f}".format(loss))
        #每到1000输出一次损失值
    loss.backward() #反向传播
    optimizer.step() #更新参数,optimizer实现了step()方法,这个方法会更新对应的参数。只有用了optimizer.step(),模型才会更新。

其中重点解释output = model(input_batch):

计算预测值,对之前建立的NNLM模型集进行训练,形式为tensor(3,7)。 一行代表一个输入对应的七个输出,这七个值对应着7类,也就是词典个数,对应最大值的位置序号就是最终预测值。

 6. 预测

# 预测
predict = model(input_batch).data.max(1, keepdim=True)[1]  #tensor (3,1)获取最大值对应的(序号)单词,也就是预测值 [batch_size, n_class]
# print("predict: \n", predict)
# 测试
print([sentence.split()[:2] for sentence in sentences], "---->",
      [number_dict[n.item()] for n in predict.squeeze()])  #predict.squeeze 的 tensor(3)

先获取预测值最大者对应的(序号)单词,也就是预测值 [batch_size, n_class] max()取的是最内层维度中最大的那个数的值和索引,[1]表示取索引。

squeeze()表示将数组中维度为1的维度去掉,squeeze():对张量的维度进行减少的操作,假设原来:tensor([[0],[6],[5]]),squeeze()操作后变成tensor([0, 6, 5])。

最终通过for循环将每个句子的前两个词组成元素放在列表中,再通过for循环将预测出来的序号对应词汇放入列表中,中间用"---->"连接。

验证一下,发现tensor([0, 6, 5])正是对应number_dict中的dog, coffee, milk:

     

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_50706330/article/details/127708430

智能推荐

FTP命令字和返回码_ftp 登录返回230-程序员宅基地

文章浏览阅读3.5k次,点赞2次,收藏13次。为了从FTP服务器下载文件,需要要实现一个简单的FTP客户端。FTP(文件传输协议) 是 TCP/IP 协议组中的应用层协议。FTP协议使用字符串格式命令字,每条命令都是一行字符串,以“\r\n”结尾。客户端发送格式是:命令+空格+参数+"\r\n"的格式服务器返回格式是以:状态码+空格+提示字符串+"\r\n"的格式,代码只要解析状态码就可以了。读写文件需要登陆服务器,特殊用..._ftp 登录返回230

centos7安装rabbitmq3.6.5_centos7 安装rabbitmq3.6.5-程序员宅基地

文章浏览阅读648次。前提:systemctl stop firewalld 关闭防火墙关闭selinux查看getenforce临时关闭setenforce 0永久关闭sed-i'/SELINUX/s/enforcing/disabled/'/etc/selinux/configselinux的三种模式enforcing:强制模式,SELinux 运作中,且已经正确的开始限制..._centos7 安装rabbitmq3.6.5

idea导入android工程,idea怎样导入Android studio 项目?-程序员宅基地

文章浏览阅读5.8k次。满意答案s55f2avsx2017.09.05采纳率:46%等级:12已帮助:5646人新版Android Studio/IntelliJ IDEA可以直接导入eclipse项目,不再推荐使用eclipse导出gradle的方式2启动Android Studio/IntelliJ IDEA,选择 import project3选择eclipse 项目4选择 create project f..._android studio 项目导入idea 看不懂安卓项目

浅谈AI大模型技术:概念、发展和应用_ai大模型应用开发-程序员宅基地

文章浏览阅读860次,点赞2次,收藏6次。AI大模型技术已经在自然语言处理、计算机视觉、多模态交互等领域取得了显著的进展和成果,同时也引发了一系列新的挑战和问题,如数据质量、计算效率、知识可解释性、安全可靠性等。城市运维涉及到多个方面,如交通管理、环境监测、公共安全、社会治理等,它们需要处理和分析大量的多模态数据,如图像、视频、语音、文本等,并根据不同的场景和需求,提供合适的决策和响应。知识搜索有多种形式,如语义搜索、对话搜索、图像搜索、视频搜索等,它们可以根据用户的输入和意图,从海量的数据源中检索出最相关的信息,并以友好的方式呈现给用户。_ai大模型应用开发

非常详细的阻抗测试基础知识_阻抗实部和虚部-程序员宅基地

文章浏览阅读8.2k次,点赞12次,收藏121次。为什么要测量阻抗呢?阻抗能代表什么?阻抗测量的注意事项... ...很多人可能会带着一系列的问题来阅读本文。不管是数字电路工程师还是射频工程师,都在关注各类器件的阻抗,本文非常值得一读。全文13000多字,认真读完大概需要2小时。一、阻抗测试基本概念阻抗定义:阻抗是元器件或电路对周期的交流信号的总的反作用。AC 交流测试信号 (幅度和频率)。包括实部和虚部。​图1 阻抗的定义阻抗是评测电路、元件以及制作元件材料的重要参数。那么什么是阻抗呢?让我们先来看一下阻抗的定义。首先阻抗是一个矢量。通常,阻抗是_阻抗实部和虚部

小学生python游戏编程arcade----基本知识1_arcade语言 like-程序员宅基地

文章浏览阅读955次。前面章节分享试用了pyzero,pygame但随着想增加更丰富的游戏内容,好多还要进行自己编写类,从今天开始解绍一个新的python游戏库arcade模块。通过此次的《连连看》游戏实现,让我对swing的相关知识有了进一步的了解,对java这门语言也有了比以前更深刻的认识。java的一些基本语法,比如数据类型、运算符、程序流程控制和数组等,理解更加透彻。java最核心的核心就是面向对象思想,对于这一个概念,终于悟到了一些。_arcade语言 like

随便推点

【增强版短视频去水印源码】去水印微信小程序+去水印软件源码_去水印机要增强版-程序员宅基地

文章浏览阅读1.1k次。源码简介与安装说明:2021增强版短视频去水印源码 去水印微信小程序源码网站 去水印软件源码安装环境(需要材料):备案域名–服务器安装宝塔-安装 Nginx 或者 Apachephp5.6 以上-安装 sg11 插件小程序已自带解析接口,支持全网主流短视频平台,搭建好了就能用注:接口是公益的,那么多人用解析慢是肯定的,前段和后端源码已经打包,上传服务器之后在配置文件修改数据库密码。然后输入自己的域名,进入后台,创建小程序,输入自己的小程序配置即可安装说明:上传源码,修改data/_去水印机要增强版

verilog进阶语法-触发器原语_fdre #(.init(1'b0) // initial value of register (1-程序员宅基地

文章浏览阅读557次。1. 触发器是FPGA存储数据的基本单元2. 触发器作为时序逻辑的基本元件,官方提供了丰富的配置方式,以适应各种可能的应用场景。_fdre #(.init(1'b0) // initial value of register (1'b0 or 1'b1) ) fdce_osc (

嵌入式面试/笔试C相关总结_嵌入式面试笔试c语言知识点-程序员宅基地

文章浏览阅读560次。本该是不同编译器结果不同,但是尝试了g++ msvc都是先计算c,再计算b,最后得到a+b+c是经过赋值以后的b和c参与计算而不是6。由上表可知,将q复制到p数组可以表示为:*p++=*q++,*优先级高,先取到对应q数组的值,然后两个++都是在后面,该行运算完后执行++。在电脑端编译完后会分为text data bss三种,其中text为可执行程序,data为初始化过的ro+rw变量,bss为未初始化或初始化为0变量。_嵌入式面试笔试c语言知识点

57 Things I've Learned Founding 3 Tech Companies_mature-程序员宅基地

文章浏览阅读2.3k次。57 Things I've Learned Founding 3 Tech CompaniesJason Goldberg, Betashop | Oct. 29, 2010, 1:29 PMI’ve been founding andhelping run techn_mature

一个脚本搞定文件合并去重,大数据处理,可以合并几个G以上的文件_python 超大文本合并-程序员宅基地

文章浏览阅读1.9k次。问题:先讲下需求,有若干个文本文件(txt或者csv文件等),每行代表一条数据,现在希望能合并成 1 个文本文件,且需要去除重复行。分析:一向奉行简单原则,如无必要,绝不复杂。如果数据量不大,那么如下两条命令就可以搞定合并:cat a.txt >> new.txtcat b.txt >> new.txt……去重:cat new...._python 超大文本合并

支付宝小程序iOS端过渡页DFLoadingPageRootController分析_类似支付宝页面过度加载页-程序员宅基地

文章浏览阅读489次。这个过渡页是第一次打开小程序展示的,点击某个小程序前把手机的开发者->network link conditioner->enable & very bad network 就会在停在此页。比如《支付宝运动》这个小程序先看这个类的.h可以看到它继承于DTViewController点击左上角返回的方法- (void)back;#import "DTViewController.h"#import "APBaseLoadingV..._类似支付宝页面过度加载页

推荐文章

热门文章

相关标签