encog(二)神经网络的训练

Encog中有很多的训练方法。

EncogUtility是一个功能辅助类,提供了很多方便的函数

Modifier and TypeMethod and Description
static double calculateClassificationError(MLClassification method, MLDataSet data)
Calculate the classification error.
static double calculateRegressionError(MLRegression method, MLDataSet data) 
static void convertCSV2Binary(File csvFile, CSVFormat format, File binFile, int[] input, int[] ideal, boolean headers) 
static void convertCSV2Binary(File csvFile, File binFile, int inputCount, int outputCount, boolean headers)
Convert a CSV file to a binary training file.
static void convertCSV2Binary(String csvFile, String binFile, int inputCount, int outputCount, boolean headers)
Convert a CSV file to a binary training file.
static void evaluate(MLRegression network, MLDataSet training)
Evaluate the network and display (to the console) the output for every value in the training set.
static void explainErrorMSE(MLRegression method, MatrixMLDataSet training) 
static void explainErrorRMS(MLRegression method, MatrixMLDataSet training) 
static String formatNeuralData(MLData data)
Format neural data as a list of numbers.
static MLDataSet loadCSV2Memory(String filename, int input, int ideal, boolean headers, CSVFormat format, boolean significance)
Load CSV to memory.
static MLDataSet loadEGB2Memory(File filename) 
static void saveCSV(File targetFile, CSVFormat format, MLDataSet set) 
static void saveEGB(File f, MLDataSet data)
Save a training set to an EGB file.
static BasicNetwork simpleFeedForward(int input, int hidden1, int hidden2, int output, boolean tanh)
Create a simple feedforward neural network.
static void trainConsole(BasicNetwork network, MLDataSet trainingSet, int minutes)
Train the neural network, using SCG training, and output status to the console.
static void trainConsole(MLTrain train, BasicNetwork network, MLDataSet trainingSet, int minutes)
Train the network, using the specified training algorithm, and send the output to the console.
static void trainToError(MLMethod method, MLDataSet dataSet, double error)
Train the method, to a specific error, send the output to the console.
static void trainToError(MLTrain train, double error)
Train to a specific error, using the specified training method, send the output to the console.

BasicTraining类是所有训练方法类的父类

构造函数
BasicTraining()
Used for serialization.
BasicTraining(TrainingImplementationType implementationType) 

 

返回值成员函数
void addStrategy(Strategy strategy)
Training strategies can be added to improve the training results.
void finishTraining()
Should be called after training has completed and the iteration method will not be called any further.
double getError()
TrainingImplementationType getImplementationType() 
int getIteration() 
List<Strategy> getStrategies() 
MLDataSet getTraining() 
boolean isTrainingDone() 
void iteration(int count)
Perform the specified number of training iterations.
void postIteration()
Call the strategies after an iteration.
void preIteration()
Call the strategies before an iteration.
void setError(double error) 
void setIteration(int iteration)
Set the current training iteration.
void setTraining(MLDataSet training)
Set the training object that this strategy is working with.

Backpropagation类是propagation类的子类

构造函数
Backpropagation(ContainsFlat network, MLDataSet training)
Create a class to train using backpropagation.

Backpropagation(ContainsFlat network, MLDataSet training, double theLearnRate, double theMomentum) 

第一个参数:将被训练的网络

第二个参数: 训练集

第三个参数:学习率

第四个参数: 梯度下降法中一种常用的加速技术。momentum是加速系数,momentum=0表示无加速,值越大表示加速越快。

 

返回值成员函数
boolean canContinue()
double[] getLastDelta() 
double getLearningRate() 
double getMomentum() 
void initOthers()
Perform training method specific init.
boolean isValidResume(TrainingContinuation state)
Determine if the specified continuation object is valid to resume with.
TrainingContinuation pause()
Pause the training.
void resume(TrainingContinuation state)
Resume training.
void setLearningRate(double rate)
Set the learning rate, this is value is essentially a percent.
void setMomentum(double m)
Set the momentum for training.
double updateWeight(double[] gradients, double[] lastGradient, int index)
Update a weight.
原文地址:https://www.cnblogs.com/codeDog123/p/6754391.html