728x90
KNN 알고리즘이란
- 가장 간단한 머신러닝 알고리즘, 분류(Classification) 알고리즘
- 어떤 데이터에 대한 답을 구할 때 주위의 다른 데이터를 보고 다수를 차지하는 것을 정답으로 사용
- 새로운 데이터에 대해 예측할 때는 가장 가까운 직선거리에 어떤 데이터가 있는지 살피기만 하면 된다.(k =1)
- 단점
- 데이터가 아주 많은 경우 사용하기 어렵다
- 데이터가 크기 때문에 메모리가 많이 필요하고 직선 거리를 계산하는 데도 많은 시간이 필요
- 실제로 k-최근접 알고리즘은 무언가 훈련되는 게 없다.
- fit() 메서드에 전달한 데이터를 모두 저장하고 있다가 새로운 데이터가 등장하면 가장 가까운 데이터를 참고하여 분류
- 객체._fit_X
- 객체._y
- fit() 메서드에 전달한 데이터를 모두 저장하고 있다가 새로운 데이터가 등장하면 가장 가까운 데이터를 참고하여 분류
- k의 기본값은 5
- n_neighbors 매개변수로 바꿀 수 있다.
- p 매개변수로 거리를 재는 방법 지정
- 1일 경우 맨해튼 거리
- 직선
- 2일 경우 유클리디안 거리 (기본값)
- 계단식(절댓값 합)
- 1일 경우 맨해튼 거리
KNN 알고리즘 실습 - 도미와 빙어 분류하기
- 도미 데이터 (35마리)
# 생선의 길이
bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
# 생선의 무게
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]
- 빙어 데이터 (14마리)
smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]
- 데이터 분포 확인
import matplotlib.pyplot as plt
plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
파란색 -> 도미데이터 선형적, 길이가 길수록 무게가 많이 나감
주황색 -> 빙어데이터 무게가 길이에 영향을 덜 받는다.
- 두 데이터 합치기 (사이킷런을 사용하기 위해 2차원 리스트로 변환)
length = bream_length+smelt_length
weight = bream_weight+smelt_weight
fish_data = [[l, w] for l, w in zip(length, weight)]
- knn 알고리즘은 지도학습이기 때문에 정답을 알려줘야한다. (도미 1, 빙어 0)
fish_target = [1]*35 + [0]*14
- knn 알고리즘
from sklearn.neighbors import KNeighborsClassifier
# 객체 생성
kn = KNeighborsClassifier()
# 학습
kn.fit(fish_data, fish_target)
# 정확도 == 훈련이 잘 되었는지 평가 (0~1)
kn.score(fish_data, fish_target) # 1.0
- k 개수 변경
kn49 = KNeighborsClassifier(n_neighbors=49)
kn49.fit(fish_data, fish_target)
# 주변 49개를 통해 결정 -> 전체가 49개이므로 모두다 1로 예측
kn49.score(fish_data, fish_target) # 35/49 == 0.7142...
위의 과정은 훈련시킬 때 사용한 데이터를 그대로 테스트 했다. 이미 답을 알고 있는 데이터로 테스트하면 올바른 결과를 도출할 수 없다. 따라서 알고리즘의 성능을 제대로 평가하기 위해서는 훈련데이터와 테스트데이터가 달라야 한다.
(샘플링이 한쪽으로 치우쳐있는 경우를 샘플링 편향이라고 부른다.)
- 데이터 섞기
import numpy as np
input_arr = np.array(fish_data)
target_arr = np.array(fish_target)
np.random.seed(42)
index = np.arange(49)
# 배열을 무작위로 섞기
np.random.shuffle(index)
- 훈련데이터 테스트 데이터 나누기 1
train_input = input_arr[index[:35]]
train_target = target_arr[index[:35]]
test_input = input_arr[index[35:]]
test_target = target_arr[index[35:]]
import matplotlib.pyplot as plt
# 데이터가 잘 섞였는지 확인
plt.scatter(train_input[:, 0], train_input[:, 1]) # 파란색
plt.scatter(test_input[:, 0], test_input[:, 1]) # 주황색
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
데이터가 잘 섞였다.
- 모델 훈련 + 예측
kn.fit(train_input, train_target)
# 정확도
kn.score(test_input, test_target) # 1.0
# 예측
kn.predict(test_input)
# array([0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0])
# 정답
test_target
# array([0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0])
- 훈련데이터 테스트 데이터 나누기 2
from sklearn.model_selection import train_test_split
# stratify : 타깃 데이터를 전달하면 클래스 비율에 맞게 데이터 나눈다.
train_input, test_input, train_target, test_target = train_test_split(
fish_data, fish_target, stratify=fish_target, random_state=42)
KNN 알고리즘 주의점
- 훈련된 모델에 길이가 25cm, 무게가 150g인 생선 테스트해보자
print(kn.predict([[25, 150]])) # 0
- 산점도로 보기
import matplotlib.pyplot as plt
plt.scatter(train_input[:,0], train_input[:,1])
plt.scatter(25, 150, marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
산점도로 보면 [25,150] 데이터는 도미에 가까워보인다. 하지만 predict()는 빙어라고 한다. 왜 이런 문제가 발생하는지 알아보자
- [25,150] 데이터와 가까운 이웃 찾아보기
distances, indexes = kn.kneighbors([[25, 150]])
plt.scatter(train_input[:,0], train_input[:,1])
plt.scatter(25, 150, marker='^')
plt.scatter(train_input[indexes,0], train_input[indexes,1], marker='D')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
그림이 뭔가 이상해보인다. 분명 도미가 더 가까워보이는데 빙어가 체크되었다. 그 이유는 x축과 y축의 스케일이 다르다. 스케일이 다를 경우 데이터 전처리를 해줘야 한다. 표준점수를 사용해보자
- 데이터 전처리
mean = np.mean(train_input, axis=0)
std = np.std(train_input, axis=0)
# 표준화하기
train_scaled = (train_input - mean) / std
plt.scatter(train_scaled[:,0], train_scaled[:,1])
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
- [25,150]도 산점도에 표시
- 표준화를 해줘야한다.
# 테스트 데이터도 표준화
new = ([25, 150] - mean) / std
plt.scatter(train_scaled[:,0], train_scaled[:,1])
plt.scatter(new[0], new[1], marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
- [25,150] 데이터 예측
kn.fit(train_scaled, train_target)
test_scaled = (test_input - mean) / std
kn.score(test_scaled, test_target) # 1.0
print(kn.predict([new])) # [1.]
도미로 예측 성공!
- [25,150] 데이터의 가까운 이웃
distances, indexes = kn.kneighbors([new])
plt.scatter(train_scaled[:,0], train_scaled[:,1])
plt.scatter(new[0], new[1], marker='^')
plt.scatter(train_scaled[indexes,0], train_scaled[indexes,1], marker='D')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
728x90
반응형
'TIL - 외 > 빅데이터' 카테고리의 다른 글
[머신러닝] 로지스틱 회귀 (0) | 2023.05.07 |
---|---|
[머신러닝] 회귀 알고리즘 및 실습 (0) | 2023.04.13 |
불균형 데이터 (imbalanced data) 처리를 위한 샘플링 기법 (0) | 2023.03.22 |
구름 php + MySQL (0) | 2022.11.02 |
DB (0) | 2022.11.02 |
댓글