Дерево классификации и регрессии машинного обучения (python реализует CART)

машинное обучение искусственный интеллект Python алгоритм

Деревья решений (ID3) были представлены в предыдущих статьях. Краткий обзор: ID3 каждый раз выбирает лучшую функцию для сегментации данных, и принцип оценки лучшей функции достигается за счет получения информации. После сегментации данных по определенному признаку этот признак не будет использоваться при сегментации набора данных позже, поэтому возникает проблема слишком быстрой сегментации. Алгоритм ID3 еще не может обрабатывать непрерывные функции. Вот краткое введение в другие алгоритмы:

屏幕快照 2018-03-03 14.05.44.png

Дерево регрессии классификации CART

CART — это аббревиатура от «Деревья классификации и регрессии», которые могут обрабатывать как задачи классификации, так и задачи регрессии.

image.png
Типичным представителем дерева CART является бинарное дерево, которое будет классифицироваться в соответствии с различными условиями.
image.png

Алгоритм построения дерева CART Подобно методу построения дерева решений ID3, процесс построения дерева CART задается напрямую. Во-первых, аналогично ID3, используется структура данных словарного дерева, включающая следующие 4 элемента:

  • Особенности, которые нужно сегментировать
  • Собственные значения для сегментации
  • правое поддерево. Когда разбиение больше не требуется, это также может быть одно значение
  • Левое поддерево аналогично правому поддереву.

Процесс выглядит следующим образом:

  1. Поиск наиболее подходящих функций сегментации
  2. Если набор данных нельзя разделить, набор данных используется как конечный узел.
  3. Разделите набор данных на два
  4. Повторите шаги 1, 2 и 3 для разделенного набора данных 1, чтобы создать правильное поддерево.
  5. Повторите шаги 1, 2 и 3 для разделенного набора данных 2, чтобы создать левое поддерево.

Очевидный рекурсивный алгоритм.

Разделите набор данных с помощью фильтрации данных и верните два подмножества.

def splitDatas(rows, value, column):
    # 根据条件分离数据集(splitDatas by value, column)
    # return 2 part(list1, list2)

    list1 = []
    list2 = []

    if isinstance(value, int) or isinstance(value, float):
        for row in rows:
            if row[column] >= value:
                list1.append(row)
            else:
                list2.append(row)
    else:
        for row in rows:
            if row[column] == value:
                list1.append(row)
            else:
                list2.append(row)
    return list1, list2

Разделить точки данных

Создание бинарного дерева решений — это, по сути, процесс рекурсивного разделения входного пространства.

image.png

код показывает, как показано ниже:

# gini()
def gini(rows):
    # 计算gini的值(Calculate GINI)

    length = len(rows)
    results = calculateDiffCount(rows)
    imp = 0.0
    for i in results:
        imp += results[i] / length * results[i] / length
    return 1 - imp

построить дерево

def buildDecisionTree(rows, evaluationFunction=gini):
    # 递归建立决策树, 当gain=0,时停止回归
    # build decision tree bu recursive function
    # stop recursive function when gain = 0
    # return tree
    currentGain = evaluationFunction(rows)
    column_lenght = len(rows[0])
    rows_length = len(rows)

    best_gain = 0.0
    best_value = None
    best_set = None

    # choose the best gain
    for col in range(column_lenght - 1):
        col_value_set = set([x[col] for x in rows])
        for value in col_value_set:
            list1, list2 = splitDatas(rows, value, col)
            p = len(list1) / rows_length
            gain = currentGain - p * evaluationFunction(list1) - (1 - p) * evaluationFunction(list2)
            if gain > best_gain:
                best_gain = gain
                best_value = (col, value)
                best_set = (list1, list2)
    dcY = {'impurity': '%.3f' % currentGain, 'sample': '%d' % rows_length}
    #
    # stop or not stop

    if best_gain > 0:
        trueBranch = buildDecisionTree(best_set[0], evaluationFunction)
        falseBranch = buildDecisionTree(best_set[1], evaluationFunction)
        return Tree(col=best_value[0], value = best_value[1], trueBranch = trueBranch, falseBranch=falseBranch, summary=dcY)
    else:
        return Tree(results=calculateDiffCount(rows), summary=dcY, data=rows)

Функция приведенного выше кода состоит в том, чтобы сначала найти наилучшую позицию для разделения набора данных и разбить набор данных. После этого рекурсивно строится все дерево приведенной выше картинки.

обрезка

При обучении дерева решений иногда дерево решений имеет слишком много ветвей, что означает удаление некоторых ветвей, чтобы уменьшить переоснащение. Процесс предотвращения переобучения из-за сложности дерева решений называется обрезкой. Постобрезка требует создания полного дерева решений из обучающего набора, а затем изучения неконечных узлов снизу вверх. Используйте тестовый набор, чтобы определить, следует ли заменить поддерево, соответствующее узлу, конечным узлом. код показывает, как показано ниже:

def prune(tree, miniGain, evaluationFunction=gini):
    # 剪枝 when gain < mini Gain, 合并(merge the trueBranch and falseBranch)
    if tree.trueBranch.results == None:
        prune(tree.trueBranch, miniGain, evaluationFunction)
    if tree.falseBranch.results == None:
        prune(tree.falseBranch, miniGain, evaluationFunction)

    if tree.trueBranch.results != None and tree.falseBranch.results != None:
        len1 = len(tree.trueBranch.data)
        len2 = len(tree.falseBranch.data)
        len3 = len(tree.trueBranch.data + tree.falseBranch.data)

        p = float(len1) / (len1 + len2)

        gain = evaluationFunction(tree.trueBranch.data + tree.falseBranch.data) - p * evaluationFunction(tree.trueBranch.data) - (1 - p) * evaluationFunction(tree.falseBranch.data)

        if gain < miniGain:
            tree.data = tree.trueBranch.data + tree.falseBranch.data
            tree.results = calculateDiffCount(tree.data)
            tree.trueBranch = None
            tree.falseBranch = None

Когда усиление узла меньше заданного мини-усиления, два узла объединяются.

И, наконец, код для построения дерева:

if __name__ == '__main__':
    dataSet = loadCSV()
    decisionTree = buildDecisionTree(dataSet, evaluationFunction=gini)
    prune(decisionTree, 0.4)
    test_data = [5.9,3,4.2,1.5]
    r = classify(test_data, decisionTree)
    print(r)

Вы можете распечатать решениеTree, чтобы построить дерево решений, как показано на рисунке выше. Найдите набор данных для проверки позже, чтобы увидеть, сможете ли вы получить правильную классификацию.

Полный код и набор данных см.
гитхаб: КОРЗИНА

Суммировать:

  • Дерево решений КОРЗИНЫ
  • Разделить набор данных
  • рекурсивно создать дерево

Справочная статья:
Анализ регрессионного дерева классификации CART и реализация Python
Реализация исходного кода Python дерева решений CART (Decision Tree)