使用weka进行Cross-validation实验

Generating cross-validation folds (Java approach)

文献:

http://weka.wikispaces.com/Generating+cross-validation+folds+%28Java+approach%29


This article describes how to generate train/test splits for cross-validation using the Weka API directly. 

The following variables are given:

 Instances data =  ...;   // contains the full dataset we wann create train/test sets from

 int seed = ...;          // the seed for randomizing the data

 int folds = ...;         // the number of folds to generate, >=2



 Randomize the data

First, randomize your data:

 Random rand = new Random(seed);   // create seeded number generator

 randData = new Instances(data);   // create copy of original data

 randData.randomize(rand);         // randomize data with number generator

In case your data has a nominal class and you wanna perform stratified cross-validation:

 randData.stratify(folds);



 Generate the folds

 Single run

Next thing that we have to do is creating the train and the test set:

 for (int n = 0; n < folds; n++) {

   Instances train = randData.trainCV(folds, n);

   Instances test = randData.testCV(folds, n);

 

   // further processing, classification, etc.

   ...

 }

Note:

  • the above code is used by the weka.filters.supervised.instance.StratifiedRemoveFolds filter
  • the weka.classifiers.Evaluation class and the Explorer/Experimenter would use this method for obtaining the train set:

 Instances train = randData.trainCV(folds, n, rand);



 Multiple runs

The example above only performs one run of a cross-validation. In case you want to run 10 runs of 10-fold cross-validation, use the following loop:

 Instances data = ...;  // our dataset again, obtained from somewhere

 int runs = 10;

 for (int i = 0; i < runs; i++) {

   seed = i+1;  // every run gets a new, but defined seed value

 

   // see: randomize the data

   ...

 

   // see: generate the folds

   ...

 }

一个简单的小实验:

继续对上一节中的红酒和白酒进行分类。分类器没有变化,只是增加了重复试验过程

package assignment2;

import weka.core.Instances;

import weka.core.converters.ConverterUtils.DataSource;

import weka.core.Utils;

import weka.classifiers.Classifier;

import weka.classifiers.Evaluation;

import weka.classifiers.trees.J48;

import weka.filters.Filter;

import weka.filters.unsupervised.attribute.Remove;

 

import java.io.FileReader;

import java.util.Random;

public class cv_rw {

    public static Instances getFileInstances(String filename) throws Exception{

       FileReader frData =new FileReader(filename);

       Instances data = new Instances(frData);

       int length= data.numAttributes();

       String[] options = new String[2];

       options[0]="-R";

       options[1]=Integer.toString(length);

       Remove remove =new Remove();

       remove.setOptions(options);

       remove.setInputFormat(data);

       Instances newData= Filter.useFilter(data, remove);

       return newData;

    }

    public static void main(String[] args) throws Exception {

        // loads data and set class index

       Instances data = getFileInstances("D://Weka_tutorial//WineQuality//RedWhiteWine.arff");

//     System.out.println(instances);

       data.setClassIndex(data.numAttributes()-1);

 

        // classifier

//      String[] tmpOptions;

//      String classname;

//      tmpOptions     = Utils.splitOptions(Utils.getOption("W", args));

//      classname      = tmpOptions[0];

//      tmpOptions[0]  = "";

//      Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions);

//

//      // other options

//      int runs  = Integer.parseInt(Utils.getOption("r", args));//重复试验

//      int folds = Integer.parseInt(Utils.getOption("x", args));

       int runs=1;

       int folds=10;

       J48 j48= new J48();

//     j48.buildClassifier(instances);

 

        // perform cross-validation

        for (int i = 0; i < runs; i++) {

          // randomize data

          int seed = i + 1;

          Random rand = new Random(seed);

          Instances randData = new Instances(data);

          randData.randomize(rand);

//        if (randData.classAttribute().isNominal())    //没看懂这里什么意思,往高手回复,万分感谢

//          randData.stratify(folds);

 

          Evaluation eval = new Evaluation(randData);

          for (int n = 0; n < folds; n++) {

            Instances train = randData.trainCV(folds, n);

            Instances test = randData.testCV(folds, n);

            // the above code is used by the StratifiedRemoveFolds filter, the

            // code below by the Explorer/Experimenter:

            // Instances train = randData.trainCV(folds, n, rand);

 

            // build and evaluate classifier

            Classifier j48Copy = Classifier.makeCopy(j48);

            j48Copy.buildClassifier(train);

            eval.evaluateModel(j48Copy, test);

          }

 

          // output evaluation

          System.out.println();

          System.out.println("=== Setup run " + (i+1) + " ===");

          System.out.println("Classifier: " + j48.getClass().getName());

          System.out.println("Dataset: " + data.relationName());

          System.out.println("Folds: " + folds);

          System.out.println("Seed: " + seed);

          System.out.println();

          System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation run " + (i+1) + "===", false));

        }

 

    }

}

运行程序得到实验结果:

 

=== Setup run 1 ===

Classifier: weka.classifiers.trees.J48

Dataset: RedWhiteWine-weka.filters.unsupervised.instance.Randomize-S42-weka.filters.unsupervised.instance.Randomize-S42-weka.filters.unsupervised.attribute.Remove-R13

Folds: 10

Seed: 1

 

=== 10-fold Cross-validation run 1===

Correctly Classified Instances        6415               98.7379 %

Incorrectly Classified Instances        82                1.2621 %

Kappa statistic                          0.9658

Mean absolute error                      0.0159

Root mean squared error                  0.1109

Relative absolute error                  4.2898 %

Root relative squared error             25.7448 %

Total Number of Instances             6497     

原文地址:https://www.cnblogs.com/7899-89/p/3667330.html