【Deep Learning】BP网络手写识别

手写数字识别

1. BP神经网络

1.1 BP算法原理


BP即反向传播算法。利用输出后的误差计算上一层的误差,以此类推,直到输入层,然后再动态调节各层之间的连接权值来减小误差。

1.2 三层网络结构

1.2 Sigmoid激活函数

Sigmoid : f(x)=1/(1+e^−x)

f(x)的取值范围为0~1;

1.3 反向误差的计算及权值调整

BP神经网络的数学原理及其算法实现

2. 基于BP的手写数字识别

2.1 格式化输入

选取Mnist数据集中的0~9手写数字图片各28张。

将图像矩阵转化为BP网络可接受的一维数组形式:

//获取像素数组
    private double[] getImagePixel(String image) throws Exception {
        File file = new File(image);
        BufferedImage bi = null;
        try {
            bi = ImageIO.read(file);
        } catch (Exception e) {
            e.printStackTrace();
        }
        int width = bi.getWidth();
        int height = bi.getHeight();
        double[] vector = new double[width / 2 * height / 2];
        int k = 0;
        for (int i = 0; i < width; i += 2) {
            for (int j = 0; j < height; j += 2) {
                int whiteNum = 0;
                for (int m = 0; m < 2; m++) {
                    for (int n = 0; n < 2; n++) {
                        if (isWhite(bi, i + m, j + n)) {
                            whiteNum++;
                        }
                    }
                }
                vector[k++] = whiteNum / 4;
            }
        }
        // System.out.println(Arrays.toString(vector));
        return vector;
    }

为了减小运算量,选取一个2×2的矩阵区域,统计区域

内白色像素点的个数,将数据量缩小为原来的1/4。

//二值化并归一化数据
    private boolean isWhite(BufferedImage bi, int i, int j) {
        int pixel = bi.getRGB(i, j);
        int[] rgb = new int[3];
        rgb[0] = (pixel & 0xff0000) >> 16;
        rgb[1] = (pixel & 0xff00) >> 8;
        rgb[2] = (pixel & 0xff);
        double d = (double) ((rgb[0] * 38 + rgb[1] * 75 + rgb[2] * 15) >> 7) > 100 ? 1 : 0;
        if (d == 1) {
            return true;
        } else
            return false;
    }

2.创建神经网络并训练网络

对网络进行20000次迭代训练或直到误差小于0.001为止

String trianDataPath = "C:\Users\Administrator\Desktop\or_perceptron.nnet";
        String imgPath = "data/train2/";
        int maxLearn = 50000;
        double maxError = 0.0001;
        Stack<Long> stack = new Stack<>();
        stack.push(System.currentTimeMillis());
        System.out.println("->初始化多层网络...");
        MultiLayerPerceptron myMlPerceptron = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 100, 20, 10);
        BackPropagation bp = new BackPropagation();

        bp.setMaxIterations(maxLearn);
        bp.setMaxError(maxError);
        bp.setLearningRate(0.5);
        bp.setMinErrorChange(0.0001);
        //bp.setBatchMode(true);
        // bp.setMinErrorChangeIterationsLimit(1000);
        myMlPerceptron.setLearningRule(bp);
        System.out.println("->初始化完成");
        LearningRule learningRule = myMlPerceptron.getLearningRule();
        learningRule.addListener(new LearningEventListener() {
            @Override
            public void handleLearningEvent(LearningEvent event) {
                BackPropagation bp = (BackPropagation) event.getSource();
                int iteration = bp.getCurrentIteration();
                if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED && iteration % 100 == 0) {
                    System.out.print("->学习次数: " + iteration + ",当前误差: " + bp.getTotalNetworkError());
                    System.out.println(",用时:"+(System.currentTimeMillis()-stack.pop())+" ms");
                    stack.push(System.currentTimeMillis());
                }
            }
        });
        System.out.println("->创建数据集...");
        DataSet trainingSet = new DataSet(100, 10);
        // 添加训练数据到数据集
        File f = new File(imgPath);
        File[] list = f.listFiles();
        ImageToVector itv = new ImageToVector();
        for (int i = 0; i < list.length; i++) {
            String fileName = list[i].getName();
            double[] input = itv.imageToVector(imgPath + fileName, 20, 20);
            trainingSet.addRow(input, getTarget(getNumber(fileName)));
        }
        System.out.println("->数据集创建完成");
        System.out.println("->开始学习...");
        myMlPerceptron.learn(trainingSet);
        System.out.println("->学习完成");
        System.out.println("->保存学习数据...");
        myMlPerceptron.save(trianDataPath);
        System.out.println("->保存完成...");

2.3 测试网络

测试共选取4990张图片,识别正确2646张,正确率:0.530

2605。识别率较低。

原文地址:https://www.cnblogs.com/cnsec/p/13286769.html