В машинном обучении дерево решений — это прогностическая модель, представляющая сопоставление между атрибутами объекта и значениями объекта. Каждый узел в дереве представляет собой объект, каждое ответвление представляет собой возможный атрибут, а каждый конечный узел соответствует значению объекта, представленному путем от корневого узла к конечному узлу. Деревья решений имеют только один выход.Если вам нужны сложные выходные данные, вы можете построить отдельные деревья решений для обработки разных входных данных. Деревья решений обычно используются в интеллектуальном анализе данных, который можно использовать для анализа данных и прогнозирования.
просто понять
Как показано выше, Первые два являются атрибутами, которые можно записать как
['no surfacing','flippers']
. Дерево решений можно просто построить следующим образом:По двум признакам можно судить, принадлежит ли оно рыбе.
Итак, сначала решите, какой атрибут выбрать в качестве начальной классификации? Самый простой — ID3. Улучшенный C4.5 будет понят позже в CART.
Деревья решений и ID3
Дерево решений похоже на древовидную структуру и имеет древовидную структуру. Каждый внутренний узел представляет собой тест атрибута, каждая ветвь представляет результат теста, а каждый конечный узел представляет категорию. Как показано выше. Дерево классификации (дерево решений) часто используется для классификации в машинном обучении и является методом обучения с учителем. Объекты этого типа классифицируются по признакам по ветвям дерева. Каждое дерево решений может полагаться на тестирование данных в разделении исходной базы данных, рекурсивно сокращая дерево. Зная, что к ветке применяется один класс, его нельзя разделить, сделать это рекурсивно. Функции:
- Многоуровневая форма дерева решений проста для понимания.
- Применяется только к номинальным данным строки, непрерывная обработка данных не годится.
Алгоритм ID3
Выше описано, как сначала выбрать, какой атрибут классифицировать в ряду атрибутов. Простое понимание, если какой атрибут более запутан, вы можете напрямую получить категорию, к которой он принадлежит. такие как вышеперечисленные свойства水下是否可以生存
, те, которые не могут выжить, могут быть классифицированы как не рыбы.
Итак, как количественно оценить и получить это свойство?
Ядром алгоритма ID3 являетсяИнформация о влажности, путем вычисления информационного прироста каждого атрибута считается, что атрибут с высоким коэффициентом усиления является хорошим атрибутом и его легко классифицировать. Каждое подразделение выбирает атрибут с наибольшим приростом информации в качестве критерия деления и повторяется до тех пор, пока не будет сгенерировано дерево решений, которое может идеально классифицировать обучающие выборки.
Приведенный выше алгоритм получения информации не совсем понятен, и его код легко просмотреть позже.
Процесс алгоритма ID3 и дерева решений
- Подготовка данных: необходимо дискретизировать числовые данные
- Алгоритм ID3 строит дерево решений:
- Если категории данных точно такие же, прекратите разделение.
- В противном случае продолжайте деление:
- Рассчитайте информационную энтропию и прирост информации, чтобы выбрать лучший метод разделения набора данных.
- Разделить набор данных
- Создать узел ответвления
- Категория решения одинакова для каждой ветви. Одни и те же перестают делиться, а разные делятся по вышеуказанной методике.
реализация кода на питоне
Создайте набор данных, используя приведенный выше пример.
def createDataSet():
dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
labels = ['no sufacing', 'flippers']
return dataSet, labels
Рассчитать информационную энтропию, соответствующую первой формуле
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
# 为分类创建字典
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts.setdefault(currentLabel, 0)
labelCounts[currentLabel] += 1
# 计算香农墒
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt += prob * math.log2(1 / prob)
return shannonEnt
Рассчитайте максимальный прирост информации (уравнение 2) и разделите набор данных.
# 定义按照某个特征进行划分的函数 splitDataSet
# 输入三个变量(带划分数据集, 特征,分类值)
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reduceFeatVec)
return retDataSet #返回不含划分特征的子集
# 定义按照最大信息增益划分数据的函数
def chooseBestFeatureToSplit(dataSet):
numFeature = len(dataSet[0]) - 1
print(numFeature)
baseEntropy = calcShannonEnt(dataSet)
bestInforGain = 0
bestFeature = -1
for i in range(numFeature):
featList = [number[i] for number in dataSet] #得到某个特征下所有值
uniqualVals = set(featList) #set无重复的属性特征值
newEntrogy = 0
#求和
for value in uniqualVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet)) #即p(t)
newEntrogy += prob * calcShannonEnt(subDataSet) #对各子集求香农墒
infoGain = baseEntropy - newEntrogy #计算信息增益
print(infoGain)
# 最大信息增益
if infoGain > bestInforGain:
bestInforGain = infoGain
bestFeature = i
return bestFeature
Простой тест:
if __name__ == '__main__':
dataSet, labels = createDataSet()
r = chooseBestFeatureToSplit(dataSet)
print(r)
# 输出
# 2
# 0.41997309402197514
# 0.17095059445466865
# 0
Как и выше, вы можете видеть, что есть два свойства['no surfacing','flippers']
и его прирост информации, поэтому выберите более крупный объект (индекс 0), чтобы разделить набор данных (см. Начальный рисунок), и повторяйте шаги, пока не останется только одна категория.
Создать конструктор дерева решений
# 投票表决代码
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount.setdefault(vote, 0)
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=lambda i:i[1], reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
# print(dataSet)
# print(classList)
# 类别相同,停止划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 判断是否遍历完所有的特征,是,返回个数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
#按照信息增益最高选择分类特征属性
bestFeat = chooseBestFeatureToSplit(dataSet) #分类编号
bestFeatLabel = labels[bestFeat] #该特征的label
myTree = {bestFeatLabel: {}}
del (labels[bestFeat]) #移除该label
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:] #子集合
#构建数据的子集合,并进行递归
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
В коде есть взгляды, пытающиеся понять выполнение каждого шага и получить базовое представление о дереве решений.
if __name__ == '__main__':
dataSet, labels = createDataSet()
r = chooseBestFeatureToSplit(dataSet)
# print(r)
myTree = createTree(dataSet, labels)
print(myTree)
# --> {'no sufacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
Вы можете видеть, что на выходе получается вложенный словарь, а дерево решений можно нарисовать вручную, что соответствует диаграмме в начале.
Используйте деревья решений для классификации
Постройте функцию классификации дерева решений:
def classify(inputTree, featLabels, testVec):
"""
:param inputTree: 决策树
:param featLabels: 属性特征标签
:param testVec: 测试数据
:return: 所属分类
"""
firstStr = list(inputTree.keys())[0] #树的第一个属性
sendDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
classLabel = None
for key in sendDict.keys():
if testVec[featIndex] == key:
if type(sendDict[key]).__name__ == 'dict':
classLabel = classify(sendDict[key], featLabels, testVec)
else:
classLabel = sendDict[key]
return classLabel
Видно, что функция классифицирует тестовые данные шаг за шагом в соответствии со значением атрибута, пока не будет достигнут конечный узел и не будет получена правильная классификация.
Кроме того, дерево решений можно сохранить, в отличие от kNN дерево решений строится без повторных вычислений и может быть использовано непосредственно в следующий раз.
def storeTree(inputTree,filename):
import pickle
fw=open(filename,'wb') #pickle默认方式是二进制,需要制定'wb'
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr=open(filename,'rb')#需要制定'rb',以byte形式读取
return pickle.load(fr)
Полный код дерева решений можно найти на github:github:decision_tree
Суммировать
- Деревья решений: ID3, C4.5, CART
- Теория информации: обогащение информации, получение информации
- хранилище объектов Python
Использованная литература:Алгоритм дерева решений машинного обучения (ID3) и реализация Python