머신러닝과 알고리즘의 이해 : K-Nearest Neighbors 알고리즘
KNN(K-Nearest Neighbors) 알고리즘은 데이터의 특성을 파악하여 가장 가까운 유사속성에 따라 분류하는 기법이다. 본 알고리즘은 학습을 위해 주어진 데이터가 가진 특징을 파악하여 새로운 데이터가 들어와도 어떠한 집합군인지를 어렵지 않게 파악해 낼 수 있다. 따라서 큰 분류인 머신러닝내의 머신러닝>지도학습>분류모델에 속한다고 할 수 있다. 일맥상통하는 사자성어는 유유상종. 주변사람을 보면 그 사람을 알 수 있는것과 같은 사람만이 아는 경험적 원리를 컴퓨터에게 가르치는 것 이다.
생선(도미와 빙어)의 무게와 길이에 대한 학습을 바탕으로 차후 들어오는 생선의 무게와 길이 데이터만 가지고 어떤 생선인지 맞추는 binary classification 문제를 아래의 코드와 같이 구현해보고자 한다.
# 데이터준비
# bream = 도미, smelt = 빙어
# length = 길이, weight = 무게
bream_length = [25.4, 26.3, 26.5, 29.0, 29.0,29.7,30.0, 30.7, 31.0, 31.0, 31.5, 32.0, 32.0, 32.0, 33.0 ,33.5 ,33.5, 33.5, 34, 34, 34.5,35,35,35,35,36.0,36.0,37,38.5,38.5,39.5,41,41]
print(len(bream_length))
print(min(bream_length))
bream_weight = [242, 290, 340, 363, 430, 450, 500, 390, 450, 470, 500, 500, 340, 600, 700, 700, 610, 650, 575,600, 600, 680, 700, 725, 720, 710, 850, 1000, 920, 955, 925, 975, 950]
print(len(bream_weight))
print(min(bream_weight))
smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12, 12.2, 12.4, 13.0, 14.3, 15]
print(len(smelt_length))
print(min(smelt_length))
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4 , 12.2, 19.7,19.9]
print(len(smelt_weight))
print(min(smelt_weight))
# 산점도 그래프 plotting
import matplotlib.pyplot as plt
plt.scatter(bream_length,bream_weight)
plt.scatter(smelt_length,smelt_weight)
plt.xlabel('lenght')
plt.ylabel('weight')
plt.show()
그래프를 plotting 하면 아래와 같은 그래프를 얻을 수 있다. (명백히 두 개의 group으로 나뉘어 있다. training을 통하여 두 그룹의 특성을 학습하고 차후 어떠한 데이터가 들어와도 분류할 수 있게 만드는게 KNN 알고리즘이 할 일.)
# 데이터 합치기
length = bream_length + smelt_length
weight = bream_weight + smelt_weight
print(length)
print(weight)
fish_data = [[l,w] for l,w in zip(length,weight)]
print(fish_data)
기존에 만들어져있는 함수에 맞춰서 데이터를 취합한다. zip이라는 함수를 가져와 데이터를 list안에 하나씩 가져와서 넣는 과정을 거친다. 최종 결과물은 [[length1,weight1],[length2,weight2] ... [length47,weight47]]과 같은 형태가 얻어진다.
fish_target = [1]*33 + [0]*14
print(fish_target)
본 머신러닝 기법은 결과값을 알려주는 지도학습이므로 컴퓨터에게 답을 알려주어야한다. (해당 특징을 가지고 있는 생선이 도미인지 혹은 빙어인지) 입력되는 데이터를 Input data 출력하고자 원하는 데이터를 Target data라 일컫으므로 fish_target는 변수를 사용한다. 앞선 33개의 데이터는 도미이고 이후에 있는 14개의 데이터는 빙어의 데이터이다. (도미 = 1, 빙어 = 0)
from sklearn.neighbors import KNeighborsClassifier # KNN 알고리즘 import
model1 = KNeighborsClassifier()
model1.fit(fish_data,fish_target)
model1.score(fish_data,fish_target)
피팅과 동시에 모델 평가까지 완료하였다. score를 통하여 만들어진 모델의 정확도를 출력할 수 있는데 출력치는 1.0으로 학습데이터를 기준으로 보았을 때 100%의 정확도를 가진 모델을 생성했다. 물론 overfitting일 가능성을 배제할 수 없으니 여러 test data를 가지고 실험을 해보아야 한다.
이제 만들어진 KNN알고리즘을 통하여 새로운 데이터가 Input으로 들어왔을 때 잘 예측하는지 살펴보자. 아래의 삼각형 심볼의 데이터는 기존 training데이터가 아닌 신규 데이터인데 (30, 600)라는 데이터를 가지고 있다.
model1.predict([[30,600]]) # 결과가 1이면 도미 0이면 빙어
해당 코드를 실행하여 보면 별다른 문제 없이 '1'이라는 값을 출력한다. 즉 정확하게 binary classification을 수행해냈다.
print(model1._fit_X)
print(model1._y)
아래 코드를 수행해보면 내부함수에 있는 X(input data)와 y(target data)값을 볼 수 있는데 학습 전과 후, (30, 600)이라는 데이터의 Classification 전과 후와 변화가 없음을 알 수 있다. 즉, KNN 알고리즘은 기존데이터의 변화없이 새로운 데이터가 다가 올 시 주변의 데이터들을 참조하여 어떤 집단인지를 구별하는 역할을 한다.
그렇다면 KNN알고리즘은 Default로 몇개의 데이터를 참고할까? 매개변수를 변환하지 않는다면 모델은 대개 5개의 주변 값을 참고한다. 알고리즘을 상황에 맞게 잘 사용하기 위해서는 Default 설정에 대해서 잘 알아둘 필요가 있다. 수천 수만개의 데이터를 가지는 상황에서 5개의 주변 값만을 참고하는 것 보단 주변 몇십개의 데이터를 참조하는 것이 더 정확도가 높을 수 있기 때문이다.
정확도를 낮추기 위해 아래와 같이 47개의 데이터를 참조하는 상황으로 바꾸어보자. (총 데이터 개수는 47개이므로 모든 데이터를 전부 참조하는 것 이다.)
model47 = KNeighborsClassifier(n_neighbors=47) #47개의 데이터 참조
model47.fit(fish_data, fish_target)
print(model49.score(fish_data, fish_target))
print(33/47) # 47개중 대다수를 차지하는 도미를 참조하기 때문에 무조건 정답은 도미라고한다.
47개의 데이터중 47개 모두를 참조하였으니 해당 모델은 모든 생선에 대해 과반수 이상인 도미라고 분류한다. 33개의 데이터를 가지는 도미는 당연 도미라고 분류하고 나머지 15개의 데이터 또한 빙어임에도 도미라고 분류하는 것 이다. 따라서 정확도(accuracy)는 33/47이 될 것 이다.
※ 추가 - 용어정리
훈련 : 데이터에서 규칙을 찾는 과정 (해당 예제에서는 사이킷런에서 fit()이 하는 역할)
모델 : 알고리즘이 구현된 객체
정확도 : 정확한 답을 몇 개 맞췄는지 백분율로 나타내는 값 (사이킷런에서는 0~1 사이의 값으로 나타내어짐)
※ 추가 - scikit-learn
- KneighborsClassifier() : KNN 분류 모델을 만드는 사이킷런 클래스이다.
- n-neighbors 매개변수로 이웃의 개수를 지정함 (기본값 = 5)
- P 매개변수로 거리를 재는 방법을 지정하며 1일 경우 맨해튼 거리, 2일 경우 유클리디안 거리를 사용. (기본값 = 2)
- n_jobs 매개변수로 사용할 CPU코어를 지정 가능. -1로 설정하면 모든 CPU 코어를 사용한다. 계산 속도를 높힐 수 있지만 fit() 매서드에서는 영향이 없다. (기본값 = 1)
- fit()은 사이킷런 모델을 훈련할 때 사용하는 메서드이다. 처음 두 매개변수로 훈련에 사용할 특성과 정답 데이터를 전달 한다.
- prdict()은 사이킷런 모델을 훈련하고 예측할 때 사용
- score()은 훈련된 모델의 성능을 측정한다.