打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
【干货】基于Keras的注意力机制实战

【导读】近几年,注意力机制(Attention)大量地出现在自动翻译、信息检索等模型中。可以把Attention看成模型中的一个特征选择组件,特征选择一方面可以增强模型的效果,另一方面,我们可以通过计算出的特征的权重来计算结果与特征之间的某种关联。例如在自动翻译模型中,Attention可以计算出不同语种词之间的关系。本文一个简单的例子,来展示Attention是怎么在模型中起到特征选择作用的。


代码




导入相关库

#coding=utf-8
import numpy as np
from keras.models import *
from keras.layers import Input, Dense, merge
import matplotlib.pyplot as plt
import pandas as pd


数据生成函数

# 输入维度
input_dim = 32


# 生成数据,数据的的第attention_column个特征由label决定,
# 即
label只与数据的第attention_column个特征相关
def get_data(n, input_dim, attention_column=1):
   x = np.random.standard_normal(size=(n, input_dim))
   y = np.random.randint(low=0, high=2, size=(n, 1))
   x[:, attention_column] = y[:, 0]
   return x, y


模型定义函数

将输入进行一次变换后,计算出Attention权重,将输入乘上Attention权重,获得新的特征。


# Attention模型
def build_model():
   inputs = Input(shape=(input_dim,))

   # 计算Attention权重
   
attention_probs = Dense(input_dim, activation='softmax',
name='attention_vec')(inputs)
   # 根据Attention权重更新特征
   
attention_mul = merge([inputs, attention_probs],
output_shape=32,
name='attention_mul', mode='mul')

   # 预测标签
   
attention_mul = Dense(64)(attention_mul)
   output = Dense(1, activation='sigmoid')(attention_mul)
   model = Model(input=[inputs], output=output)
   attention_vec_model = Model(input=[inputs],
output=attention_probs)
   return model, attention_vec_model


主函数

if __name__ == '__main__':
   # 生成训练数据
   
N = 10000
   
inputs_1, outputs = get_data(N, input_dim)

   # 获取模型,以及用于计算Attention权重的子模型
   
m, attention_vec_model = build_model()
   m.compile(optimizer='adam', loss='binary_crossentropy',
metrics=['accuracy'])
   print(m.summary())

   # 训练
   
m.fit([inputs_1], outputs, epochs=20, batch_size=64,
validation_split=0.5)

   # 生成测试数据
   
testing_inputs_1, testing_outputs = get_data(1, input_dim)

   # 根据测试数据计算Attention权重
   
attention_vector = attention_vec_model.
   predict([testing_inputs_1])[0].flatten()
   print('attention =', attention_vector)

   # 绘图
pd.DataFrame(attention_vector, columns=['attention (%)'])
.plot(kind='bar', title='Attention Mechanism as a function of
input dimensions.'
)
   plt.show()


运行结果

代码中,attention_column为1,也就是说,label只与数据的第1个特征相关。从运行结果中可以看出,Attention权重成功地获取了这个信息。


参考链接

https://github.com/philipperemy/keras-attention-mechanism


更多教程资料请访问:人工智能知识资料全集

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
Llama深入浅出
从零开始学自然语言处理(十一)——keras实现textCNN
使用keras和tensorflow保存为可部署的pb格式
“让Keras更酷一些!”:层与模型的重用技巧
使用LSTM神经网络进行音乐合成(数据格式,模型构建,完整源码)
Kaggle从零到实践:Bert中文多项选择
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服