利用pytorch实现前馈网络分类的chatbot

一.目的

利用pytorch实现前馈网络意图分类,实现一个简单的chatbot。

二.数据

数据为英文数据,如下:

{'intents': [{'tag': 'greeting', 'patterns': ['Hi there', 'How are you', 'Is anyone there?', 'Hey', 'Hola', 'Hello', 'Good day'], 'responses': ['Hello, thanks for asking', 'Good to see you again', 'Hi there, how can I help?'], 'context': ['']}, {'tag': 'goodbye', 'patterns': ['Bye', 'See you later', 'Goodbye', 'Nice chatting to you, bye', 'Till next time'], 'responses': ['See you!', 'Have a nice day', 'Bye! Come back again soon.'], 'context': ['']}, {'tag': 'thanks', 'patterns': ['Thanks', 'Thank you', "That's helpful", 'Awesome, thanks', 'Thanks for helping me'], 'responses': ['Happy to help!', 'Any time!', 'My pleasure'], 'context': ['']}, {'tag': 'noanswer', 'patterns': [], 'responses': ["Sorry, can't understand you", 'Please give me more info', 'Not sure I understand'], 'context': ['']}, {'tag': 'options', 'patterns': ['How you could help me?', 'What you can do?', 'What help you provide?', 'How you can be helpful?', 'What support is offered'], 'responses': ['I can guide you through Adverse drug reaction list, Blood pressure tracking, Hospitals and Pharmacies', 'Offering support for Adverse drug reaction, Blood pressure, Hospitals and Pharmacies'], 'context': ['']}, {'tag': 'adverse_drug', 'patterns': ['How to check Adverse drug reaction?', 'Open adverse drugs module', 'Give me a list of drugs causing adverse behavior', 'List all drugs suitable for patient with adverse reaction', 'Which drugs dont have adverse reaction?'], 'responses': ['Navigating to Adverse drug reaction module'], 'context': ['']}, {'tag': 'blood_pressure', 'patterns': ['Open blood pressure module', 'Task related to blood pressure', 'Blood pressure data entry', 'I want to log blood pressure results', 'Blood pressure data management'], 'responses': ['Navigating to Blood Pressure module'], 'context': ['']}, {'tag': 'blood_pressure_search', 'patterns': ['I want to search for blood pressure result history', 'Blood pressure for patient', 'Load patient blood pressure result', 'Show blood pressure results for patient', 'Find blood pressure results by ID'], 'responses': ['Please provide Patient ID', 'Patient ID?'], 'context': ['search_blood_pressure_by_patient_id']}, {'tag': 'search_blood_pressure_by_patient_id', 'patterns': [], 'responses': ['Loading Blood pressure result for Patient'], 'context': ['']}, {'tag': 'pharmacy_search', 'patterns': ['Find me a pharmacy', 'Find pharmacy', 'List of pharmacies nearby', 'Locate pharmacy', 'Search pharmacy'], 'responses': ['Please provide pharmacy name'], 'context': ['search_pharmacy_by_name']}, {'tag': 'search_pharmacy_by_name', 'patterns': [], 'responses': ['Loading pharmacy details'], 'context': ['']}, {'tag': 'hospital_search', 'patterns': ['Lookup for hospital', 'Searching for hospital to transfer patient', 'I want to search hospital data', 'Hospital lookup for patient', 'Looking up hospital details'], 'responses': ['Please provide hospital name or location'], 'context': ['search_hospital_by_params']}, {'tag': 'search_hospital_by_params', 'patterns': [], 'responses': ['Please provide hospital type'], 'context': ['search_hospital_by_type']}, {'tag': 'search_hospital_by_type', 'patterns': [], 'responses': ['Loading hospital details'], 'context': ['']}]}

其中,tag表示意图类别,patterns表示类别下的样本,responses为意图识别后的回应语句。

三.程序

完成程序和数据见(https://github.com/jiangnanboy/chatbot)

import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import numpy as np
import os
import json
import random

intent_json_path = os.path.join(os.getcwd(), "intents.json")
with open(intent_json_path, 'r', encoding='utf-8') as f:
    intents = json.load(f)
    
words_path = os.path.join(os.getcwd(), "words.pkl")
with open(words_path, 'rb') as f_words:
    words = pickle.load(f_words)
    
classes_path = os.path.join(os.getcwd(), "classes.pkl")
with open(classes_path, 'rb') as f_classes:
    classes = pickle.load(f_classes)
    
classes_index_path = os.path.join(os.getcwd(), "classes_index.pkl")
with open(classes_index_path, 'rb') as f_classes_index:
    classes_index = pickle.load(f_classes_index)
index_classes = dict(zip(classes_index.values(), classes_index.keys()))
print('index_classes:{}'.format(index_classes))
class classifyModel(nn.Module):
    
    def __init__(self):
        super(classifyModel, self).__init__()
        self.model = nn.Sequential(
                nn.Linear(len(words), 128),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(64, len(classes)))
    def forward(self, x):
            out = self.model(x)
            return out
        
model = classifyModel()
model_path = os.path.join(os.getcwd(), "chatbot_model.h5")
model.load_state_dict(torch.load(model_path))
import nltk
from nltk.stem import WordNetLemmatizer

lemmatizer = WordNetLemmatizer()

def clean_up_sentence(sentence):
    sentence_words = nltk.word_tokenize(sentence) #分词
    sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words] #词干
    return sentence_words

def bow(sentence, words, show_detail = True):
    sentence_words = clean_up_sentence(sentence)
    #词袋
    bag = [0] * len(words)
    for s in sentence_words:
        for i,w in enumerate(words):
            if w == s:
                bag[i] = 1 #词在词典中
            if show_detail:
                print("found in bag:{}".format(w))
    return [bag]

def predict_class(sentence, model):
    sentence_bag = bow(sentence, words, False)
    model.eval()
    outputs = model(torch.FloatTensor(sentence_bag))
    print('outputs:{}'.format(outputs))
    predicted_prob,predicted_index = torch.max(F.softmax(outputs, 1), 1)#预测最大类别的概率与索引
    print('softmax_prob:{}'.format(predicted_prob))
    print('softmax_index:{}'.format(predicted_index))
    results = []
    results.append({'intent':index_classes[predicted_index.detach().numpy()[0]], 'prob':str(predicted_prob.detach().numpy()[0])})
    print('result:{}'.format(results))
    return results
 
def get_response(predict_result, intents_json):
    tag = predict_result[0]['intent']
    list_of_intents = intents_json['intents']
    for i in list_of_intents:
        if(i['tag'] == tag):
            result = random.choice(i['responses'])
            break
    return result

def chatbot_response(text):
    predict_result = predict_class(text, model)
    res = get_response(predict_result, intents)
    return res
print(chatbot_response("Lookup for hospital"))
import tkinter
from tkinter import *

def send():
    msg = EntryBox.get("1.0",'end-1c').strip()
    EntryBox.delete("0.0",END)
    if msg != '':
            ChatLog.config(state=NORMAL)
            ChatLog.insert(END, "你: " + msg + '

')
            ChatLog.config(foreground="#442265", font=("Verdana", 12 ))
            res = chatbot_response(msg)
            ChatLog.insert(END, "机器人: " + res + '

')
            ChatLog.config(state=DISABLED)
            ChatLog.yview(END)
base = Tk()
base.title("Hello")
base.geometry("400x500")
base.resizable(width=FALSE, height=FALSE)
#Create Chat window
ChatLog = Text(base, bd=0, bg="white", height="8", width="50", font="Arial",)
ChatLog.config(state=DISABLED)
#Bind scrollbar to Chat window
scrollbar = Scrollbar(base, command=ChatLog.yview, cursor="heart")
ChatLog['yscrollcommand'] = scrollbar.set
#Create Button to send message
SendButton = Button(base, font=("Verdana",12,'bold'), text="发送", width="12", height=5,
                    bd=0, bg="#32de97", activebackground="#3c9d9b",fg='#ffffff',
                    command= send )
#Create the box to enter message
EntryBox = Text(base, bd=0, bg="white",width="29", height="5", font="Arial")
#EntryBox.bind("<Return>", send)
#Place all components on the screen
scrollbar.place(x=376,y=6, height=386)
ChatLog.place(x=6,y=6, height=386, width=370)
EntryBox.place(x=128, y=401, height=90, width=265)
SendButton.place(x=6, y=401, height=90)
base.mainloop()

四.结果

原文地址:https://www.cnblogs.com/little-horse/p/14033132.html