dga model train and test code

# _*_coding:UTF-8_*_

import operator
import tldextract
import random
import pickle
import os
import tflearn

from math import log
from tflearn.data_utils import to_categorical, pad_sequences
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_1d, max_pool_1d
from tflearn.layers.estimator import regression
from tflearn.layers.normalization import batch_normalization
from sklearn.model_selection import train_test_split


def get_cnn_model(max_len, volcab_size=None):
    if volcab_size is None:
        volcab_size = 10240000

    # Building convolutional network
    network = tflearn.input_data(shape=[None, max_len], name='input')
    network = tflearn.embedding(network, input_dim=volcab_size, output_dim=32)

    network = conv_1d(network, 64, 3, activation='relu', regularizer="L2")
    network = max_pool_1d(network, 2)
    network = conv_1d(network, 64, 3, activation='relu', regularizer="L2")
    network = max_pool_1d(network, 2)

    network = batch_normalization(network)
    network = fully_connected(network, 64, activation='relu')
    network = dropout(network, 0.5)

    network = fully_connected(network, 2, activation='softmax')
    sgd = tflearn.SGD(learning_rate=0.1, lr_decay=0.96, decay_step=1000)
    network = regression(network, optimizer=sgd, loss='categorical_crossentropy')

    model = tflearn.DNN(network, tensorboard_verbose=0)
    return model


def get_data_from(file_name):
    ans = []
    with open(file_name) as f:
        for line in f:
            domain_name = line.strip()
            ans.append(domain_name)
    return ans


def get_local_data(tag="labeled"):
    white_data = get_data_from(file_name="dga_360_sorted.txt")
    black_data = get_data_from(file_name="top-1m.csv")
    return black_data, white_data


def get_data():
    black_x, white_x = get_local_data()
    black_y, white_y = [1]*len(black_x), [0]*len(white_x)

    X = black_x + white_x
    labels = black_y + white_y

    # Generate a dictionary of valid characters
    valid_chars = {x:idx+1 for idx, x in enumerate(set(''.join(X)))}

    max_features = len(valid_chars) + 1
    print("max_features:", max_features)
    maxlen = max([len(x) for x in X])
    print("max_len:", maxlen)
    maxlen = min(maxlen, 256)

    # Convert characters to int and pad
    X = [[valid_chars[y] for y in x] for x in X]
    X = pad_sequences(X, maxlen=maxlen, value=0.)

    # Convert labels to 0-1
    Y = to_categorical(labels, nb_classes=2)
    
    volcab_file = "volcab.pkl"
    output = open(volcab_file, 'wb')
    # Pickle dictionary using protocol 0.
    data = {"valid_chars": valid_chars, "max_len": maxlen, "volcab_size": max_features}
    pickle.dump(data, output)
    output.close()

    return X, Y, maxlen, max_features


def train_model():
    X, Y, max_len, volcab_size = get_data()

    print("X len:", len(X), "Y len:", len(Y))
    trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=42)
    print(trainX[:1])
    print(trainY[:1])
    print(testX[-1:])
    print(testY[-1:])

    model = get_cnn_model(max_len, volcab_size)
    model.fit(trainX, trainY, validation_set=(testX, testY), show_metric=True, batch_size=1024)
   
    filename = 'finalized_model.tflearn'
    model.save(filename)

    model.load(filename)
    print("Just review 3 sample data test result:")
    result = model.predict(testX[0:3])
    print(result)


def test_model():
    volcab_file = "volcab.pkl"
    assert os.path.exists(volcab_file)
    pkl_file = open(volcab_file, 'rb')
    data = pickle.load(pkl_file)
    valid_chars, max_document_length, max_features = data["valid_chars"], data["max_len"], data["volcab_size"]

    print("max_features:", max_features)
    print("max_len:", max_document_length)

    cnn_model = get_cnn_model(max_document_length, max_features)
    filename = 'finalized_model.tflearn'
    cnn_model.load(filename)
    print("predict domains:")
    bls = list()

    
    with open("dga_360_sorted.txt") as f:
    # with open("todo.txt") as f:
        lines = f.readlines()
        print("domain_list len:", len(lines))
        cnt = 1000
        for i in range(0, len(lines), cnt):
            lines2 = lines[i:i+cnt]
            domain_list = [line.strip() for line in lines2]
            #print("domain_list sample:", domain_list[:5])
        
            # Convert characters to int and pad
            X = [[valid_chars[y] if y in valid_chars else 0 for y in x] for x in domain_list]
            X = pad_sequences(X, maxlen=max_document_length, value=0.)
        
            result = cnn_model.predict(X)
            for i, domain in enumerate(domain_list):
                if result[i][1] > .5: #.95:
                    #print(lines2[i], domain + " is GDA")
                    print(lines2[i].strip() + "	" + domain, result[i][1])
                    bls.append(domain)
                else:
                    #print(lines2[i], domain )
                    pass
            #print(bls)
        print(len(bls) , "dga found!")


if __name__ == "__main__":
    print("train model...")
    train_model()
    print("test model...")
    test_model()
原文地址:https://www.cnblogs.com/bonelee/p/11958214.html