본문 바로가기
Data Analytics with python/[Machine Learning ]

[Regression/Classification] 교차 검증과 하이퍼 파라미터 튜닝 + pipeline

by 보끔밥0130 2023. 2. 16.
728x90

GridSearchCV

교차검증과 하이퍼 파라미터 튜닝을 합니다.

 

랜덤 포레스트 하이퍼 파라미터 튜닝으로 예시

from sklearn.model_selection import GridSearchCV

params = {
    'n_estimators':[100],
    'max_depth' : [6, 8, 10], 
    'min_samples_leaf' : [8, 12, 18],
    'min_samples_split' : [8, 16, 20]
}
# model 객체 생성 후 GridSearchCV 수행
rf_clf = RandomForestClassifier(random_state=0, n_jobs=-1)
grid_cv = GridSearchCV(rf_clf , param_grid=params , cv=2, n_jobs=-1 )
grid_cv.fit(X_train , y_train)

print('최적 하이퍼 파라미터:\n', grid_cv.best_params_)
print('최고 예측 정확도: {0:.2f}'.format(grid_cv.best_score_))

 

Pipeline

피처 벡터화와 ML알고리즘 학습/예측을 위한 코드 작성을 한 번에 진행 가능합니다.

 

데이터의 가공, 변환 등의 전처리와 알고리즘 적용을 마치 '수도관(pipe)에서 물이 흐르듯' 한 번에 스트림 기반으로 처리합니다.

 

# Pipeline
from sklearn.pipeline import Pipeline

pipeline = Pipeline([ ('scaler',StandardScaler()),
('lr_clf', LogisticRegression(random_state = 42))					
])

pipeline.fit(X_train,y_train)
y_pred = pipeline.predict(X_test)

# GridSearchCV + Pipeline
pipeline = Pipeline([ ('scaler',StandardScaler()),
('lr_clf', LogisticRegression(random_state = 42))					  
])

#Pipeline에 기술된 각각의 객체 변수에 언더바(_) 2개를 연달아 붙여 GridSearchCV에 사용될
# 파라미터/하이퍼 파리미터 이름과 값을 설정
params = {
    'rf_clf__n_estimators':[100],
    'rf_clf__max_depth' : [6, 8, 10], 
    'rf_clf__min_samples_leaf' : [8, 12, 18],
    'rf_clf__min_samples_split' : [8, 16, 20]
}

# GridSearchCV의 생성자에 Estimator가 아닌 Pipeline 객체 입력
grid_cv_pipe = GridSearchCV(pipeline, param_grid=params, cv=3 , scoring='accuracy',verbose=1)
grid_cv_pipe.fit(X_train , y_train)
print(grid_cv_pipe.best_params_ , grid_cv_pipe.best_score_)

pred = grid_cv_pipe.predict(X_test)
print('Pipeline을 통한 Logistic Regression 의 예측 정확도는 {0:.2f}'.format(accuracy_score(y_test ,pred)))
728x90

'Data Analytics with python > [Machine Learning ]' 카테고리의 다른 글

RFC, GBM, XGBoost, LightGBM hyper-parameter tuning  (0) 2023.06.19
OLS 고급 수식 구문  (0) 2023.02.22
[Classification] Decision Tree  (0) 2023.02.15
[Clustering] DBSCAN  (0) 2023.02.15
[Clustering] GMM  (0) 2023.02.15

댓글