Posts List

[Data Mining] 9. Classification - 트리 기반 모형


Decision Tree (의사결정나무)

트리기반 모형은 분류문제에 많이 쓰는 대표적인 모형 중 하나이다.
(앞서 AI 수업에서 Tree Search를 배웠기 때문에 어느정도 익숙할 것이라 예상된다. 그때 배웠던 pruning이란 단어도 다시 등장한다.)

트리 모형은 독립변수 공간을 한번에 한 변수의 경계치를 기준으로 분할(분지) 하는 것이다. 위 그림을 보면 뿌리노드에서는 X1을 기준으로 분지되었고 다음으로는 X2를 기준으로 분지되었다.

트리모형에서 가장 중요한 것은 분지기준 잘 찾아서 가장 깊이가 짧은 트리모형을 형성하는 것이다. 이 목적을 달성하기 위해 분지기준은 불순도(impurity)가 감소하는 방향으로 분지한다. 직관적으로 표현하면 자식노드에 샘플이 최대한 한쪽으로 쏠리게 하는 것이다. (자식노드 중 하나에 1개 분류만 모여있으면 순도 100%, 불순도는 0%)

트리 모형을 학습하는 과정은 크게 아래와 같으며, 이번 챕터에서는 아래 모든 과정에 대해서 자세히 알아볼 예정이다.

  • 트리의 형성(Growing the tree)
    - 분지 기준(splitting criteria)
    - 불순도 함수(impurity function)
  • 가지치기(Pruning)
    - 과적합(overfitting) 방지
    - 비용-복잡도 척도(cost-complexity measure)
  • 분류(Classification)
    - 분류 규칙(classification rule)


※ CART : Classification and Regression Tree

트리 모형은 상당히 오래전에 나온 모형이지만, 앞서 그림에서 보았듯 예측 결과의 원인에 대한 추론이 매우 직관적이다. 그래서 지금까지도 여타 알고리즘에 비해 사랑받고 있고, 위 표와 같이 다양한 하이퍼파라미터 조합을 이용해서 성능을 올리는 등 연구활동이 활발하다.


1. 트리의 형성

이제 본격적으로 트리모형에 대해 알아볼텐데, 역시 트리모형에서 가장 중요한 것은 "어떻게 분지를 할 것인가" 이다.

앞서 이미 언급했지만, 분지는 한번에 한 변수를 기준으로만 이루어지고, 불순도(impurity)를 감소하는 방향으로 분지시키는 것이 기본이다.

불순도의 개념에 앞서 위의 두 케이스 중 어떤 경우가 잘 분류된 것일까? 정답은 왼쪽이다. 왼쪽과 같이 자식노드에 1개 범주를 몰아넣을 수 있다면 "순도가 높다"라고 포현하며, 이렇게 분지됐을 때 "불순도가 감소됐다."라고 표현한다.

불순도 함수

그럼 이제 필요한 것은 불순도를 정량적으로 구하는 것이다.

불순도 함수 Φ는 아래의 조건을 만족하면 어떤 함수든 쓸 수 있다. 

  • Φ(1,0,0,...,0) = Φ(0,1,0,...,0) = ... = Φ(0,0,0,...1) = 0
    : 한개 노드에 한개 범주가 몰려 있으면 불순도는 0이다. (순도 100%)
  • Φ(1/J, 1/J, ..., 1/J) 가 최대값
    : 모든 노드에 공평하게 1/J개의 각 범주샘플을 가지고 있으면 불순도가 가장 높은상태이다.
  • Φ는 (p1, ..., pJ)에 대한 대칭함수 (pJ = 각 자식 노드에 있는 샘플의 비율)
    ex) 2범주의 경우 Φ(0.4, 0.6)와 Φ(0.6, 04) 두 불순도는 같다.
         → 불순도(A범주 샘플 4개 + B범주 6개) = 불순도(A범주 샘플 6개 + B범주 4개)

위 조건을 만족하는 대표적인 불순도 함수로 지니계수엔트로피 지수가 있다.

  • 지니계수 G(p1, ..., pJ) = 1 - Σj=1~J pj2
  • 엔트로피 지수 E(p1, ..., pJ) = -Σj=1~J Pjlog2Pj

위 불순도 함수에 의하면 아래와 같이 각 노드의 불순도 i 를 구할 수 있다.(지니 계수 가정)

i(t) = G(p(1|t), ... , p(J|t) = 1 - Σj=1~J p(j)2
※ p(j|t) : 노드 t에서 범주 j에 속한 객체수 / 노드 t의 전체 객체 수

예를 들어 위와 같이 노드 t1, t2, t3에 객체가 들어 있을 때 불순도는 아래와 같이 구해진다.

i(t1) = G(5/10, 5/10) = 1 - (5/10)2 + (5/10)2 = 5/10

i(t2) = G(2/7, 5/7) = 1 - (2/7)2 + (5/7)2 = 20/49

i(t3) = G(3/3, 0/3) = 1 - (3/3)2 + (0/0)2 = 0

이렇게 각 노드의 불순도가 구해졌을 때, 과연 이 트리가 잘 분지시켰는지를 파악하기 위해 트리의 불순도도 파악해야한다. 트리의 불순도는 아래와 같이 구할 수 있다.

I(T) = Σt∈A(T) i(t) p(t) : 최종 노드에서의 불순도 가중 평균

※ p(t) = t노드의 객체수 / 전체 객체수
   A(T) = 최종 노드의 집합

그럼 위 예제에서 트리 T1의 불순도는 아래와 같이 계산할 수 있다.

I(T1) = i(t2)p(t2) + i(t3)p(t3) = (20/49)*(7/10) + (0)*(3/10) = 2/7

분지 기준

앞서 트리 모형은 불순도를 감소시키는 방향으로 분지된다고 하였다. 즉, 부모노드가 불순도 i(t)에서 분지(s)되어 트리의 불순도 I(T) 만큼 불순도가 감소했을 때, 이 불순도 감소량을 최대화 하는 분지기준 s*을 선택해야 한다.

Δi(s, t) = i(t) - pLi(tL) - pRi(tR) = i(t) - I(T)

Δi(s*, t) = maxΔi(s, t)

※ pL, pR : 왼쪽(L), 오른쪽(R) 자식노드의 분지 비율


2. 가지치기 (Pruning)

가지치기는 overfitting을 방지를 위한 행위다. 실제로 트리모형은 깊이가 깊어질 수록 성능이 향상되지만, full-depth까지 내려가면 모든 경우의 수를 표현하기 때문에 정확도가 100%로 수렴한다. 즉, 트리모형의 깊이가 너무 깊으면 학습데이터셋에 overfitting 될 수 있다. 그래서 실제 나무의 가지를 치는 것 처럼 불순도가 적정 수준에 다다르면 더 분지시키지 않게 하는 것이 바로 가지치기이다.
※ AI시간에는 search 속도 향상을 위해 pruning을 했지만 여기서는 overfitting 방지 효과의 목적이 더 크다.

그럼 중요한 것은 어떤 기준으로 가지를 칠것인지 정하는 것이다. 이를 위해 나온 개념이 바로 비용-복잡도 개념이다. (비용과 복잡도를 함께 고려하는 개념)

  • 오분류 비용 : 오분류 시 얻는 penalty = 오분류 확률
    ex. 노드 t에 A범주 3개, B범주 5개의 object가 있다면 B범주를 예측한다.
         이 때 노드 t의 비용(오분류 확률) r(t) = 3/5
  • 복잡도 : 트리가 얼마나 복잡한가? = 트리의 최종(terminal)노드수

위에 짧게 설명했지만, 노드와 트리의 오분류 비용을 일반화하면 아래와 같다.

노드 t의 오분류 비용 : r(t) = 1 - max p(j | t)

트리 T의 오분류 비용 : Σt∈A r(t)p(t)

결과적으로 트리 T의 비용-복잡도 척도(cost-complexity measure)Ra(T)는 아래와 같이 비용과 복잡도의 합으로 나타낸다.

Ra(T) = R(T) + α|T|
※ α > 0 : 복잡도 계수   |T| : 트리 T의 최종 노드 개수

비용-복잡도 척도는 노드별로 계산 된다. 예를 들어 위 그림에서 노드 t를 가지치기 했을때의 척도, 가지치기 하지 않았을 때의 척도를 비교해서, 가지치기를 했을 때 비용-복잡도 척도가 낮으면 가지치기를 수행한다. 

결국 최대트리에서 가지치기를 1개, 2개...뿌리노드 까지 모두 수행하면서 가지친 트리들 T1, T2 ... 들이 생기는데, 그 중에서 학습데이터가 아닌 별도의 테스트표본에서 오분류율이 가장 낮은 트리가 best estimator로 선정된다. (cross validation)


3. 트리의 분류 규칙

트리의 분류 규칙은 사실 별로 설명할 게 없다. 앞서 잠깐 언급했듯, 최종 노드(terminal node에 도착했을 때 다수결에 의해서 예측 결과를 도출한다.

댓글 쓰기

0 댓글