관리 메뉴

나는 문어~ 꿈을 꾸는 문어~

[Image Retrieval] 캐글 - Shopee (1) 개요~EDA 본문

대회 & 프로젝트

[Image Retrieval] 캐글 - Shopee (1) 개요~EDA

harrykur139 2022. 4. 19. 19:14

대회 링크 : https://www.kaggle.com/c/shopee-product-matching

코드 링크 : https://github.com/euiraekim/kaggle-shopee/blob/main/EDA.ipynb

 

 

 

개요

Shopee는 동남아시아와 대만에서 서비스하는 전자 상거래 플랫폼이다. 우리나라로 치면 쿠팡이나 G마켓, 11번가 쯤을 생각하면 될 것 같다.

 

대회의 목적은 동일 상품 찾기다. 특정 제품에 대하여 판매자는 여러명일 수 있으므로, 서비스 측면에서 판매자가 올린 상품과 동일한 상품이 기존에 올라와 있는지 알 수 있어야한다. 그래야 소비자들이 동일 상품에 대하여 최저가를 잘 파악할 수 있고 그에 따라 서비스의 효용이 올라가기 때문이다.

 

본 대회에는 서비스에 올라온 상품들의 이미지와 상품의 제목이 데이터로 주어진다. 이 두 정보로 일치하는 상품을 찾기 위해서는 유사도를 통한 검색이 필요하다. 이미지와 상품의 제목과 같은 비정형 데이터의 유사도를 비교하기 위해서는 유사도 측정에 아주 적합한 특징을 뽑아내는 것이 매우 중요하다. 이를 아주 효율적으로 잘 할 수 있는 것이 딥러닝이다.

 

이 특징을 일반적으로 임베딩이라 부르고, 상위 솔루션들은 실제로 이미지와 제목으로 딥러닝 네트워크를 학습하여 임베딩을 얻고 이를 통해 유사도가 큰 상품을 찾는 방식을 사용했다.

 

실제로 높은 성능을 얻기 위해서는 이미지와 제목을 둘 다 사용해야 하겠지만, 나의 목표는 Image Retrieval(이미지를 통한 검색)이기 때문에 이미지만을 사용하여 학습을 진행하고자 한다.

 

 

EDA

데이터를 다운로드한 후 data 폴더에 넣고 분석해보자.

 

먼저 패키지를 임포트하자.

import os
import pandas as pd
import cv2

import matplotlib.pyplot as plt
import seaborn as sns

# pandas에서 progress bar 보는 설정
from tqdm.auto import tqdm as tqdmp
tqdmp.pandas()

# warning 무시
import warnings
warnings.simplefilter("ignore")

 

 

데이터 폴더 안에는 저 다섯 개의 파일 및 폴더가 들어있다. 학습에 필요한 train.csv와 train_images를 위주로 살펴보자.

data_path = 'data'
os.listdir(data_path)

"""
['test.csv',
 'train.csv',
 'train_images',
 'sample_submission.csv',
 'test_images']
 """

 

 

train.csv의 상위 5개 행을 출력해보자.

train_df = pd.read_csv(os.path.join(data_path, 'train.csv'))
train_df.head()

각 feature의 의미는 다음과 같다.

  • posting_id : 해당 포스팅의 고유 id
  • image : train_images 폴더에 들어있는 포스팅 이미지 파일 이름
  • image_phash : 이미지를 특정 길이의 16진수로 변환하여 넣어 놓은 것 같음 (사용하지 않을 예정)
  • title : 포스팅의 타이틀 (이미지 관련 프로젝트로 진행할 것이기 때문에 사용하지 않을 예정)
  • label_group : 상품의 label. 동일한 제품을 올린 포스팅인 경우 이 값이 같음

 

 

총 행 수는 34250으로 posting_id의 unique 개수와 같다. label_group은 총 11014개다. 전체 개수의 3분의1 정도인 것으로 보아 동일 제품은 보통 2~4개에 있을 것으로 예상해볼 수 있다.

print(f'train df shape {train_df.shape}')
print('train df posting_id unique values {}'.format(train_df['posting_id'].nunique()))
print('train df label_group unique values {}'.format(train_df['label_group'].nunique()))

"""
train df shape (34250, 5)
train df posting_id unique values 34250
train df label_group unique values 11014
"""

 

 

그렇다면 동일 제품이 개수가 어떻게 분포해 있는지 확인해보자.

sns.countplot(train_df['label_group'].value_counts())
plt.show()

대부분이 2~4개 정도이고 눈에 보이는 수치는 10개 이하에 몰려있다. 그 이상인 것도 있고 최대는 51개의 동일 제품이 있는 label도 있다.

 

 

이미지의 경로를 담는 path 컬럼을 추가하자.

train_df['path'] = data_path + "/train_images/" + train_df['image']

train_df.head()

 

 

이미지를 쭉 둘러보니 각각 다르지만 보통 정사각형의 이미지로 확인했다. 실제로 그런지 보자.

모든 행에 대하여 이미지를 가져와야하므로 시간이 좀 걸린다. 앞서 tqdm을 pandas에서 사용할 수 있게 설정해놓았으므로 progress_apply 함수를 이용하여 진행 상황을 지켜볼 수 있다.

train_df['img_shape'] = train_df['path'].progress_apply(lambda x: cv2.imread(x).shape)

약 5분만에 끝났다.

 

 

정사각형이 아닌 이미지도 있는지 확인해보자.

shapes = pd.DataFrame().from_records(train_df['img_shape'])
shapes.columns = ['Width', 'Height', 'Colors']

print('정사각형이 아닌 이미지 개수 : {}'.format(len(shapes[shapes['Width'] != shapes['Height']])))

"""
정사각형이 아닌 이미지 개수 : 118
"""
sns.set_style("white")
sns.jointplot(x = shapes.iloc[:, 0].astype('float32'), 
              y = shapes.iloc[:, 1].astype('float32'),
              height = 6, color = '#f15335')
plt.show()​

118개의 이미지만 정사각형이 아니다. 학습할 때 모두 같은 크기의 정사각형으로 resize해도 크게 문제 없겠다는 생각이 들었다.

 

 

동일 label_group에 있는 이미지를 함께 출력해서 살펴보자. label_group에서 약 5개를 랜덤하게 뽑아 출력했다.

import random

# 출력할 label 개수
num_sample = 5
label_samples = random.sample(list(train_df['label_group'].unique()), num_sample)

plt.figure(figsize = (15, 5*num_sample))
for i, label in enumerate(label_samples):
    for j, image_path in enumerate(train_df[train_df['label_group'] == label]['path']):
        # 동일 label의 개수가 3개가 넘더라도 3개까지만 출력
        if j == 3:
            break

        plt.subplot(len(label_samples), 3, 3*i+j+1)
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)    
        plt.imshow(img)
        plt.axis('off')

plt.show()

같은 행에 같은 제품을 최대 3개씩 출력했다. 아예 같아보이는 이미지도 있고 같은 제품인데 다른 형태로 올라와 있는 것들도 존재한다.

 

이들을 학습해 같은 제품이면 유사도가 큰 특징을 출력하는 모델을 만들어야 한다.

Comments