# 说明

构建决策树的算法是递归算法,在 functions.cpp 中定义的函数 buildDecisionTree() 中实现。该算法的工作原理如下:

if the sub-table passed to the algorithm is empty
    return NULL;    // since there is no data in the table
if the sub-table passed to the algorithm is homogeneous (if all the rows have the same value for the last column)
    mark this node as a leaf node;
    label this node with the value of the last column
    return a pointer to this node
else
    decide a column to split the table on based on information gain
    set the node's splitOn value to this column's name
    for all the values that the splitting column can take:
        create a new node
        set the new node as the current node's child node
        prune the sub-table so that all the rows with this value of the last column are removed
        recursively call the function by passing it the new pruned table and the new node

拆分列是根据信息增益决定的。该信息增益是使用熵计算的,熵是信息随机性的度量。拆分离开表格的随机性越小,我们可以从中获得的信息就越多。因此,我们在信息增益(最小熵)最多的属性上拆分。

make clean
make 
./dtree train.txt test.txt

老师给的

色泽 根蒂 敲声 纹理 脐部 触感 好瓜
青绿 卷缩 浊响 清晰 凹陷 硬滑 是
乌黑 卷缩 沉闷 清晰 凹陷 硬滑 是
乌黑 卷缩 浊响 清晰 凹陷 硬滑 是
青绿 卷缩 沉闷 清晰 凹陷 硬滑 是
浅白 卷缩 浊响 清晰 凹陷 硬滑 是
青绿 稍卷 浊响 清晰 稍凹 软黏 是
乌黑 稍卷 浊响 稍糊 稍凹 软黏 是
乌黑 稍卷 浊响 清晰 稍凹 硬滑 是
乌黑 稍卷 沉闷 稍糊 稍凹 硬滑 否
青绿 硬挺 清脆 清晰 平坦 软黏 否
浅白 硬挺 清脆 模糊 平坦 硬滑 否
浅白 卷缩 浊响 模糊 平坦 软黏 否
青绿 稍卷 浊响 稍糊 凹陷 硬滑 否
浅白 稍卷 沉闷 稍糊 凹陷 硬滑 否
乌黑 稍卷 浊响 清晰 稍凹 软黏 否
浅白 卷缩 浊响 模糊 平坦 硬滑 否
青绿 卷缩 沉闷 稍糊 稍凹 硬滑 否

手动处理后的(打乱)

color,root,knock,texture,navel,touch,goodMelon
green,slightlyRolled,turbidSound,slightlyMushy,concave,hardSlide,no
black,slightlyRolled,turbidSound,clear,slightlyConcave,hardSlide,yes
green,Rolled,dull,slightlyMushy,slightlyConcave,hardSlide,no
white,stiff,crisp,mushy,flat,hardSlide,no
white,Rolled,turbidSound,mushy,flat,hardSlide,no
black,rolled,turbidSound,clear,concave,hardSlide,yes
black,slightlyRolled,turbidSound,clear,slightlyConcave,sticky,no
black,slightlyRolled,dull,slightlyMushy,slightlyConcave,hardSlide,no
green,rolled,dull,clear,concave,hardSlide,yes
white,rolled,turbidSound,clear,concave,hardSlide,yes
black,rolled,dull,clear,concave,hardSlide,yes
white,Rolled,turbidSound,mushy,flat,sticky,no
green,stiff,crisp,clear,flat,sticky,no
black,slightlyRolled,turbidSound,slightlyMushy,slightlyConcave,sticky,yes
green,rolled,turbidSound,clear,concave,hardSlide,yes
green,slightlyRolled,turbidSound,clear,slightlyConcave,sticky,yes
white,slightlyRolled,dull,slightlyMushy,concave,hardSlide,no

train.txt

color,root,knock,texture,navel,touch,goodMelon
green,slightlyRolled,turbidSound,slightlyMushy,concave,hardSlide,no
black,slightlyRolled,turbidSound,clear,slightlyConcave,hardSlide,yes
green,Rolled,dull,slightlyMushy,slightlyConcave,hardSlide,no
white,stiff,crisp,mushy,flat,hardSlide,no
white,Rolled,turbidSound,mushy,flat,hardSlide,no
black,rolled,turbidSound,clear,concave,hardSlide,yes
black,slightlyRolled,turbidSound,clear,slightlyConcave,sticky,no
black,slightlyRolled,dull,slightlyMushy,slightlyConcave,hardSlide,no
green,rolled,dull,clear,concave,hardSlide,yes
white,rolled,turbidSound,clear,concave,hardSlide,yes
black,rolled,dull,clear,concave,hardSlide,yes

test.txt

color,root,knock,texture,navel,touch,goodMelon
white,Rolled,turbidSound,mushy,flat,sticky,no
green,stiff,crisp,clear,flat,sticky,no
black,slightlyRolled,turbidSound,slightlyMushy,slightlyConcave,sticky,yes
green,rolled,turbidSound,clear,concave,hardSlide,yes
green,slightlyRolled,turbidSound,clear,slightlyConcave,sticky,yes
white,slightlyRolled,dull,slightlyMushy,concave,hardSlide,no

decisionTreeOutput.txt

#     Given Class                Predicted Class
--------------------------------------------------
  1              no  ------------  no
  2              no  ------------  no
  3             yes  xxxxxxxxxxxx  no
  4             yes  ------------  yes
  5             yes  xxxxxxxxxxxx  no
  6              no  ------------  no
--------------------------------------------------
Total number of instances in test data = 6
Number of correctly predicted instances = 4
Accuracy of decision tree classifier = 8%

详细地址见:decisionTree

# 代码

# header.h

#include <iostream>
#include <string>
#include <vector>
#include <fstream>
#include <map>
#include <math.h>
#include <float.h>
#include <cstdlib>
#include <iomanip>
using namespace std;
typedef vector<string> vs;
typedef vector<vs> vvs;
typedef vector<int> vi;
typedef map<string, int> msi;
typedef vector<double> vd;
struct node													// struct node defines the structure of a node of the decision tree
{
	string splitOn;											// Stores which attribute to split on at a particular node
	string label;											// Stores the class label for leaf nodes. For nodes that are not leaf nodes, it stores the value of the attribute of the parent's' split 
	bool isLeaf;											// boolean flag for leaf nodes
	vector<string> childrenValues;							// Stores the values of the childrens' attributes
	vector<node*> children;									// Stores pointers to the children of a node
};
void parse(string&, vvs&);									// Parses a single line from the input file and stores the information into a vector of vector of strings 
void printAttributeTable(vvs&);								// For debugging purposes only. Prints a data table
vvs pruneTable(vvs&, string&, string);						// Prunes a table based on a column/attribute's name and the value of that attribute. Removes that column and all instances that have that value for that column
node* buildDecisionTree(vvs&, node*, vvs&);					// Builds the decision tree based on the table it is passed
bool isHomogeneous(vvs&);									// Returns true if all instances in a subtable at a node have the same class label
vi countDistinct(vvs&, int);								// Returns a vector of integers containing the counts of all the various values of an attribute/column
string decideSplittingColumn(vvs&);							// Returns the column on which to split on. Decision of column is based on entropy
int returnColumnIndex(string&, vvs&);						// Returns the index of a column in a subtable
bool tableIsEmpty(vvs&);									// Returns true if a subtable is empty
void printDecisionTree(node*);								// For degubbing purposes only. Recursively prints decision tree
string testDataOnDecisionTree(vs&, node*, vvs&, string);	// Runs a single instance of the test data through the decision tree. Returns the predicted class label
int returnIndexOfVector(vs&, string);						// Returns the index of a string in a vector of strings
double printPredictionsAndCalculateAccuracy(vs&, vs&);		// Outputs the predictions to file and returns the accuracy of the classification
vvs generateTableInfo(vvs &dataTable);						// Generates information about the table in a vector of vector of stings
string returnMostFrequentClass(vvs &dataTable);				// Returns the most frequent class from the training data. This class is used as the default class during the testing phase

# DecisionTree.cpp

#include "header.h"
int main(int argc, const char *argv[])
{
	ifstream inputFile;												// Input file stream
	string singleInstance;											// Single line read from the input file 
	vvs dataTable;													// Input data in the form of a vector of vector of strings
	inputFile.open(argv[1]);
	if (!inputFile)													// If input file does not exist, print error and exit
	{
		cerr << "Error: Training data file not found!" << endl;
		exit(-1);
	}
	/* 
	 * Decision tree training phase
	 * In this phase, the training data is read
	 * from the file and stored into a vvs using
	 * the parse() function. The generateTableInfo()
	 * function extracts the attribute (column) names
	 * and also the values that each column can take.
	 * This information is also stored in a vvs.
	 * buildDecisionTree() function recursively
	 * builds trains the decision tree.
	 */
	while (getline(inputFile, singleInstance))						// Read from file, parse and store data
	{
		parse(singleInstance, dataTable);
	}
	inputFile.close(); 												// Close input file
	vvs tableInfo = generateTableInfo(dataTable);					// Stores all the attributes and their values in a vector of vector of strings named tableInfo
	node* root = new node;											// Declare and assign memory for the root node of the Decision Tree
	root = buildDecisionTree(dataTable, root, tableInfo);			// Recursively build and train decision tree
	string defaultClass = returnMostFrequentClass(dataTable);		// Stores the most frequent class in the training data. This is used as the default class label
	dataTable.clear(); 												// clear dataTable of training data to store testing data
	/*
	 * Decision tree testing phase
	 * In this phase, the testing is read
	 * from the file, parsed and stored.
	 * Each row in the table is made to
	 * traverse down the decision tree
	 * till a class label is found.
	 */
	inputFile.clear();
	inputFile.open(argv[2]); 										// Open test file
	if (!inputFile) 												// Exit if test file is not found
	{
		cerr << "Error: Testing data file not found!" << endl;
		exit(-1);
	}
	while (getline(inputFile, singleInstance)) 						// Store test data in a table
	{
		parse(singleInstance, dataTable);
	}
	vs predictedClassLabels;										// Stores the predicted class labels for each row
	vs givenClassLabels;											// Stores the given class labels in the test data
	for (int iii = 1; iii < dataTable.size(); iii++)				// Store given class labels in vector of strings named givenClassLabels
	{
		string data = dataTable[iii][dataTable[0].size()-1];
		givenClassLabels.push_back(data);
	}
	for (int iii = 1; iii < dataTable.size(); iii++)				// Predict class labels based on the decision tree
	{
		string someString = testDataOnDecisionTree(dataTable[iii], root, tableInfo, defaultClass);
		predictedClassLabels.push_back(someString);
	}
	dataTable.clear();
	/* Print output */
	ofstream outputFile;
	outputFile.open("decisionTreeOutput.txt", ios::app);
	outputFile << endl << "--------------------------------------------------" << endl;
	double accuracy = printPredictionsAndCalculateAccuracy(givenClassLabels, predictedClassLabels);			// calculate accuracy of classification
	outputFile << "Accuracy of decision tree classifier = " << accuracy << "%"; 							// Print out accuracy to console
	return 0;
}

# functions.cpp

#include "header.h"
/* 
 * Parses a string and stores data
 * into a vector of vector of strings
 */
void parse(string& someString, vvs &attributeTable)
{
	int attributeCount = 0;
	vs vectorOfStrings;
	while (someString.length() != 0 && someString.find(',') != string::npos)
	{
		size_t pos;
		string singleAttribute;
		pos = someString.find_first_of(',');
		singleAttribute = someString.substr(0, pos);
		vectorOfStrings.push_back(singleAttribute);
		someString.erase(0, pos+1);
	}
	vectorOfStrings.push_back(someString);
	attributeTable.push_back(vectorOfStrings);
	vectorOfStrings.clear();
}
/*
 * Prints a vector of vector of strings
 * For debugging purposes only.
 */
void printAttributeTable(vvs &attributeTable)
{
	int inner, outer;
	for (outer = 0; outer < attributeTable.size(); outer++) {
		for (inner = 0; inner < attributeTable[outer].size(); inner++) {
			cout << attributeTable[outer][inner] << "\t";
		}
		cout << endl;
	}
}
/*
 * Prunes a table based on a column/attribute's name
 * and value of that attribute. Removes that column
 * and all rows that have that value for that column.
 */
vvs pruneTable(vvs &attributeTable, string &colName, string value)
{
	int iii, jjj;
	vvs prunedTable;
	int column = -1;
	vs headerRow;
	for (iii = 0; iii < attributeTable[0].size(); iii++) {
		if (attributeTable[0][iii] == colName) {
			column = iii;
			break;
		}
	}
	for (iii = 0; iii < attributeTable[0].size(); iii++) {
		 if (iii != column) {
		 	headerRow.push_back(attributeTable[0][iii]);
		 }
	}
	prunedTable.push_back(headerRow);
	for (iii = 0; iii < attributeTable.size(); iii++) {
		vs auxRow;
		if (attributeTable[iii][column] == value) {
			for (jjj = 0; jjj < attributeTable[iii].size(); jjj++) {
				if(jjj != column) {
					auxRow.push_back(attributeTable[iii][jjj]);
				}
			}
			prunedTable.push_back(auxRow);
		}
	}
	return prunedTable;
}
/*
 * Recursively builds the decision tree based on
 * the data that it is passed and tha table info.
 */
node* buildDecisionTree(vvs &table, node* nodePtr, vvs &tableInfo)
{
	if (tableIsEmpty(table)) {
		return NULL;
	}
	if (isHomogeneous(table)) {
		nodePtr->isLeaf = true;
		nodePtr->label = table[1][table[1].size()-1];
		return nodePtr;
	} else {
		string splittingCol = decideSplittingColumn(table);
		nodePtr->splitOn = splittingCol;
		int colIndex = returnColumnIndex(splittingCol, tableInfo);
		int iii;
		for (iii = 1; iii < tableInfo[colIndex].size(); iii++) {
			node* newNode = (node*) new node;
			newNode->label = tableInfo[colIndex][iii];
			nodePtr->childrenValues.push_back(tableInfo[colIndex][iii]);
			newNode->isLeaf = false;
			newNode->splitOn = splittingCol;
			vvs auxTable = pruneTable(table, splittingCol, tableInfo[colIndex][iii]);
			nodePtr->children.push_back(buildDecisionTree(auxTable, newNode, tableInfo));
		}
	}
	return nodePtr;
}
/*
 * Returns true if all rows in a subtable
 * have the same class label.
 * This means that that node's class label
 * has been decided.
 */
bool isHomogeneous(vvs &table)
{
	int iii;
	int lastCol = table[0].size() - 1;
	string firstValue = table[1][lastCol];
	for (iii = 1; iii < table.size(); iii++) {
		if (firstValue != table[iii][lastCol]) {
			return false;
		}
	}
	return true;
}
/*
 * Returns a vector of integers containing the counts
 * of all the various values of an attribute/column.
 */
vi countDistinct(vvs &table, int column)
{
	vs vectorOfStrings;
	vi counts;
	bool found = false;
	int foundIndex;
	for (int iii = 1; iii < table.size(); iii++) {
		for (int jjj = 0; jjj < vectorOfStrings.size(); jjj++) {
			if (vectorOfStrings[jjj] == table[iii][column]) {
				found = true;
				foundIndex = jjj;
				break;
			} else {
				found = false;
			}
		}
		if (!found) {
			counts.push_back(1);
			vectorOfStrings.push_back(table[iii][column]);
		} else {
			counts[foundIndex]++;
		}
	}
	int sum = 0;
	for (int iii = 0; iii < counts.size(); iii++) {
		sum += counts[iii];
	}
	counts.push_back(sum);
	return counts;
}
/*
 * Decides which column to split on
 * based on entropy. Returns the column
 * with the least entropy.
 */
string decideSplittingColumn(vvs &table)
{
	int column, iii;
	double minEntropy = DBL_MAX;
	int splittingColumn = 0;
	vi entropies;
	for (column = 0; column < table[0].size() - 1; column++) {
		string colName = table[0][column];
		msi tempMap;
		vi counts = countDistinct(table, column);
		vd attributeEntropy;
		double columnEntropy = 0.0;
		for (iii = 1; iii < table.size()-1; iii++) {
			double entropy = 0.0;
			if (tempMap.find(table[iii][column]) != tempMap.end()) { 	// IF ATTRIBUTE IS ALREADY FOUND IN A COLUMN, UPDATE IT'S FREQUENCY
				tempMap[table[iii][column]]++;
			} else { 							// IF ATTRIBUTE IS FOUND FOR THE FIRST TIME IN A COLUMN, THEN PROCESS IT AND CALCULATE IT'S ENTROPY
				tempMap[table[iii][column]] = 1;
				vvs tempTable = pruneTable(table, colName, table[iii][column]);
				vi classCounts = countDistinct(tempTable, tempTable[0].size()-1);
				int jjj, kkk;
				for (jjj = 0; jjj < classCounts.size(); jjj++) {
					double temp = (double) classCounts[jjj];
					entropy -= (temp/classCounts[classCounts.size()-1])*(log(temp/classCounts[classCounts.size()-1]) / log(2));
				}
				attributeEntropy.push_back(entropy);
				entropy = 0.0;
			}
		}
		for (iii = 0; iii < counts.size() - 1; iii++) {
			columnEntropy += ((double) counts[iii] * (double) attributeEntropy[iii]);
		}
		columnEntropy = columnEntropy / ((double) counts[counts.size() - 1]);
		if (columnEntropy <= minEntropy) {
			minEntropy = columnEntropy;
			splittingColumn = column;
		}
	}
	return table[0][splittingColumn];
}
/*
 * Returns an integer which is the
 * index of a column passed as a string
 */
int returnColumnIndex(string &columnName, vvs &tableInfo)
{
	int iii;
	for (iii = 0; iii < tableInfo.size(); iii++) {
		if (tableInfo[iii][0] == columnName) {
			return iii;
		}
	}
	return -1;
}
/*
 * Returns true if the table is empty
 * returns false otherwise
 */
bool tableIsEmpty(vvs &table)
{
	return (table.size() == 1);
}
/*
 * Recursively prints the decision tree
 * For debugging purposes only
 */
void printDecisionTree(node* nodePtr)
{
	if(nodePtr == NULL) {
		return;
	}
	if (!nodePtr->children.empty()) {
		cout << " Value: " << nodePtr->label << endl;
		cout << "Split on: " << nodePtr->splitOn;
		int iii;
		for (iii = 0; iii < nodePtr->children.size(); iii++) {   
			cout << "\t";
			printDecisionTree(nodePtr->children[iii]);
		}
		return;
        } else {
		cout << "Predicted class = " << nodePtr->label;
		return;
	}
}
/*
 * Takes a row and traverses that row through
 * the decision tree to find out the 
 * predicted class label. If none is found
 * returns the default class label which is
 * the class label with the highest frequency.
 */
string testDataOnDecisionTree(vs &singleLine, node* nodePtr, vvs &tableInfo, string defaultClass)
{
	string prediction;
	while (!nodePtr->isLeaf && !nodePtr->children.empty()) {
		int index = returnColumnIndex(nodePtr->splitOn, tableInfo);
		string value = singleLine[index];
		int childIndex = returnIndexOfVector(nodePtr->childrenValues, value);
		nodePtr = nodePtr->children[childIndex];
		if (nodePtr == NULL) {
			prediction = defaultClass;
			break;
		}
		prediction = nodePtr->label;
	}
	return prediction;
}
/*
 * Returns an integer which is the index
 * of a string in a vector of strings
 */
int returnIndexOfVector(vs &stringVector, string value)
{
	int iii;
	for (iii = 0; iii < stringVector.size(); iii++) {
		if (stringVector[iii] == value)	{
			return iii;
		}
	}
	return -1;
}
/*
 * Outputs the predictions to file
 * and returns the accuracy of the classification
 */
double printPredictionsAndCalculateAccuracy(vs &givenData, vs &predictions)
{
	ofstream outputFile;
	outputFile.open("decisionTreeOutput.txt");
	int correct = 0;
	outputFile << setw(3) << "#" << setw(16) << "Given Class" << setw(31) << right << "Predicted Class" << endl;
	outputFile << "--------------------------------------------------" << endl;
	for (int iii = 0; iii < givenData.size(); iii++) {
		outputFile << setw(3) << iii+1 << setw(16) << givenData[iii];
		if (givenData[iii] == predictions[iii]) {
			correct++;
			outputFile << "  ------------  ";
		} else {
			outputFile << "  xxxxxxxxxxxx  ";
		}
		outputFile << predictions[iii] << endl;
	}
	outputFile << "--------------------------------------------------" << endl;
	outputFile << "Total number of instances in test data = " << givenData.size() << endl;
	outputFile << "Number of correctly predicted instances = " << correct << endl;
	outputFile.close();
	return (double) correct/50 * 100;
}
/*
 * Returns a vvs which contains information about
 * the data table. The vvs contains the names of
 * all the columns and the values that each
 * column can take
 */
vvs generateTableInfo(vvs &dataTable)
{
	vvs tableInfo;
	for (int iii = 0; iii < dataTable[0].size(); iii++) {
		vs tempInfo;
		msi tempMap;
		for (int jjj = 0; jjj < dataTable.size(); jjj++) {
			if (tempMap.count(dataTable[jjj][iii]) == 0) {
				tempMap[dataTable[jjj][iii]] = 1;
				tempInfo.push_back(dataTable[jjj][iii]);
			} else	{
				tempMap[dataTable[jjj][iii]]++;
			}
		}
		tableInfo.push_back(tempInfo);
	}
	return tableInfo;
}
/*
 * Returns the most frequent class from the training data
 * This class will be used as the default class label
 */
string returnMostFrequentClass(vvs &dataTable)
{
	msi trainingClasses;           													 // Stores the classlabels and their frequency
	for (int iii = 1; iii < dataTable.size(); iii++) {
		if (trainingClasses.count(dataTable[iii][dataTable[0].size()-1]) == 0) {
			trainingClasses[dataTable[iii][dataTable[0].size()-1]] = 1;
		} else {
			trainingClasses[dataTable[iii][dataTable[0].size()-1]]++;
		}
	}   
	msi::iterator mapIter;
	int highestClassCount = 0;
	string mostFrequentClass;
	for (mapIter = trainingClasses.begin(); mapIter != trainingClasses.end(); mapIter++) {
		if (mapIter->second >= highestClassCount) {
			highestClassCount = mapIter->second;
			mostFrequentClass = mapIter->first;
		}   
	}
	return mostFrequentClass;
}
更新於 閱讀次數

請我喝[茶]~( ̄▽ ̄)~*

fygod 微信支付

微信支付

fygod 支付寶

支付寶

fygod PayPal

PayPal