나는 문어~ 꿈을 꾸는 문어~
[Image Retrieval] 캐글 - Shopee (3) 유사도 검색 시각화 본문
대회 링크 : https://www.kaggle.com/c/shopee-product-matching
코드 링크 : https://github.com/euiraekim/kaggle-shopee/blob/main/visualization.ipynb
목표
이전 포스팅에서 ResNet18을 ArcFace Loss로 학습했다. 학습 결과 5 epoch에서 mean f1 score가 약 0.75가 나오며 가장 높은 성능을 보였고 이 때 threshold는 0.5였다.
이번 포스팅에서는 학습에 사용하지 않은 validation data로 실제로 검색을 해보고 입력 이미지와 검색 결과 이미지들을 시각화하여 살펴볼 것이다. faiss나 annoy와 같은 이러한 유사도 검색 과정을 빠르고 간편하게 도와주는 라이브러리들이 있지만 데이터 수가 그리 많지 않으므로 사용하지 않고 for문을 사용하여 구현해보겠다.
구현 코드
먼저 다음 단계를 거쳐서 모든 valid 이미지에 대한 임베딩을 구해야한다.
- 데이터셋을 Group K Fold로 나눠 valid set을 구한다.
- 모델을 정의하고 학습된 가중치를 로드하여 적용한다.
- data loader을 이용하여 모든 이미지에 대한 임베딩을 구한다.
이는 이전 포스팅에서 대부분 구현해놓았기 때문에 필요한 부분들만 가져와서 구현했다.
import pandas as pd
import numpy as np
import os
import cv2
import random
from tqdm.notebook import tqdm
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import albumentations
import timm
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.model_selection import GroupKFold
from warnings import filterwarnings
filterwarnings("ignore")
device = torch.device('cuda')
image_size = 512
batch_size = 16
n_worker = 4
init_lr = 3e-4
n_epochs = 6
fold_id = 0
thres = 0.5
# 0.3부터 1까지 0.1간격으로 threshold를 검증하기 위해 만든 배열
search_space = np.arange(0.3, 1, 0.1)
backbone_name = 'resnet18'
weight_dir = './weights/resnet18_512_epoch5.pth'
data_dir = './data/'
df_train_all = pd.read_csv(os.path.join(data_dir, 'train.csv'))
df_train_all['file_path'] = df_train_all.image.apply(lambda x: os.path.join(data_dir, 'train_images', x))
gkf = GroupKFold(n_splits=5)
df_train_all['fold'] = -1
for fold, (train_idx, valid_idx) in enumerate(gkf.split(df_train_all, None, df_train_all.label_group)):
df_train_all.loc[valid_idx, 'fold'] = fold
df_train = df_train_all[df_train_all['fold'] != fold_id]
df_valid = df_train_all[df_train_all['fold'] == fold_id]
transforms_valid = albumentations.Compose([
albumentations.Resize(image_size, image_size),
albumentations.Normalize()
])
class SHOPEEDataset(Dataset):
def __init__(self, df, mode, transform=None):
self.df = df.reset_index(drop=True)
self.mode = mode
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, index):
row = self.df.loc[index]
img = cv2.imread(row.file_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self.transform is not None:
res = self.transform(image=img)
img = res['image'].transpose(2,0,1)
if self.mode == 'test':
return torch.tensor(img).float()
else:
return torch.tensor(img).float(), torch.tensor(row.label_group)
class ArcFaceClassifier(nn.Module):
def __init__(self, in_features, output_classes):
super().__init__()
self.W = nn.Parameter(torch.Tensor(in_features, output_classes))
nn.init.kaiming_uniform_(self.W)
def forward(self, x):
x_norm = F.normalize(x)
W_norm = F.normalize(self.W, dim=0)
return x_norm @ W_norm
class ResnetArcFace(nn.Module):
def __init__(self):
super().__init__()
self.backbone = timm.create_model(backbone_name, pretrained=True)
embedding_size = self.backbone.get_classifier().in_features
self.after_conv=nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.BatchNorm1d(embedding_size))
self.classifier = ArcFaceClassifier(embedding_size, df_train.label_group.nunique())
def forward(self, x, output_embs=False):
embeddings = self.after_conv(self.backbone.forward_features(x))
if output_embs:
return F.normalize(embeddings)
return self.classifier(embeddings)
model = ResnetArcFace()
model.to(device);
# 학습된 가중치 적용
model.load_state_dict(torch.load(weight_dir))
def get_embeddings(data_loader):
model.eval()
embs = []
with torch.no_grad():
for batch_idx, (images) in enumerate(tqdm(data_loader)):
images = images.to(device)
features = model(images, output_embs=True)
embs += [features.detach().cpu()]
embs = torch.cat(embs).cpu().numpy()
return embs
dataset_valid = SHOPEEDataset(df_valid, 'test', transform = transforms_valid)
valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers = n_worker)
embs = get_embeddings(valid_loader)
embs.shape
"""
(6851, 512)
"""
위 코드들을 열심히 돌려 마침내 6851개의 이미지에 대하여 각각 512 크기의 임베딩 벡터를 얻었다.
유사도 검색을 하기 앞서 이렇게 구한 embedding들을 dataframe에 새로운 컬럼을 만들어 넣어주자.
df_valid = df_valid.reset_index(drop=True)
df_valid['embs'] = embs.tolist()
아래 코드를 통해 유사도 검색을 하고 시각화까지 해본다.
출력 결과의 맨 왼쪽 열은 검색을 하기 위해 입력한 이미지고 나머지 열은 입력 이미지와 유사하다고 출력된 이미지다. 실제로는 가시성을 위해 아래 코드를 여러번 돌려 출력된 유사 이미지가 많은 것 위주로 가져왔다. 잘된 것만 가져온 것이 아닌데 저 정도면 상당히 좋고 실무에도 써먹을만 한 것 같다. (나도 놀람)
def show_image(file_path, title):
plt.title(title)
plt.axis('off')
img = cv2.imread(search_row.file_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img)
for i in random.sample(range(len(embs)), 10):
search_row = df_valid.iloc[i]
plt.figure(figsize=(18,3))
plt.subplot(1, 6, 1)
show_image(search_row.file_path, 'Input Image')
pred = []
for df_i, row in df_valid.iterrows():
# 검색 이미지와 동일한 id의 이미지는 제외한다.
if search_row['posting_id'] == row['posting_id']:
continue
cosine_sim = np.array(search_row['embs'])@np.array(row['embs']).T
if cosine_sim > thres:
pred.append((df_i, cosine_sim))
# 코사인 유사도를 기준으로 내림차순 정렬한다.
pred = sorted(pred, key=lambda x: x[1], reverse=True)
for j, (df_i, cosine_sim) in enumerate(pred):
# 하나의 이미지에 대하여 5개까지만 시각화한다.
if j == 5:
break
plt.subplot(1, 6, j+2)
show_image(df_valid.iloc[df_i].file_path, 'Searched Image')
plt.show()
'대회 & 프로젝트' 카테고리의 다른 글
[Image Retrieval] 캐글 - Shopee (2) 학습~검증 (0) | 2022.04.24 |
---|---|
[Image Retrieval] 캐글 - Shopee (1) 개요~EDA (0) | 2022.04.19 |
[Instance Segmentation] 캐글 - Sartorius (3) UPerNet 전처리~학습~시각화 (0) | 2022.04.06 |
[Instance Segmentation] 캐글 - Sartorius (2) YOLOX 학습 (0) | 2022.04.04 |
[Instance Segmentation] 캐글 - Sartorius (1) 데이터 전처리 (0) | 2022.04.03 |