Поэкспериментируйте с проектом facebook maskrcnn-benchmark 2

глубокое обучение PyTorch
Поэкспериментируйте с проектом facebook maskrcnn-benchmark 2

maskrcnn-benchmarkЭто проект алгоритма эталонного тестирования Facebook с открытым исходным кодом, который включаетобнаружить,сегментацияиключевые точки человеческого телаалгоритм.

В этой серии две статьи:


тренироваться

Используйте maskrcnn-benchmark для обучения модели, вы можетеСсылаться на.

набор данных:

Выберите шаблон обучения:e2e_mask_rcnn_R_50_FPN_1x.yaml,в:

WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"  # 预训练权重
DATASETS:  # 数据集
  TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
  TEST: ("coco_2014_minival",)
MAX_ITER: 90000  # 最大训练轮次

Положение установки других параметров:maskrcnn_benchmark/config/defaults.py

как:

  • _C.SOLVER.CHECKPOINT_PERIOD = 2500, сохранить раунд;
  • _C.SOLVER.IMS_PER_BATCH = 16, обученныйbatch_size;
  • _C.OUTPUT_DIR = "./models", выходной путь модели;

Укажите количество графических процессоров:

export NGPUS=4

Обучите модель:

python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml"

nohup python -u -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" &

Выходная модель находится в./models, последняя модельmodel_0090000.pth.


контрольная работа

файл конфигурацииe2e_mask_rcnn_R_50_FPN_1x.my.yamlИзмените ВЕС на путь к модели, например ВЕС: "model_0090000.pth".

Логика теста следующая:

def main():
    img_path = os.path.join(DATA_DIR, 'aoa-mina.jpeg')

    img = cv2.imread(img_path)
    print('[Info] img size: {}'.format(img.shape))
    
    config_file = "../configs/e2e_mask_rcnn_R_50_FPN_1x.my.yaml"

    cfg.merge_from_file(config_file)  # 设置配置文件
    cfg.merge_from_list(["MODEL.MASK_ON", True])
    cfg.merge_from_list(["MODEL.DEVICE", "cpu"])  # 指定为CPU

    coco_demo = COCODemo(  # 创建模型文件
        cfg,
        # show_mask_heatmaps=True,
        min_image_size=800,
        confidence_threshold=0.7,
    )

    predictions = coco_demo.compute_prediction(img)
    top_predictions = coco_demo.select_top_predictions(predictions)
    show_mask(img, top_predictions, coco_demo)

    print('执行完成!')

show_maskЧтобы отобразить эффект сегментации, нарисуйте полигональную рамку на основе изображения matplotlib:

def show_mask(img, predictions, coco_demo):
    """
    显示分割效果
    :param img: numpy的图像 
    :param predictions: 分割结果
    :param coco_demo: 函数集
    :return: 显示分割效果
    """
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    pylab.rcParams['figure.figsize'] = (8.0, 10.0)  # 图片尺寸
    plt.imshow(img)  # 需要提前填充图像
    # plt.show()

    extra_fields = predictions.extra_fields
    masks = extra_fields['mask']
    labels = extra_fields['labels']
    name_list = [coco_demo.CATEGORIES[l] for l in labels]

    seg_list = []
    for mask in masks:
        mask = torch.squeeze(mask)
        segmentation = binary_mask_to_polygon(mask, tolerance=2)[0]
        seg_list.append(segmentation)

    ax = plt.gca()
    ax.set_autoscale_on(False)

    polygons, color = [], []

    np.random.seed(37)

    for name, seg in zip(name_list, seg_list):
        c = (np.random.random((1, 3)) * 0.8 + 0.2).tolist()[0]

        poly = np.array(seg).reshape((int(len(seg) / 2), 2))
        c_x, c_y = get_center_of_polygon(poly)  # 计算多边形的中心点

        # 0~26是大类别, 其余是小类别 同时 每个标签只绘制一次
        tc = c - np.array([0.5, 0.5, 0.5])  # 降低颜色
        tc = np.maximum(tc, 0.0)  # 最小值0

        plt.text(c_x, c_y, name, ha='left', wrap=True, color=tc,
                 bbox=dict(facecolor='white', alpha=0.5))  # 绘制标签

        polygons.append(pylab.Polygon(poly))  # 绘制多边形
        color.append(c)

    p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)  # 添加多边形
    ax.add_collection(p)
    p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)  # 添加多边形的框
    ax.add_collection(p)

    plt.axis('off')
    ax.get_xaxis().set_visible(False)  # this removes the ticks and numbers for x axis
    ax.get_yaxis().set_visible(False)  # this removes the ticks and numbers for y axis

    out_folder = os.path.join(ROOT_DIR, 'demo', 'out')
    mkdir_if_not_exist(out_folder)
    out_file = os.path.join(out_folder, 'test.png'.format())
    plt.savefig(out_file, bbox_inches='tight', pad_inches=0, dpi=200)

    plt.close()  # 避免所有图像绘制在一起

Эффект следующий:

Test


OK, that's all!