훈련된 MaskRCNN을 이용해 예측 파이썬 치트코드

2) mask rcnn the real-infer

Mask RCNN 훈련된 모델로 예측

  • 현재 훈련된 모델이 저장되어 있다면, 이 모델로 실제 Segmentation을 예측하는 과정
In [1]:
import os 
import sys
import random
import math
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd 
import glob
from sklearn.model_selection import KFold
from PIL import Image
import os.path
import glob

import skimage

from skimage import data, color
from skimage.transform import rescale, resize, downscale_local_mean

from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log

import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np

import tensorflow as tf

import traceback
Using TensorFlow backend.
In [ ]:
class ShapesConfig(Config):
    
    NAME = "shapes"

    GPU_COUNT = 1
    IMAGES_PER_GPU = 8

    NUM_CLASSES = 351  

    IMAGE_MIN_DIM = 256
    IMAGE_MAX_DIM = 256

    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)  # anchor side in pixels

    TRAIN_ROIS_PER_IMAGE = 32

    STEPS_PER_EPOCH = 100

    VALIDATION_STEPS = 5
    
config = ShapesConfig()
config.display()
In [3]:
class InferenceConfig(ShapesConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 8

inference_config = InferenceConfig()

MODEL_DIR = os.path.join("./model/")

추론 모드로 모델 로드

  • 모델은 보통 훈련시 자동으로 epoch당 저장하게 되어있다.
In [ ]:
model_inference = modellib.MaskRCNN(mode="inference", 
                          config=inference_config,
                          model_dir=MODEL_DIR)
In [ ]:
model_inference.load_weights("./model/shapes20190914T2042/mask_rcnn_shapes_0002.h5", by_name=True)
In [6]:
tests = os.listdir('../test')
results = []

예측

  • 8개씩 나눠서 예측
  • list안에 8개씩 짝이 있어 나중에 flatten이 필요
In [ ]:
def get_image_array(num):
    return np.array(Image.open('../test/' + tests[num]))

for i in tqdm(range(0, 100000, 8)):
    
    t = [ get_image_array(i+j) for j in range(0,8) ]
    
    results.append(model_inference.detect(t, verbose = 1))
In [8]:
import pickle

저장

  • 결과 데이터를 저장해 competition용 submission파일로 저장
In [9]:
# write python dict to a file
output = open('result_002_head_weight.pkl', 'wb')
pickle.dump(results, output)
output.close()

댓글 남기기