본문 바로가기
TIL - 외/빅데이터

[머신러닝] K-최근접 이웃(KNN) 알고리즘 및 실습

by chaemj97 2023. 4. 9.
728x90

KNN 알고리즘이란

  • 가장 간단한 머신러닝 알고리즘, 분류(Classification) 알고리즘
  • 어떤 데이터에 대한 답을 구할 때 주위의 다른 데이터를 보고 다수를 차지하는 것을 정답으로 사용
  • 새로운 데이터에 대해 예측할 때는 가장 가까운 직선거리에 어떤 데이터가 있는지 살피기만 하면 된다.(k =1)
  • 단점
    • 데이터가 아주 많은 경우 사용하기 어렵다
    • 데이터가 크기 때문에 메모리가 많이 필요하고 직선 거리를 계산하는 데도 많은 시간이 필요
  • 실제로 k-최근접 알고리즘은 무언가 훈련되는 게 없다.
    • fit() 메서드에 전달한 데이터를 모두 저장하고 있다가 새로운 데이터가 등장하면 가장 가까운 데이터를 참고하여 분류
      • 객체._fit_X
      • 객체._y
  • k의 기본값은 5
    • n_neighbors 매개변수로 바꿀 수 있다.
  • p 매개변수로 거리를 재는 방법 지정
    • 1일 경우 맨해튼 거리
      • 직선
    • 2일 경우 유클리디안 거리 (기본값)
      • 계단식(절댓값 합)

KNN 알고리즘 실습 - 도미와 빙어 분류하기

  • 도미 데이터 (35마리)
# 생선의 길이
bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
# 생선의 무게
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]
  • 빙어 데이터 (14마리)
smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]
  • 데이터 분포 확인
import matplotlib.pyplot as plt

plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight) 
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

파란색 -> 도미데이터 선형적, 길이가 길수록 무게가 많이 나감

주황색 -> 빙어데이터 무게가 길이에 영향을 덜 받는다.

 

  • 두 데이터 합치기 (사이킷런을 사용하기 위해 2차원 리스트로 변환)
length = bream_length+smelt_length
weight = bream_weight+smelt_weight

fish_data = [[l, w] for l, w in zip(length, weight)]
  • knn 알고리즘은 지도학습이기 때문에 정답을 알려줘야한다. (도미 1, 빙어 0)
fish_target = [1]*35 + [0]*14
  • knn 알고리즘
from sklearn.neighbors import KNeighborsClassifier
# 객체 생성
kn = KNeighborsClassifier()
# 학습
kn.fit(fish_data, fish_target)
# 정확도 == 훈련이 잘 되었는지 평가 (0~1)
kn.score(fish_data, fish_target) # 1.0
  • k 개수 변경
kn49 = KNeighborsClassifier(n_neighbors=49)
kn49.fit(fish_data, fish_target)
# 주변 49개를 통해 결정 -> 전체가 49개이므로 모두다 1로 예측
kn49.score(fish_data, fish_target) # 35/49 == 0.7142...

 

위의 과정은 훈련시킬 때 사용한 데이터를 그대로 테스트 했다. 이미 답을 알고 있는 데이터로 테스트하면 올바른 결과를 도출할 수 없다. 따라서 알고리즘의 성능을 제대로 평가하기 위해서는 훈련데이터와 테스트데이터가 달라야 한다.

(샘플링이 한쪽으로 치우쳐있는 경우를 샘플링 편향이라고 부른다.)

  • 데이터 섞기
import numpy as np

input_arr = np.array(fish_data)
target_arr = np.array(fish_target)

np.random.seed(42)
index = np.arange(49)
# 배열을 무작위로 섞기
np.random.shuffle(index)
  • 훈련데이터 테스트 데이터 나누기 1
train_input = input_arr[index[:35]]
train_target = target_arr[index[:35]]

test_input = input_arr[index[35:]]
test_target = target_arr[index[35:]]

import matplotlib.pyplot as plt

# 데이터가 잘 섞였는지 확인
plt.scatter(train_input[:, 0], train_input[:, 1]) # 파란색
plt.scatter(test_input[:, 0], test_input[:, 1]) # 주황색
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

데이터가 잘 섞였다.

  • 모델 훈련 + 예측
kn.fit(train_input, train_target)

# 정확도
kn.score(test_input, test_target) # 1.0

# 예측
kn.predict(test_input)
# array([0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0])

# 정답
test_target
# array([0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0])

 

  • 훈련데이터 테스트 데이터 나누기 2
from sklearn.model_selection import train_test_split

# stratify : 타깃 데이터를 전달하면 클래스 비율에 맞게 데이터 나눈다.
train_input, test_input, train_target, test_target = train_test_split(
    fish_data, fish_target, stratify=fish_target, random_state=42)

 

KNN 알고리즘 주의점

  • 훈련된 모델에 길이가 25cm, 무게가 150g인 생선 테스트해보자
print(kn.predict([[25, 150]])) # 0
  • 산점도로 보기
import matplotlib.pyplot as plt

plt.scatter(train_input[:,0], train_input[:,1])
plt.scatter(25, 150, marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

산점도로 보면 [25,150] 데이터는 도미에 가까워보인다. 하지만 predict()는 빙어라고 한다. 왜 이런 문제가 발생하는지 알아보자

  • [25,150] 데이터와 가까운 이웃 찾아보기
distances, indexes = kn.kneighbors([[25, 150]])

plt.scatter(train_input[:,0], train_input[:,1])
plt.scatter(25, 150, marker='^')
plt.scatter(train_input[indexes,0], train_input[indexes,1], marker='D')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

그림이 뭔가 이상해보인다. 분명 도미가 더 가까워보이는데 빙어가 체크되었다. 그 이유는 x축과 y축의 스케일이 다르다. 스케일이 다를 경우 데이터 전처리를 해줘야 한다. 표준점수를 사용해보자

  • 데이터 전처리
mean = np.mean(train_input, axis=0)
std = np.std(train_input, axis=0)

# 표준화하기
train_scaled = (train_input - mean) / std

plt.scatter(train_scaled[:,0], train_scaled[:,1])
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

  • [25,150]도 산점도에 표시
    • 표준화를 해줘야한다.
# 테스트 데이터도 표준화
new = ([25, 150] - mean) / std

plt.scatter(train_scaled[:,0], train_scaled[:,1])
plt.scatter(new[0], new[1], marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

  • [25,150] 데이터 예측
kn.fit(train_scaled, train_target)

test_scaled = (test_input - mean) / std
kn.score(test_scaled, test_target) # 1.0

print(kn.predict([new])) # [1.]

도미로 예측 성공!

  • [25,150] 데이터의 가까운 이웃
distances, indexes = kn.kneighbors([new])

plt.scatter(train_scaled[:,0], train_scaled[:,1])
plt.scatter(new[0], new[1], marker='^')
plt.scatter(train_scaled[indexes,0], train_scaled[indexes,1], marker='D')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

728x90
반응형

'TIL - 외 > 빅데이터' 카테고리의 다른 글

[머신러닝] 로지스틱 회귀  (0) 2023.05.07
[머신러닝] 회귀 알고리즘 및 실습  (0) 2023.04.13
불균형 데이터 (imbalanced data) 처리를 위한 샘플링 기법  (0) 2023.03.22
구름 php + MySQL  (0) 2022.11.02
DB  (0) 2022.11.02

댓글