관리 메뉴

Joonas' Note

[딥러닝 일지] 학습 조기 종료 (Early Stop) 본문

AI/딥러닝

[딥러닝 일지] 학습 조기 종료 (Early Stop)

joonas 2022. 3. 13. 03:37

    이전 글 - [딥러닝 일지] 과적합 문제, 그리고 배치 전략

     

    [딥러닝 일지] 과적합 문제, 그리고 배치 전략 (교차 검증)

    이전 글 - [딥러닝 일지] 이진 분류를 위한 CNN 모델 작성 (개 vs 고양이) [딥러닝 일지] 이진 분류를 위한 CNN 모델 작성 (개 vs 고양이) 이전 글 : [딥러닝 기록] 시작하기 - 개 vs 고양이 분류 [딥러닝

    blog.joonas.io


    이번 글은 Version 26을 기준으로 설명한다.

     

    Dogs vs. Cats Classification

    Explore and run machine learning code with Kaggle Notebooks | Using data from multiple data sources

    www.kaggle.com

     

    적당할 때 끝내기

    실험을 하다보면 분명히 학습이 잘못되고 있는 경우가 있다.

    전체 학습 데이터로 학습을 돌려놓고 자러갈 때, 30 에포크를 전부 돌면 엄청 오래 걸린다. 그럼 나의 아까운 GPU 가용량을 엄청 소비하게 된다.

    그리고 너무 오래 걸린다!!!! epoch 하나에 6분 정도 걸렸는 데, 30 epoch를 다 돌면 180분(3시간)이다...

    그럼 언제 끝내는 것이 좋을까 고민하다가 아래 3가지 조건을 세워보았다.

    1. 그래프가 점점 벌어지는 경우 (과적합) → train과 valid의 loss 차이가 너무 크면 종료하자.
    2. 충분히 학습한 경우 → train과 valid를 합쳐서 봤을 때, loss가 굉장히 작으면 종료하자.
    3. 학습에 더 이상 진전이 없는 경우 → 이전 3개의 epoch를 봤을 때, loss가 비슷하면 종료하자.
      • 3번은 코드만 길어지고 한번씩 튀어오르는 경우를 배제하는 것 같아서 제거했다.

    [11]번 노트에서는 조기 종료를 이런 식으로 작성했다.

    # Train set으로 학습하는 부분
    model.train()
    for inputs, labels in tqdm(train_loader, desc=f'train model ({epoch+1}/{EPOCHS} epoch, fold={fold})'):
    	...
    
    # Valid set으로 정확도를 계산하는 부분
    model.eval()
    with torch.no_grad():
        for images, labels in valid_loader:
            ...
        ...
    
    # Early stop - 학습의 loss가 검증의 loss보다 너무 크면(3배) 종료
    if train_loss * 3 < valid_loss:
        early_stop = True
        break
    
    # Early stop - 학습과 검증의 평균 loss가 0.1 이하면 종료
    if np.array([train_loss, valid_loss]).mean() < 0.1:
        early_stop = True
        break
    
    if early_stop: break

    텐서플로나 파이토치에서 조기 종료를 위해서 따로 제공하는 함수가 있을텐데, 뭐가 있는 지도 모르고 어떻게 사용하는 지도 모른다.

    그래서 적당히 그래프가 괜찮게 그려지고, 만족할만한 결과가 나오도록 (어차피 파이토치 연습용 스크립트이므로) 작성했다.

    조기 종료하고 그린 loss 그래프

    원래 30번의 epoch를 전부 돌아야하는데, 16번만 돌고 끊긴 모습이다.

     

    많이 쓰이는 기본 형태

    UPDATED 2022/06/03

    여러 노트북이나 툴을 보니 EarlyStopping을 콜백 형태로 작성하고 사용하는 곳이 많았다.

    그리고 대부분은 아주 단순하게 작성하고 있었다.
    이전에 가장 좋았던 loss보다 patience 만큼 참아보고 아무리 참아도 더 나은 결과가 나오지 않으면 끝내버리는 식이다.

    class EarlyStopping:
        def __init__(self, patience=5):
            self.loss = np.inf
            self.patience = 0
            self.patience_limit = patience
            
        def step(self, loss):
            if self.loss > loss:
                self.loss = loss
                self.patience = 0
            else:
                self.patience += 1
        
        def is_stop(self):
            return self.patience >= self.patience_limit

    위처럼 클래스를 구현해서 학습 중간에 아래와 같이 종료시킬 수 있다.

    # 5번 안에 더 좋아지지 않으면 종료
    early_stop = EarlyStopping(patience=5)
    
    for epoch in range(epochs):
    	# ...
    	loss = criterion(y_true, y_hat)
        early_stop.step(loss.item())
        # ...
        if early_stop.is_stop():
        	break
    반응형
    0 Comments
    댓글쓰기 폼