K-fold Train Version3

# config.py
TRAINING_FILE = "../input/mnist_train_folds.csv"
MODEL_OUTPUT = "../models/"
# model_dispatcher.py
from sklearn import tree
from sklearn import ensemble
models = {
"decision_tree_gini": tree.DecisionTreeClassifier(
criterion="gini"
),
"decision_tree_entropy": tree.DecisionTreeClassifier(
criterion="entropy"
),
"rf": ensemble.RandomForestClassifier(),
}
# train.py
import argparse
import os
import joblib
import pandas as pd
from sklearn import metrics
import config
import model_dispatcher

def run(fold, model):
# read the training data with folds
df = pd.read_csv(config.TRAINING_FILE)
# training data is where kfold is not equal to provided fold
# also, note that we reset the index
df_train = df[df.kfold != fold].reset_index(drop=True)
# validation data is where kfold is equal to provided fold
df_valid = df[df.kfold == fold].reset_index(drop=True)
# drop the label column from dataframe and convert it to
# a numpy array by using .values.
# target is label column in the dataframe
x_train = df_train.drop("label", axis=1).values
y_train = df_train.label.values
# similarly, for validation, we have
x_valid = df_valid.drop("label", axis=1).values
y_valid = df_valid.label.values
# fetch the model from model_dispatcher
clf = model_dispatcher.models[model]
# fir the model on training data
clf.fit(x_train, y_train)
# create predictions for validation samples
preds = clf.predict(x_valid)
# calculate & print accuracy
accuracy = metrics.accuracy_score(y_valid, preds)
print(f"Fold={fold}, Accuracy={accuracy}")
# save the model
joblib.dump(
clf,
os.path.join(config.MODEL_OUTPUT, f"dt_{fold}.bin")
)


if __name__ == "__main__":
# initialize ArgumentParser class of argparse
parser = argparse.ArgumentParser()
# add the different arguments needed and their type
# currently, only need fold
parser.add_argument(
"--fold",
type=int
)
parser.add_argument(
"--model",
type=str
)
# read the arguments from the command line
args = parser.parse_args()
# run the fold specified by command line arguments
run(fold=args.fold,
model=args.model
)
================================================
#!/bin/sh
# run.sh
python train.py --fold 0 --model rf
python train.py --fold 1 --model rf
python train.py --fold 2 --model rf
python train.py --fold 3 --model rf
python train.py --fold 4 --model rf
================================================
sh run.sh
原文地址:https://www.cnblogs.com/songyuejie/p/14789476.html