관리 메뉴

나는 문어~ 꿈을 꾸는 문어~

[Instance Segmentation] 캐글 - Sartorius (2) YOLOX 학습 본문

대회 & 프로젝트

[Instance Segmentation] 캐글 - Sartorius (2) YOLOX 학습

harrykur139 2022. 4. 4. 13:56

대회 링크 : https://www.kaggle.com/competitions/sartorius-cell-instance-segmentation/overview

코드 링크 : https://github.com/euiraekim/kaggle-sartorius-cell-instance-segmentation

 

 

Segmentation task에 웬 yolox인가.. box based segmentation을 하려고 한다. 말 그대로 detection 모델을 학습하고, 각각의 이미지를 학습하는 것이 아니라 이미지의 박스들을 segmentation 모델로 학습한다. inference 역시 bounding box를 얻고 그 box를 segmentation 한다.

 

이런 복잡한 프로세스를 거치는 이유는 성능이 좋기 때문이다. 그렇기 때문에 많은 상위 솔루션이 이 방법을 사용했다. 물론 실무에서 이러한 task를 마주쳤다면 정확도 뿐만 아니라 모델의 속도도 고려해야한다. 하지만 대회는 속도보다 정확도가 중요하므로 나도 이와 같은 방법을 사용해보고자 한다.

 

 

목표

모델은 YOLOX를 사용한다. 여기서도 병변 검출 대회처럼(해당 블로그 포스팅 참조) 여러 고성능 모델을 사용하여 앙상블을 하면 좋겠지만 목표는 Instance Segmentation이기 때문에, 여기에 크게 시간을 쏟기보단 좋다고 알려진 SOTA모델을 사용한다.(공부가 목적임) 그냥 사용하는 것은 아니고 해당 대회 1등 솔루션도 똑같은 YOLOX를 사용했고, 결과 성능도 좋다고 하였다.

 

coco dataset으로 사전 학습된 모델을 livecell dataset으로 fine tuning하고, 캐글 데이터셋으로 다시 fine tuning한다. 시간이 허락한다면 coco dataset을 바로 캐글 dataset으로 fine tuning하여 성능 비교도 해보고자 한다.

 

 

LIVECell YOLOX 학습

mmdetection을 사용하여 학습할 것이다. 아마 segmentation도 mmsegmentation을 사용할 것 같다. 좋은 프레임워크를 만들어주신 mmlab 여러분들 사는 국가는 다를 것이지만 정말 감사합니다.

 

mmdetection의 공식 github에서 tools/train.py 파일을 내 프로젝트의 utils/detection/train.py로 가져온다. yolox의 coco pretrained 모델은 여기를 클릭하여 받아 data/checkpoints/yolox_x_coco.pth로 넣어주고 학습 때 사용한다.

 

모델은 yolox 중 가장 무거운 모델인 yolox-x를 사용하였고 config 파일은 mmdetection의 기본 yolox-x의 설정을 일부만 바꾸고 대부분의 하이퍼 파라미터를 그대로 사용했다. 내가 사용한 config 파일은 다음과 같다. config 파일의 경로는configs/detection/yolox_x_livecell.py다.

img_scale = (960, 960)  # height, width

# model settings
model = dict(
    type='YOLOX',
    input_size=img_scale,
    random_size_range=(15, 25),
    random_size_interval=10,
    backbone=dict(type='CSPDarknet', deepen_factor=1.33, widen_factor=1.25),
    neck=dict(
        type='YOLOXPAFPN',
        in_channels=[320, 640, 1280],
        out_channels=320,
        num_csp_blocks=4),
    bbox_head=dict(
        type='YOLOXHead', num_classes=8, in_channels=320, feat_channels=320),
    train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
    # In order to align the source code, the threshold of the val phase is
    # 0.01, and the threshold of the test phase is 0.001.
    test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))

# dataset settings
data_root = 'data/'
dataset_type = 'CocoDataset'
classes = ('shsy5y', 'a172', 'bt474', 'bv2', 'huh7', 'mcf7', 'skov3', 'skbr3')

train_pipeline = [
    dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
    dict(
        type='RandomAffine',
        scaling_ratio_range=(0.1, 2),
        border=(-img_scale[0] // 2, -img_scale[1] // 2)),
    dict(
        type='MixUp',
        img_scale=img_scale,
        ratio_range=(0.8, 1.6),
        pad_val=114.0),
    dict(type='YOLOXHSVRandomAug'),
    dict(type='RandomFlip', flip_ratio=0.5),
    # According to the official implementation, multi-scale
    # training is not considered here but in the
    # 'mmdet/models/detectors/yolox.py'.
    dict(type='Resize', img_scale=img_scale, keep_ratio=True),
    dict(
        type='Pad',
        pad_to_square=True,
        # If the image is three-channel, the pad value needs
        # to be set separately for each channel.
        pad_val=dict(img=(114.0, 114.0, 114.0))),
    dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]

train_dataset = dict(
    type='MultiImageMixDataset',
    dataset=dict(
        type=dataset_type,
        classes=classes,
        ann_file=data_root + 'LIVECell_dataset_2021/train_8class.json',
        img_prefix=data_root +
        'LIVECell_dataset_2021/images/livecell_train_val_images',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations', with_bbox=True)
        ],
        filter_empty_gt=False,
    ),
    pipeline=train_pipeline)

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=img_scale,
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Pad',
                pad_to_square=True,
                pad_val=dict(img=(114.0, 114.0, 114.0))),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img'])
        ])
]

data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    persistent_workers=True,
    train=train_dataset,
    val=dict(
        type=dataset_type,
        classes=classes,
        ann_file=data_root + 'LIVECell_dataset_2021/val_8class.json',
        img_prefix=data_root +
        'LIVECell_dataset_2021/images/livecell_train_val_images',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        classes=classes,
        ann_file=data_root + 'LIVECell_dataset_2021/val_8class.json',
        img_prefix=data_root +
        'LIVECell_dataset_2021/images/livecell_train_val_images',
        pipeline=test_pipeline))

# optimizer
# default 8 gpu
optimizer = dict(
    type='SGD',
    lr=0.01 / 64,
    momentum=0.9,
    weight_decay=5e-4,
    nesterov=True,
    paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
optimizer_config = dict(grad_clip=None)

max_epochs = 100
num_last_epochs = 15
resume_from = None
interval = 1

# learning policy
lr_config = dict(
    # _delete_=True,
    policy='YOLOX',
    warmup='exp',
    by_epoch=False,
    warmup_by_epoch=True,
    warmup_ratio=1,
    warmup_iters=5,  # 5 epoch
    num_last_epochs=num_last_epochs,
    min_lr_ratio=0.05)

runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)

custom_hooks = [
    dict(
        type='YOLOXModeSwitchHook',
        num_last_epochs=num_last_epochs,
        priority=48),
    dict(
        type='SyncNormHook',
        num_last_epochs=num_last_epochs,
        interval=interval,
        priority=48),
    dict(
        type='ExpMomentumEMAHook',
        resume_from=resume_from,
        momentum=0.0001,
        priority=49)
]
checkpoint_config = dict(interval=interval)
evaluation = dict(
    save_best='auto',
    # The evaluation interval is 'interval' when running epoch is
    # less than ‘max_epochs - num_last_epochs’.
    # The evaluation interval is 1 when running epoch is greater than
    # or equal to ‘max_epochs - num_last_epochs’.
    interval=interval,
    dynamic_intervals=[(max_epochs - num_last_epochs, 1)],
    metric='bbox')
log_config = dict(
    interval=100,
    hooks=[
        dict(type='TextLoggerHook')
    ])

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = data_root + 'checkpoints/yolox_x_coco.pth'
workflow = [('train', 1)]

# # disable opencv multithreading to avoid system being overloaded
# opencv_num_threads = 0
# set multi-process start method as `fork` to speed up the training
mp_start_method = 'fork'

바꾼 부분을 생각나는 대로 적으면 다음과 같다.

  • label 데이터, 이미지 등 각종 경로들
  • load_from에 pretrained 모델 경로 넣어줌
  • image size
  • optimizer의 learning rate를 기존보다 작게 바꿔줌 (fine tuning이기 때문)

 

이제 mmdetection을 설치하고, 다음 명령어를 입력하여 학습을 시작한다.

pip install openmim
mim install mmdet
python utils/detection/train.py configs/detection/yolox_x_livecell.py​

 

모델은 10epoch에서 가장 높은 성능을 보였고 evaluation 결과는 다음과 같다.

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.230
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=1000 ] = 0.690
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=1000 ] = 0.409
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.375
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.298
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.246
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.275
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 ] = 0.429
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=1000 ] = 0.534
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.511
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.531
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.626

 

 

Kaggle YOLOX 학습

결론부터 말하자면 나의 경우 LIVECell pretrained 모델을 사용하는 것보다 coco pretrained에서 바로 kaggle 데이터를 학습하는 것이 더 성능이 좋았다. 그러므로 포스팅도 그에 맞춰 진행한다.

 

config 파일은 LIVECell config 파일에서 데이터셋의 경로들만 바꿔주면 되기 때문에 생략한다. 코드는 최상단 전체 코드 링크의 configs/detection/yolox_x_kaggle.py에서 확인할 수 있다.

 

다음 명령어를 입력해 학습을 시작하도록 하자.

python utils/detection/train.py configs/detection/yolox_x_kaggle.py

 

모델의 성능은 15epoch에서 가장 높았고 evaluation 결과는 다음과 같다.

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.228
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=2000 ] = 0.631
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=2000 ] = 0.221
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=2000 ] = 0.250
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=2000 ] = 0.202
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=2000 ] = 0.379
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.319
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 ] = 0.407
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=2000 ] = 0.476
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=2000 ] = 0.449
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=2000 ] = 0.371
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=2000 ] = 0.627

2022-04-04 02:37:54,572 - mmdet - INFO - 
+----------+-------+----------+-------+----------+-------+
| category | AP    | category | AP    | category | AP    |
+----------+-------+----------+-------+----------+-------+
| shsy5y   | 0.227 | astro    | 0.260 | cort     | 0.372 |
+----------+-------+----------+-------+----------+-------+

 

Inference 결과 시각화

먼저 시각화 코드를 보자.

 

모델을 불러온다.

from mmdet.apis import init_detector, inference_detector

config_path = './work_dirs/yolox_x_kaggle/yolox_x_kaggle.py'
checkpoint_path = './work_dirs/yolox_x_kaggle/epoch_15.pth'

model = init_detector(config_path, checkpoint_path, device='cuda:0')

 

pycocotools 패키지를 이용하여 validation coco file인 dval_g0.json에서 이미지 파일명들을 가져온다. 그리고 이미지 폴더에서 모든 valid image들에 대해 inference를 진행하고 그에 대한 결과들을 result 리스트에 담는다.

센스있게 1장 당 inference time이 몇 초인지도 출력해본다.

import cv2
import os
import matplotlib.pyplot as plt
import pandas as pd
import time
from pycocotools.coco import COCO

valid_coco = COCO('./data/dval_g0.json')
valid_img_infos = valid_coco.loadImgs(valid_coco.getImgIds())
image_list = [v['file_name'] for v in valid_img_infos]

image_path = './data/train'

start = time.time()

# inference_detector의 인자로 string(file경로), ndarray가 단일 또는 list형태로 입력 될 수 있음. 
result = []
for image in image_list:
    result.append(inference_detector(model, os.path.join(image_path, image)))
    
end = time.time()
print(str((end-start)/len(image_list)) + ' sec per image')

###
0.12361730003356934 sec per image
###

 

inference 결과 bbox를 원본 사진에 시각화하여 저장한다. result/yolox_x_kaggle 폴더 안에 저장한다. 원본 사진과 같이 보면 편하므로 같이 저장한다.

import shutil

for i in range(len(image_list)):
    image = os.path.join(image_path, image_list[i])
    result_path = 'result/yolox_x_kaggle'

    # inference box가 그려진 이미지 저장
    model.show_result(image, result[i],
                      out_file=os.path.join(result_path, f'img_{i}_result.jpg'), score_thr=0.4)
    # 원본 이미지 저장
    shutil.copy(image, os.path.join(result_path, f'img_{i}_src.jpg'))

 

결과 예시 4장 (왼쪽 : 원본, 오른쪽 : 결과 시각화)

 

결과 고찰

원래는 상위 솔루션을 최대한 이해하고 그에 대한 구현을 내 스타일대로 한 후 학습까지 진행하여 상위 솔루션에 근접한 성능을 내보고 싶었다. 하지만 불가능하다는 것을 깨달았다. 왜냐하면 나는 장비가 딸리기 때문이다. (코랩 프로 사용)

 

cell 데이터같은 경우 이미지에서 detection(혹은 segmentation)을 해야하는 부분들이 매우 작다. 위의 결과 예시를 보면 알 수 있다. 그래서 이미지 크기를 크게 가져가는 것이 중요하다. 고성능을 내려면 이미지 크기를 크게 하며 학습의 안정도를 위하여 배치 사이즈도 4 정도는 해야하는데 내 장비로는 이미지의 크기가 1000x1000만 넘어가도 배치 사이즈가 1밖에 되지 않았고 1300이 넘어가면 그 조차도 돌아가지 않았다. 상위 솔루션을 만드신 분들은 1500x1500 이상, 배치 사이즈 4이상으로 학습하였다.

 

따라서 이번 대회의 목표는 최대 성능을 내는 것보단, 장비가 주어지면 언제든 그 성능을 뽑을 수 있게 상위 솔루션을 내 것으로 만드는 것으로 한다.

Comments