선형회귀(Linear Regression) – 파이썬 코드 예제

본 포스팅에서는 파이썬 라이브러리 scikit-learn을 통해 선형회귀(Linear Regression) 분석을 직접 수행하는 예제를 소개한다. 누구나 쉽게 따라할 수 있는 수준으로 작성했다.

이전 포스팅에서는 선형 회귀의 기초적인 개념에 대해서는 간단히 짚어봤다.

이제는 직접 돌려봐야지.

sklearn LinearRegression 사용법

실제 데이터 돌려보기 전에 사용법부터 익히고 가자.

일단 그 유명한 파이썬 머신러닝 라이브러리 싸이킷런을 불러오자.

from sklearn.linear_model import LinearRegression

이제 LinearRegression 모델을 생성하고, 그 안에 X, y 데이터를 fit 시킨다. 이렇게.

line_fitter = LinearRegression()
line_fitter.fit(X, y)

fit() 메서드는 선형 회귀 모델에 필요한 두 가지 변수를 전달하는 거다.

  • 기울기: line_fitter.coef_
  • 절편: line_fitter.intercept_

어쨌든 이게 끝이다. 이렇게 하면 새로운 X 값을 넣어 y값을 예측할 수 있게 된다.

y_predicted = line_fitter.predict(X)

만약 기울기와 절편을 알고 싶다면 line_fitter.coef_ , line_fitter.intercept_를 직접 찍어보면 된다.

쉽다.

아, 그리고 이전 포스팅에서 수렴할 때까지 얼마나 반복할 것인지(num_iterations), 얼마나 꼼꼼히 학습할 것인지(learning_rate) 정해줘야 한다고 했는데, 일단은 신경 쓰지 말자. scikit-learn에서 알아서 기본 값을 제공한다.

sklearn LinearRegression 실전 예제

일단 필요한 라이브러리를 불러온다.

sklearn 외에도 데이터를 불러올 때 필요한 pandas, 배열을 바꿀 때 필요한 numpy, 시각화를 위한 matplotlib를 함께 불러왔다.

from sklearn.linear_model import LinearRegression
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

일단 csv 파일로 키와 몸무게가 들어있는 파일을 준비했다. pandas로 윗 부분만 찍어보면 아래와 같이 나온다.

df = pd.read_csv("heights.csv")
df.head()

키와 몸무게가 우리나라 상식에는 좀 어긋난데, 아마 미국 기준의 데이터라서 그런 것 같다. 아무튼… 그냥 넘어가자.

matplotlib으로 시각화를 해보면 아래와 같이 나온다.

X = df["height"]
y = df["weight"]
plt.plot(X, y, 'o')<br>plt.show()

대충 우상향 하는 패턴이 보인다.

이제 위에서 배운대로 모델을 생성하고 데이터를 fit 시킨다.

line_fitter = LinearRegression()
line_fitter.fit(X.values.reshape(-1,1), y)

여기서 주의해야 할 점은 X데이터를 넣을 때 .values.reshape(-1,1)를 해줬다는 거다. 왜냐하면 X는 2차원 array 형태여야 하기 때문이다. 이런 식으로 [[x1], [x2], [x3], ... , [xn]] . (이렇게 넣는 이유는 X 변수가 하나가 아니라 여러개일 때 다중회귀분석을 실시하기 위함인데, 이는 다른 포스팅에서 소개한다.)

아무튼 이렇게 하면 끝이다.

이제 한 번 예측을 해보자. 키가 70인 사람을 예측한다고 치면 이렇게.

line_fitter.predict([[70]])

이런 값을 돌려준다. array([134.2596226]). 몸무게가 134 정도 되나보다.

그렇다면 기울기를 알려달라고 해보자.

line_fitter.coef_

array([3.43267613])를 돌려준다.

이번엔 절편을 알려달라고 해보자.

line_fitter.intercept_

-106.02770644878126라고 한다.

이번엔 기존 X 값으로 y를 예측하게 해서 그래프를 그려보자. 당연히 선이 나올 거다.

plt.plot(X, y, 'o')
plt.plot(X,line_fitter.predict(X.values.reshape(-1,1)))
plt.show()

여기까지다.

다음엔 이렇게 1차 방정식으로 설명이 가능한 단순선형회귀 분석 말고, 조금 더 복잡한 다중선형회귀 분석(Multiple Linear Regression)을 알아보자.

추천 글


댓글 남기기