본문 바로가기

Study/ML 공부

원핫인코더 fit은 한번만 할 껀데? (feat.check_is_fitted, try-except , if - else)

사이킷런의 OneHotEncoder.fit을 예외처리와 조건문을 통해 한번만 한다.

코드는 블로그 코드 블록에서 날 코딩한 것, 불편해도 넘어가자

일단 원핫인코딩을 하기 전에 '머신러닝 모델을 사용한 학습부터 예측까지'라고 하면 아래와 같은 일련의 과정이 필요하다고 해보자

  1. Data load
  2. Data preprocessing
  3. Model.fit
  4. Model.predict

1번부터 4번까지의 파이프라인을 클래스로 작성할 것이다.

아래와 같이 대충 HGomForest 모델의 메서드를 정의했다.
Class HGomForest:
    def __init__(self):
    	pass
    def dataload(self, data):
    	pass
    
    def preprocessing(self, data):
    	pass
    
    def fit(self, data):
    	pass
    
    def predict(self, data):
    	pass
그리고 데이터를 불러왔다 가정하고 이제 preprocessing에서 데이터를 원핫인코딩을 할 것이다.
그러기 위해선 fit을 통해 OneHotEncoder 객체를 학습하고 인코딩할 데이터를 transform 시키는 것이다.
OneHotEncoder를 설명하는 글이 아니니 자세한 설명은 넘어간다. 아무튼 그렇다 
from sklearn.preprocessing import OneHotEncoder
Class HGomForest:
    def __init__(self):
    	self.ohe = OneHotEncoder()
    	pass
    def dataload(self, data):
    	pass
    
    def preprocessing(self, data):
    	self.ohe.fit(data)
        data = self.ohe.transform(data)
    	pass
    
    def fit(self, data):
    	pass
    
    def predict(self, data):
    	pass
위 코드에서 우리는 인코딩까지 하고 Model.fit까지 완료했다.
마지막은 predict를 통해 데이터를 예측할 것이다. 인코딩이 전혀 안된 데이터이기 때문에 OneHotEncoder를 통해 변형시켜주어야 한다.
여기 주의할 점은, train_dataset으로 fit 한 OneHotEncoder 객체로 predict_dataset을 transform 시켜야 한다.  다시 한번 말하지만, OneHotEncoder를 설명하는 글이 아니니 넘어간다.

from sklearn.preprocessing import OneHotEncoder
Class HGomForest:
    def __init__(self):
    	self.ohe = OneHotEncoder()
    	pass
    def dataload(self, data):
    	pass
    
    def preprocessing(self, data):
    	self.ohe.fit(data)
        data = self.ohe.transform(data)
    	return data
    
    def fit(self, data):
    	pass
    
    def predict(self, data):
    	data = self.ohe.transform(data)
        return data
위 코드처럼 각 메서드마다 OneHotEncoding 하는 부분이 1줄 혹은 2줄만 나올 리가.. 없다.
인코딩한 데이터를 가지고 이것도 하고 저것도 해서 점점 코드가 길어지고 preprocessing 메서드와 predict 메서드에서 중복 코드가 증가하게 될 것이다
그래서 아래와 같이  중복되는 부분을 메서드로 빼는 경우가 올 것이다.
from sklearn.preprocessing import OneHotEncoder
Class HGomForest:
    def __init__(self):
    	self.ohe = OneHotEncoder()
    	pass
    def dataload(self, data):
    	pass
    def do_one(self, data):
    	if ~~:
    	``self.ohe.fit(data)
        data = self.ohe.transform(data)
    	data = ~~~
        data = ~~~~
        
    def preprocessing(self, data):
        data = self.do_one(data)
		
    	return data
    
    def fit(self, data):
    	pass
    
    def predict(self, data):
    	data = self.do_one(data)
        return data

OneHotEncoder가 학습되어있다면 Transform만 수행, 아니면 학습까지 하는 경우를 조건문을 통해 구현해보겠다.
if_fitted라는 bool값은 가진 변수를 선언해, fit이 되면 True로 변환시키며 조건문을 수행한다.
from sklearn.preprocessing import OneHotEncoder
Class HGomForest:
    def __init__(self):
    	self.ohe = OneHotEncoder()
        self.is_fitted = False
    	pass
    def dataload(self, data):
    	pass
    def do_one(self, data):
    	if not self.is_fitted:
    	    self.ohe.fit(data)
        data = self.ohe.transform(data)
    	data = ~~~
        data = ~~~~
        
    def preprocessing(self, data):
        data = self.do_one(data)
        self.is_fitted = True
		
    	return data
    
    def fit(self, data):
    	pass
    
    def predict(self, data):
    	data = self.do_one(data)
        return data
위와 같은 수행을 사이킷런의 check_is_fitted메서드와 Try-Except를 같이 사용함으로써 수행할 수 있다.
check_is_fitted 메서드는 해당 객체가 fit 되어있나를 확인하는 사이킷런 메서드이다. 
객체가 학습되어있지 않다면 NotFittedError를 발생해서 try-except를 활용해 원하는 코드 수행을 하면 된다.
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils.validation import check_is_fitted, NotFittedError
Class HGomForest:
    def __init__(self):
    	self.ohe = OneHotEncoder()
    	pass
    def dataload(self, data):
    	pass
    def do_one(self, data):
        try:
            check_is_fitted(self.ohe)
        except:
            ~~
    def preprocessing(self, data):
        data = self.do_one(data)
		
    	return data
    
    def fit(self, data):
    	pass
    
    def predict(self, data):
    	data = self.do_one(data)
        return data