BP神经网络的Java实现(转)

http://fantasticinblur.iteye.com/blog/1465497

课程作业要求实现一个BPNN。这次尝试使用Java实现了一个。现共享之。版权属于大家。关于BPNN的原理,就不赘述了。

下面是BPNN的实现代码。类名为BP。

Java代码  收藏代码
  1. package ml;  
  2.   
  3. import java.util.Random;  
  4.   
  5. /** 
  6.  * BPNN. 
  7.  *  
  8.  * @author RenaQiu 
  9.  *  
  10.  */  
  11. public class BP {  
  12.     /** 
  13.      * input vector. 
  14.      */  
  15.     private final double[] input;  
  16.     /** 
  17.      * hidden layer. 
  18.      */  
  19.     private final double[] hidden;  
  20.     /** 
  21.      * output layer. 
  22.      */  
  23.     private final double[] output;  
  24.     /** 
  25.      * target. 
  26.      */  
  27.     private final double[] target;  
  28.   
  29.     /** 
  30.      * delta vector of the hidden layer . 
  31.      */  
  32.     private final double[] hidDelta;  
  33.     /** 
  34.      * output layer of the output layer. 
  35.      */  
  36.     private final double[] optDelta;  
  37.   
  38.     /** 
  39.      * learning rate. 
  40.      */  
  41.     private final double eta;  
  42.     /** 
  43.      * momentum. 
  44.      */  
  45.     private final double momentum;  
  46.   
  47.     /** 
  48.      * weight matrix from input layer to hidden layer. 
  49.      */  
  50.     private final double[][] iptHidWeights;  
  51.     /** 
  52.      * weight matrix from hidden layer to output layer. 
  53.      */  
  54.     private final double[][] hidOptWeights;  
  55.   
  56.     /** 
  57.      * previous weight update. 
  58.      */  
  59.     private final double[][] iptHidPrevUptWeights;  
  60.     /** 
  61.      * previous weight update. 
  62.      */  
  63.     private final double[][] hidOptPrevUptWeights;  
  64.   
  65.     public double optErrSum = 0d;  
  66.   
  67.     public double hidErrSum = 0d;  
  68.   
  69.     private final Random random;  
  70.   
  71.     /** 
  72.      * Constructor. 
  73.      * <p> 
  74.      * <strong>Note:</strong> The capacity of each layer will be the parameter 
  75.      * plus 1. The additional unit is used for smoothness. 
  76.      * </p> 
  77.      *  
  78.      * @param inputSize 
  79.      * @param hiddenSize 
  80.      * @param outputSize 
  81.      * @param eta 
  82.      * @param momentum 
  83.      * @param epoch 
  84.      */  
  85.     public BP(int inputSize, int hiddenSize, int outputSize, double eta,  
  86.             double momentum) {  
  87.   
  88.         input = new double[inputSize + 1];  
  89.         hidden = new double[hiddenSize + 1];  
  90.         output = new double[outputSize + 1];  
  91.         target = new double[outputSize + 1];  
  92.   
  93.         hidDelta = new double[hiddenSize + 1];  
  94.         optDelta = new double[outputSize + 1];  
  95.   
  96.         iptHidWeights = new double[inputSize + 1][hiddenSize + 1];  
  97.         hidOptWeights = new double[hiddenSize + 1][outputSize + 1];  
  98.   
  99.         random = new Random(19881211);  
  100.         randomizeWeights(iptHidWeights);  
  101.         randomizeWeights(hidOptWeights);  
  102.   
  103.         iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1];  
  104.         hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1];  
  105.   
  106.         this.eta = eta;  
  107.         this.momentum = momentum;  
  108.     }  
  109.   
  110.     private void randomizeWeights(double[][] matrix) {  
  111.         for (int i = 0, len = matrix.length; i != len; i++)  
  112.             for (int j = 0, len2 = matrix[i].length; j != len2; j++) {  
  113.                 double real = random.nextDouble();  
  114.                 matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;  
  115.             }  
  116.     }  
  117.   
  118.     /** 
  119.      * Constructor with default eta = 0.25 and momentum = 0.3. 
  120.      *  
  121.      * @param inputSize 
  122.      * @param hiddenSize 
  123.      * @param outputSize 
  124.      * @param epoch 
  125.      */  
  126.     public BP(int inputSize, int hiddenSize, int outputSize) {  
  127.         this(inputSize, hiddenSize, outputSize, 0.25, 0.9);  
  128.     }  
  129.   
  130.     /** 
  131.      * Entry method. The train data should be a one-dim vector. 
  132.      *  
  133.      * @param trainData 
  134.      * @param target 
  135.      */  
  136.     public void train(double[] trainData, double[] target) {  
  137.         loadInput(trainData);  
  138.         loadTarget(target);  
  139.         forward();  
  140.         calculateDelta();  
  141.         adjustWeight();  
  142.     }  
  143.   
  144.     /** 
  145.      * Test the BPNN. 
  146.      *  
  147.      * @param inData 
  148.      * @return 
  149.      */  
  150.     public double[] test(double[] inData) {  
  151.         if (inData.length != input.length - 1) {  
  152.             throw new IllegalArgumentException("Size Do Not Match.");  
  153.         }  
  154.         System.arraycopy(inData, 0, input, 1, inData.length);  
  155.         forward();  
  156.         return getNetworkOutput();  
  157.     }  
  158.   
  159.     /** 
  160.      * Return the output layer. 
  161.      *  
  162.      * @return 
  163.      */  
  164.     private double[] getNetworkOutput() {  
  165.         int len = output.length;  
  166.         double[] temp = new double[len - 1];  
  167.         for (int i = 1; i != len; i++)  
  168.             temp[i - 1] = output[i];  
  169.         return temp;  
  170.     }  
  171.   
  172.     /** 
  173.      * Load the target data. 
  174.      *  
  175.      * @param arg 
  176.      */  
  177.     private void loadTarget(double[] arg) {  
  178.         if (arg.length != target.length - 1) {  
  179.             throw new IllegalArgumentException("Size Do Not Match.");  
  180.         }  
  181.         System.arraycopy(arg, 0, target, 1, arg.length);  
  182.     }  
  183.   
  184.     /** 
  185.      * Load the training data. 
  186.      *  
  187.      * @param inData 
  188.      */  
  189.     private void loadInput(double[] inData) {  
  190.         if (inData.length != input.length - 1) {  
  191.             throw new IllegalArgumentException("Size Do Not Match.");  
  192.         }  
  193.         System.arraycopy(inData, 0, input, 1, inData.length);  
  194.     }  
  195.   
  196.     /** 
  197.      * Forward. 
  198.      *  
  199.      * @param layer0 
  200.      * @param layer1 
  201.      * @param weight 
  202.      */  
  203.     private void forward(double[] layer0, double[] layer1, double[][] weight) {  
  204.         // threshold unit.  
  205.         layer0[0] = 1.0;  
  206.         for (int j = 1, len = layer1.length; j != len; ++j) {  
  207.             double sum = 0;  
  208.             for (int i = 0, len2 = layer0.length; i != len2; ++i)  
  209.                 sum += weight[i][j] * layer0[i];  
  210.             layer1[j] = sigmoid(sum);  
  211.         }  
  212.     }  
  213.   
  214.     /** 
  215.      * Forward. 
  216.      */  
  217.     private void forward() {  
  218.         forward(input, hidden, iptHidWeights);  
  219.         forward(hidden, output, hidOptWeights);  
  220.     }  
  221.   
  222.     /** 
  223.      * Calculate output error. 
  224.      */  
  225.     private void outputErr() {  
  226.         double errSum = 0;  
  227.         for (int idx = 1, len = optDelta.length; idx != len; ++idx) {  
  228.             double o = output[idx];  
  229.             optDelta[idx] = o * (1d - o) * (target[idx] - o);  
  230.             errSum += Math.abs(optDelta[idx]);  
  231.         }  
  232.         optErrSum = errSum;  
  233.     }  
  234.   
  235.     /** 
  236.      * Calculate hidden errors. 
  237.      */  
  238.     private void hiddenErr() {  
  239.         double errSum = 0;  
  240.         for (int j = 1, len = hidDelta.length; j != len; ++j) {  
  241.             double o = hidden[j];  
  242.             double sum = 0;  
  243.             for (int k = 1, len2 = optDelta.length; k != len2; ++k)  
  244.                 sum += hidOptWeights[j][k] * optDelta[k];  
  245.             hidDelta[j] = o * (1d - o) * sum;  
  246.             errSum += Math.abs(hidDelta[j]);  
  247.         }  
  248.         hidErrSum = errSum;  
  249.     }  
  250.   
  251.     /** 
  252.      * Calculate errors of all layers. 
  253.      */  
  254.     private void calculateDelta() {  
  255.         outputErr();  
  256.         hiddenErr();  
  257.     }  
  258.   
  259.     /** 
  260.      * Adjust the weight matrix. 
  261.      *  
  262.      * @param delta 
  263.      * @param layer 
  264.      * @param weight 
  265.      * @param prevWeight 
  266.      */  
  267.     private void adjustWeight(double[] delta, double[] layer,  
  268.             double[][] weight, double[][] prevWeight) {  
  269.   
  270.         layer[0] = 1;  
  271.         for (int i = 1, len = delta.length; i != len; ++i) {  
  272.             for (int j = 0, len2 = layer.length; j != len2; ++j) {  
  273.                 double newVal = momentum * prevWeight[j][i] + eta * delta[i]  
  274.                         * layer[j];  
  275.                 weight[j][i] += newVal;  
  276.                 prevWeight[j][i] = newVal;  
  277.             }  
  278.         }  
  279.     }  
  280.   
  281.     /** 
  282.      * Adjust all weight matrices. 
  283.      */  
  284.     private void adjustWeight() {  
  285.         adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights);  
  286.         adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights);  
  287.     }  
  288.   
  289.     /** 
  290.      * Sigmoid. 
  291.      *  
  292.      * @param val 
  293.      * @return 
  294.      */  
  295.     private double sigmoid(double val) {  
  296.         return 1d / (1d + Math.exp(-val));  
  297.     }  
  298. }  

 为了验证正确性,我写了一个测试用例,目的是对于任意的整数(int型),BPNN在经过训练之后,能够准确地判断出它是奇数还是偶数,正数还是负数。首先对于训练的样本(是随机生成的数字),将它转化为一个32位的向量,向量的每个分量就是其二进制形式对应的位上的0或1。将目标输出视作一个4维的向量,[1,0,0,0]代表正奇数,[0,1,0,0]代表正偶数,[0,0,1,0]代表负奇数,[0,0,0,1]代表负偶数。

训练样本为1000个,学习200次。

Java代码  收藏代码
  1. package ml;  
  2.   
  3. import java.io.IOException;  
  4. import java.util.ArrayList;  
  5. import java.util.List;  
  6. import java.util.Random;  
  7.   
  8. public class Test {  
  9.   
  10.     /** 
  11.      * @param args 
  12.      * @throws IOException 
  13.      */  
  14.     public static void main(String[] args) throws IOException {  
  15.         BP bp = new BP(32, 15, 4);  
  16.   
  17.         Random random = new Random();  
  18.         List<Integer> list = new ArrayList<Integer>();  
  19.         for (int i = 0; i != 1000; i++) {  
  20.             int value = random.nextInt();  
  21.             list.add(value);  
  22.         }  
  23.   
  24.         for (int i = 0; i != 200; i++) {  
  25.             for (int value : list) {  
  26.                 double[] real = new double[4];  
  27.                 if (value >= 0)  
  28.                     if ((value & 1) == 1)  
  29.                         real[0] = 1;  
  30.                     else  
  31.                         real[1] = 1;  
  32.                 else if ((value & 1) == 1)  
  33.                     real[2] = 1;  
  34.                 else  
  35.                     real[3] = 1;  
  36.                 double[] binary = new double[32];  
  37.                 int index = 31;  
  38.                 do {  
  39.                     binary[index--] = (value & 1);  
  40.                     value >>>= 1;  
  41.                 } while (value != 0);  
  42.   
  43.                 bp.train(binary, real);  
  44.             }  
  45.         }  
  46.   
  47.         System.out.println("训练完毕,下面请输入一个任意数字,神经网络将自动判断它是正数还是复数,奇数还是偶数。");  
  48.   
  49.         while (true) {  
  50.             byte[] input = new byte[10];  
  51.             System.in.read(input);  
  52.             Integer value = Integer.parseInt(new String(input).trim());  
  53.             int rawVal = value;  
  54.             double[] binary = new double[32];  
  55.             int index = 31;  
  56.             do {  
  57.                 binary[index--] = (value & 1);  
  58.                 value >>>= 1;  
  59.             } while (value != 0);  
  60.   
  61.             double[] result = bp.test(binary);  
  62.   
  63.             double max = -Integer.MIN_VALUE;  
  64.             int idx = -1;  
  65.   
  66.             for (int i = 0; i != result.length; i++) {  
  67.                 if (result[i] > max) {  
  68.                     max = result[i];  
  69.                     idx = i;  
  70.                 }  
  71.             }  
  72.   
  73.             switch (idx) {  
  74.             case 0:  
  75.                 System.out.format("%d是一个正奇数 ", rawVal);  
  76.                 break;  
  77.             case 1:  
  78.                 System.out.format("%d是一个正偶数 ", rawVal);  
  79.                 break;  
  80.             case 2:  
  81.                 System.out.format("%d是一个负奇数 ", rawVal);  
  82.                 break;  
  83.             case 3:  
  84.                 System.out.format("%d是一个负偶数 ", rawVal);  
  85.                 break;  
  86.             }  
  87.         }  
  88.     }  
  89.   
  90. }  

 运行结果截图如下:



 这个测试的例子非常简单。大家可以根据自己的需要去使用BP这个类。

原文地址:https://www.cnblogs.com/bnuvincent/p/6476040.html