Facebook AIが開発したPyTorchベースの物体検出ライブラリ Detectron2 でデータ水増し

スターやコメントしていただけると励みになります。 また、記事内で間違い等ありましたら教えてください。

Detectron2でのデータ水増し

デフォルトでは訓練時のログからもわかるように2つの手法が使用されています。(実行結果は前の記事のものです)
今回の記事ではいろんな水増し手法を使う方法を紹介します。

f:id:yamayou_1:20210228183010p:plain

実行

モジュールのインポート

import torch, torchvision
print(torch.__version__, torch.cuda.is_available())

import torch
assert torch.__version__.startswith("1.7")

# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import matplotlib.pyplot as plt
import os, json, cv2, random

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog

# 画像を保存するディレクトリを作成
dirname_picture = "./pictures/"
os.makedirs(dirname_picture, exist_ok=True)

# 画像表示用の関数を定義
def cv2_imshow(img):
    plt.figure(figsize=(8, 8))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.imshow(img)
    plt.show()

独自のTrainerの定義

調べた中で一番簡単な実装は以下のようにDefaultTrainerを継承して、build_train_loaderを上書きすることです。
DatasetMapperは画像に対して一括で関数(ここではtrain_augmentations)を実行するものです。

とても簡単ですね。独自の関数を使用したい場合はAPI refを読んでくだい。

import detectron2.data.transforms as T
from detectron2.data import DatasetMapper, build_detection_train_loader

# ここに使用したい水増し手法を追加
train_augmentations = [
    T.RandomBrightness(0.5, 2),
    T.RandomContrast(0.5, 2),
    T.RandomSaturation(0.5, 2),
    T.RandomFlip(prob=0.5, horizontal=True, vertical=False),
    T.RandomFlip(prob=0.5, horizontal=False, vertical=True),
]

# DefaultTrainerを継承します
class AddAugmentationsTrainer(DefaultTrainer):
    """
    We use the "DefaultTrainer" which contains a number pre-defined logic for
    standard training workflow. They may not work for you, especially if you
    are working on a new research project. In that case you can use the cleaner
    "SimpleTrainer", or write your own training loop.
    
    ref: https://github.com/facebookresearch/detectron2/blob/master/projects/DeepLab/train_net.py
    """

    @classmethod
    def build_train_loader(cls, cfg):
        custom_mapper = DatasetMapper(cfg, is_train=True, augmentations=train_augmentations)
        return build_detection_train_loader(cfg, mapper=custom_mapper)

Config構築

あとは前の記事と同じです。今回は時間かかるので結果の検証までしません。

cfg = get_cfg()
cfg.merge_from_file(
    model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
)
cfg.DATASETS.TRAIN = ("fruits_nuts",)
cfg.DATASETS.TEST = ()  # no metrics implemented for this dataset
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.02
cfg.SOLVER.MAX_ITER = (
    300
)  # 300 iterations seems good enough, but you can certainly train longer
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = (
    128
)  # faster, and good enough for this toy dataset
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3  # 3 classes (data, fig, hazelnut)

# defaultだとcfgのDEVICE=cudaになっているので、cudaない場合はcpuに変更
cfg.MODEL.DEVICE = "cpu"

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = AddAugmentationsTrainer(cfg)

データの確認

data loaderを作成する際のログでもちゃんと適用できているのが確認できます。
水増し後のデータを確認してみます。

# create data loader
# code from https://www.kaggle.com/julienbeaulieu/detectron2-wheat-detection-eda-training-eval
train_data_loader = trainer.build_train_loader(cfg)
data_iter = iter(train_data_loader)

f:id:yamayou_1:20210228182403p:plain

# visualization
from detectron2.data import detection_utils as utils

for batch in data_iter:
    # batchの画像数に応じて変更必要あり
    for i, per_image in enumerate(batch[:2]):

        # Pytorch tensor is in (C, H, W) format
        img = per_image["image"].permute(1, 2, 0).cpu().detach().numpy()
       img = utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT)

        visualizer = Visualizer(img, metadata=fruits_nuts_metadata, scale=0.5)

        target_fields = per_image["instances"].get_fields()
        labels = None
        vis = visualizer.overlay_instances(
            labels=labels,
            boxes=target_fields.get("gt_boxes", None),
            masks=target_fields.get("gt_masks", None),
            keypoints=target_fields.get("gt_keypoints", None),
        )
        cv2_imshow(vis.get_image()[:, :, ::-1])

f:id:yamayou_1:20210228182731p:plain

参考URL