Joonas' Note

Joonas' Note

Quick, Draw! 클론 코딩 해보기 본문

개발

Quick, Draw! 클론 코딩 해보기

2024. 5. 26. 18:50 joonas

    먼저 Quick, Draw! 는 구글에서 위와 같은 낙서 데이터 셋을 학습하여 345개의 주제 중 하나를 그리면 머신 러닝 모델을 통해 정답을 맞추는 게임이다. 그리고 이 데이터셋을 오픈소스로 공유했다.


    목표 설정

    시계열 데이터를 Online 으로 처리하는 RNN 모델을 다루고 싶었고 최종적으로는 브라우저에서 돌아가도록 포팅하는 것이 목표였으나, 아래 서술될 이유로 개발 방향을 잠시 수정하였다.

    345개의 클래스로 적지 않은 클래스를 분류하는 모델이지만 각 클래스당 최소 1만개 이상의 Dataset이 있으므로 학습에는 어려움이 없을 것이라고 판단하였다.


    목표 수정

    간단한 RNN 모델을 만들었는데 학습이 잘 되지 않았었고, 이를 해결하는 과정에서 RNN 구조의 문제인지 데이터 전처리의 문제인지 파악하기가 어려워, 방향을 우선 이미지 기반의 CNN 모델로 변경하기로 했다. (개발자 역량 이슈)

    그럼에도 구글에서 제공하는 공식 데이터셋이 어떤 그림을 이미지로 제공하는게 아니라, RDP(Ramer–Douglas–Peucker) 알고리즘으로 단순화 된 선분을 제공한다. 그래서 CNN 구조로 학습하고 싶다면 RDP 데이터를 그림으로 복원시키는 코드를 작성해야 했고 작성했었다.

    정작 CNN 모델은 작성하지 않은채로 몇 개월 방치했다가, 우연히 QuickDraw API 파이썬 패키지를 발견했고, 이미지 데이터를 쉽게 받아오고 학습할 수 있어서 우선은 CNN 모델을 사용하여 전체 앱 파이프라인부터 빠르게 완성하기로 결정했다.


    전체 구성

    앱부터 AI 모델까지의 구성

    모델은 Python 으로 작성되었고 실행 역시 Python 으로 되는데, 이것을 브라우저에서 실행하기란 어려운 일이다.

    그래서 여러 플랫폼과 여러 프로세서에서 모델이 동작할 수 있도록 하는 표준인 ONNX 로 변환이 필요하다.

    ONNX Runtime을 기준으로 사용자가 그림을 그리면 모델에 넘겨 실행하는 부분과, 그러한 모델을 학습하는 부분으로 나누었다.

    모델 학습

    PyTorch 베이스로 AI 모델을 작성하고, Lightning AI로 반복적인 학습 루프 코드나 학습을 중단하는 등의 콜백을 쉽게 추가했다.

    매번 스크래치로 학습 루프를 직접 다 작성했는데 이번에 Lightning AI + W&B 조합으로 돌려보니 모델과 파라미터에 집중하기 좋았다.

    wandb 연동 페이지

    모델은 Image classifier로 RGB Image를 입력으로 받아 345개 클래스에 대한 확률을 예측하는 모델이다.

    출처: https://www.datacareer.de/blog/quick-draw-classifying-drawings-with-python/

    여러 실험을 해보았는데, 아래와 같이 Convolutional Block 4개를 적절히 쌓은 구조가 학습 속도도 괜찮으면서 정확도가 가장 높게 나왔다.

    CNNModel(
      (conv_layer): Sequential(
        # input: (3, 32, 32)
        (0): Conv2d(3, 32, kernel_size=(2, 2), stride=(1, 1))    # (32, 31, 31)
        (1): ReLU()
        (2): Conv2d(32, 64, kernel_size=(2, 2), stride=(1, 1))   # (64, 30, 30)
        (3): ReLU()
        (4): MaxPool2d(kernel_size=2, stride=2)                  # (64, 15, 15)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))  # (128, 13, 13)
        (6): ReLU()
        (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1)) # (256, 11, 11)
        (8): ReLU()
        (9): MaxPool2d(kernel_size=3, stride=2)                  # (256, 5, 5)
      )
      # flatten it into 256x5x5 = 6400
      (classifier): Sequential(
        (0): Dropout(p=0.1, inplace=False)
        (1): Linear(in_features=6400, out_features=345, bias=True) # (1, 6400) -> (1, 345)
        (2): LogSoftmax(dim=1)
      )
      # output: (1, 345)
    )

    위 모델(이하 68번 모델) 구조로 학습했을 때, valid set에서 82.7%의 정확도, test set에서 83.8%의 정확도가 나왔다.

    64x64 크기나 128x128 크기의 이미지를 입력으로도 해보았지만 conv block 이 너무 깊었거나, 뒤 쪽 FC layer의 파라미터가 너무 많아서인지 학습이 잘 안되었다.

    68번 모델의 경우 파라미터의 크기는 conv layer 377K, classifier 2.2M 으로 합쳐서 2.6M 정도고, 학습은 14시간정도 걸렸다.


    웹 페이지

    유저가 직접 그림을 그려서 모델에 입력으로 넣을 수 있도록 웹 페이지를 게임 형태로 만든다.

    아래 서술된 글의 순서는 개발 순서와 동일하다. 이미지가 있다고 가정하고 모델을 먼저 포팅했고, 유저가 이미지를 그릴 수 있도록 UI를 작업한 이후 두 파트를 합쳐서 동작하는 지 확인했다.

    모델 로드

    먼저 CNN Classifier 모델을 브라우저에서 돌릴 수 있도록 해야한다. 이를 위해서 먼저 ONNX 모델로 변환하고 브라우저에서 ONNX Runtime Web (구 ONNX.js) 라이브러리를 통해 불러온다. ONNX Runtime 예제 리포지토리를 보고 참고할 수 있다.

    아래와 같이 ONNX Runtime Web 라이브러리를 불러온 후에,

    <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>

    ONNX 모델을 불러와서 세션을 만들어 Tensor로 변환하고 inference 까지 할 수 있는 클래스를 다음과 같이 작성했다.

    class Model {
        constructor() { }
    
        async load() {
            this.session = await ort.InferenceSession.create('./public/model_cnn_74.optimized.onnx');
        }
    
        /**
         *
         * @param {Float32Array} rgb_input
         * @returns probablities of classes
         */
        async infer(rgb_input) {
            const dims = [1, 3, 32, 32];
            const tensor = new ort.Tensor('float32', rgb_input, dims);
            const result = await this.session.run({ 'input.1': tensor });
            const probs = Array.from(result[Object.keys(result)].cpuData);
            return probs.map((x, i) => ({ logit: x, probability: Math.exp(x), index: i }));
        }
    }

    배치 크기는 1, 입력 차원은 32x32 사이즈의 3채널 RGB Image 이므로 3x32x32 으로 맞춰준다.

    생성한 세션에 입력 형식을 맞추고 실행하면 CPU 에서 실행한 결과를 받을 수 있다. 문서에서는 CPU 실행 시 내부적에서 WebAssembly로 실행되기 때문에 속도가 빠르다고 설명되고 실제로도 20ms 정도 내로 빨랐다.

    UI 작업

    그림 그리기

    HTML Canvas 에다가 마우스 이벤트 리스너를 붙여서 그림을 그릴 수 있도록 작업한다. 이 부분은 워낙 예제도 많고 코드가 전형적이라서 크게 다른 부분이 없다.

    이미지 잘라내기

    UI를 화면 전체가 캔버스가 되도록 구성했기 때문에, 캔버스의 크기는 무척 크다. 하지만 모델이 학습한 이미지를 살펴보면 그림이 그려진 부분만 입력이 되어야한다.

    모델이 학습한 이미지 배치 (32x32 크기)

    즉, 그려진 그림의 상하좌우 여백이 거의 없도록 이미지를 잘라내어야 한다.
    그래서 아래와 같이 마우스 이벤트가 발동했던 부분의 좌표만 추려내서 그림이 그려진 영역만 계산했다.

    canvas.addEventListener('mousemove', (evt) => {
        // ...
        const { clientX, clientY } = evt;
        cursor.px = clientX;
        cursor.py = clientY;
        const PADDING = 5;
        cropBox.top = Math.min(cropBox.top, clientY - PADDING);
        cropBox.bottom = Math.max(cropBox.bottom, clientY + PADDING);
        cropBox.left = Math.min(cropBox.left, clientX - PADDING);
        cropBox.right = Math.max(cropBox.right, clientX + PADDING);
    });

    이미지 크기 조정하기

    모델은 32x32 크기의 이미지를 학습했다. 즉, 모델의 입력 역시 32x32 픽셀 이미지, 엄밀히는 3x32x32 차원의 텐서를 입력받아야한다
    따라서 캔버스에 그려진 그림은 모델에 들어가기 전에 32x32 크기로 조정되어야한다.

    이러한 과정에서 그려진 그림이 많이 찌그러지고 손상되는데, 디버그 버튼을 클릭하면 어떤 이미지가 모델로 넘어가는 지 볼 수 있다.

    모델과 자연스러운 상호작용

    구글의 Quick, Draw! 사이트를 보면 모델이 계속 정답을 맞추려는 듯이 멘트를 한다. 선을 하나 긋고 커서를 떼었을 때 대답하는 방식이 아니라 적절한 시간이 지나면 그렸던 데이터를 모델에 넣는 것처럼 보였다.

    비슷한 UX를 위해서 커서가 떨어지지 않더라도 1~2초 주기로 모델에 현재까지 그려진 그림을 넣고 유저에게 결과를 보여주도록 했다.
    대신 캔버스에 변경이 있는 경우에만 입력을 전달해서, 유저가 아무 행동도 하지 않는 경우에는 모델이 혼잣말을 중얼거리는 듯한 효과만 보여주기로 했다. 


    결과물

    실제 애플리케이션 화면

     

    Quick, Draw! Clone

    다음을 그려주세요. 30초 이내 알겠어요!

    www.joonas.io

    감상 그리고 향후 계획

    모델 학습 단계에서 데이터셋과의 정확도도 84% 아래로 낮은 편이었는데, 실제로 웹 서비스에 적용하면서 생기는 이미지 전처리까지 합쳐지면서 정확도가 더 낮아진 것이 체감된다.

    사실 이러한 부분을 염려하여 RNN 기반으로 정답을 추론하고 싶었다. 그래서 앞으로 2가지의 업데이트를 모두 해보고싶은데, 하나는 CNN 모델에 Residual block 을 추가하여 모델의 정확도를 높이는 것, 그리고 RNN 기반으로 모델을 만들고 이미지 손실 없이 선분을 구성하는 점의 순서, 위치만으로 추론하는 모델을 만드는 것.


    전체 코드

     

    GitHub - joonas-yoon/quick-draw-clone

    Contribute to joonas-yoon/quick-draw-clone development by creating an account on GitHub.

    github.com


    참고

     

    onnxruntime-inference-examples/js/quick-start_onnxruntime-web-script-tag/index.html at main · microsoft/onnxruntime-inference-e

    Examples for using ONNX Runtime for machine learning inferencing. - microsoft/onnxruntime-inference-examples

    github.com

     

    Issue loading model using onnx web · Issue #9322 · microsoft/onnxruntime

    Describe the bug A clear and concise description of what the bug is. To avoid repetition please make sure this is not one of the known issues mentioned on the respective release page. I have a proj...

    github.com

     

    Comments