Instance Segmentation with YOLOv7 in Python

Introduction

Object detection is a large field in computer vision, and one of the more important applications of computer vision "in the wild". From it, instance segmentation was extracted, and is tasked with having models predict not only the label and bounding box of an object, but also the "area" it covers - classifying each pixel that belongs to that object.

Semantic Segmentation classifies all pixels in an image to their semantic label (car, pavement, building). Instance Segmentation classifies all pixels of each detected object individually, and Car1 is differentiated from Car2.

Conceptually - they're similar, but instance segmentation combines semantic segmentation and object detection. Thankfully, object detection, semantic segmentation and by extension instance segmentation can be done with a common back-end, with different heads of the network, as they're tasked with a conceptually similar task, and thus share computational representations of that knowledge.

Object detection, semantic segmentation, instance segmentation and key point detection aren't as standardized as image classification, mainly because most of the new developments are typically done by individual researchers, maintainers and developers, rather than large libraries and frameworks. It's difficult to package the necessary utility scripts in a framework like TensorFlow or PyTorch and maintain the API guidelines that guided the development so far.

Fortunately for the masses - Ultralytics has developed a simple, very powerful and beautiful object detection API around their YOLOv5 which has been extended by other research and development teams into newer versions, such as YOLOv7.

In this short guide, we'll be performing Instance Segmentation in Python, with state-of-the-art YOLOv7.

YOLO and Instance Segmentation

YOLO (You Only Look Once) is a methodology, as well as a family of models built for object detection. Since the inception in 2015, YOLOv1, YOLOv2 (YOLO9000) and YOLOv3 have been proposed by the same author(s) - and the deep learning community continued with open-sourced advancements in the continuing years.

Ultralytics' YOLOv5 is a massive repository, and the first production-level implementation of YOLO in PyTorch, which has seen major usage in the industry. The PyTorch implementation made it more accessible than ever before, which were usually done in C++, but the main reason it became so popular is because of the beautifully simple and powerful API built around it, which allows anyone that can run a few lines of Python code able to build object detectors.

YOLOv5 has become such a staple that most repositories that aim to advance the YOLO method use it as a basis and offer a similar API inherited from Ultralytics. YOLOR (You Only Learn One Representation) did exactly this, and YOLOv7 was built on top of YOLOR by the same authors.

YOLOv7 is the first YOLO model that ships with new model heads, allowing for key points, instance segmentation and object detection, which was a very sensible addition. Hopefully, going forward, we'll see an increasing number of YOLO-based models that offer similar capabilities out of the box.

This makes instance segmentation and key point detection faster to perform than ever before, with a simpler architecture than two-stage detectors.

YOLOv7 was released alongside a paper named "YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors", and the source code is available on GitHub.

The model itself was created through architectural changes, as well as optimizing aspects of training, dubbed "bag-of-freebies", which increased accuracy without increasing inference cost.

Instance Segmentation with YOLOv7

A standard library used for instance segmentation, object detection and key point estimation in Python is Detectron2, built by Meta AI.

The library offers various convenience methods and classes to help visualize results beautifully, but the underlying implementation for detection is a Mask R-CNN. YOLO has been shown to outperform R-CNN-based models across the board. The YOLOv7 repository is Detectron2-compatible and is compliant with its API and visualization tools, making it easier to run fast, accurate instance segmentation without having to learn a new API. You can, in effect, swap out the Mask R-CNN backbone and replace it with YOLOv7.

Advice: If you'd like to read more about Detectron2 - read our "Object Detection and Instance Segmentation in Python with Detectron2"!

Installing Dependencies - YOLOv7 and Detectron2

Let's first go ahead and install the dependencies. We'll clone the GitHub repo for the YOLOv7 project, and install the latest Detectron2 version via pip:

! git clone -b mask https://github.com/WongKinYiu/yolov7.git
! pip install pyyaml==5.1
! pip install 'git+https://github.com/facebookresearch/detectron2.git'

Detectron2 requires pyyaml as well. To ensure compatibility, you'll also want to specify the running torch version:

! pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html

The main branch of YOLOv7 doesn't support instance segmentation, as it has a dependency on a third-party project. However, the mask branch was made exactly for this support, so we're installing the mask branch of the project. Finally, you'll want to download the pre-trained weights for the instance segmentation model either manually or with:

%cd yolov7
! curl -L https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7-mask.pt -o yolov7-mask.pt

We've first moved into the yolov7 directory (the downloaded directory containing the project) and then downloaded the weights file there. With that - our dependencies are set up! Let's import the packages and classes we'll be using:

import matplotlib.pyplot as plt
import torch
import cv2
import yaml
from torchvision import transforms
import numpy as np

from utils.datasets import letterbox
from utils.general import non_max_suppression_mask_conf

from detectron2.modeling.poolers import ROIPooler
from detectron2.structures import Boxes
from detectron2.utils.memory import retry_if_cuda_oom
from detectron2.layers import paste_masks_in_image

Instance Segmentation Inference with YOLOv7

Let's first take a look at the image we'll be segmenting:

street_img = cv2.imread('../street.png')
street_img = cv2.cvtColor(street_img, cv2.COLOR_BGR2RGB)

fig = plt.figure(figsize=(12, 6))
plt.imshow(street_img)

It's a screen shot from the live view of Google Maps! Since the model isn't pre-trained on many classes, we'll likely only see semantic segmentation for classes like 'person', 'car', etc. without "fine-grained" classes like 'traffic light'.

We can now get to loading the model and preparing it for inference. The hyp.scratch.mask.yaml file contains configurations for hyperparameters, so we'll initially load it in, check for the active device (GPU or CPU), and load the model from the weights file we just downloaded:

Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

with open('data/hyp.scratch.mask.yaml') as f:
    hyp = yaml.load(f, Loader=yaml.FullLoader)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def load_model():
    model = torch.load('yolov7-mask.pt', map_location=device)['model']
    # Put in inference mode
    model.eval()

    if torch.cuda.is_available():
        # half() turns predictions into float16 tensors
        # which significantly lowers inference time
        model.half().to(device)
    return model

model = load_model()

Next, let's create a helper method to run inference! We'll want it to read an image, reshape it and pad it to the expected input size, apply transforms, batch it and pass it into the model:

def run_inference(url):
    image = cv2.imread(url) # shape: (480, 640, 3)
    # Resize and pad image
    image = letterbox(image, 640, stride=64, auto=True)[0] # shape: (480, 640, 3)
    # Apply transforms
    image = transforms.ToTensor()(image) # torch.Size([3, 480, 640])
    # Match tensor type (`torch.FloatTensor` -> `torch.HalfTensor`) with model
    image = image.half().to(device)
    # Turn image into batch
    image = image.unsqueeze(0) # torch.Size([1, 3, 480, 640])
    output = model(image)
    return output, image

output, image = run_inference('../street.png')

The function returns the output of the model, as well as the image itself (loaded, padded and otherwise processed). The output is a dictionary:

output.keys()
# dict_keys(['mask_iou', 'test', 'attn', 'bbox_and_cls', 'bases', 'sem'])

The predictions that the model made are raw - we'll need to pass them through non_max_supression(), and utilize the ROIPooler from Detectron2.

Note: "ROI Pooling" is short for "Region of Interest Pooling" and is used to extract small feature maps for object detection and segmentation tasks, in regions that may contain objects.

inf_out = output['test']
attn = output['attn']
bases = output['bases']
sem_output = output['sem']

bases = torch.cat([bases, sem_output], dim=1)
nb, _, height, width = image.shape
names = model.names
pooler_scale = model.pooler_scale

pooler = ROIPooler(output_size=hyp['mask_resolution'], 
                   scales=(pooler_scale,), 
                   sampling_ratio=1, 
                   pooler_type='ROIAlignV2', 
                   canonical_level=2)
                   
# output, output_mask, output_mask_score, output_ac, output_ab
output, output_mask, _, _, _ = non_max_suppression_mask_conf(inf_out, 
                                                             attn, 
                                                             bases, 
                                                             pooler, 
                                                             hyp, 
                                                             conf_thres=0.25, 
                                                             iou_thres=0.65, 
                                                             merge=False, 
                                                             mask_iou=None)                 

Here - we've obtained the predictions for objects and their labels in output and the masks that should cover those objects in output_mask:

output[0].shape # torch.Size([30, 6])
output_mask[0].shape # torch.Size([30, 3136])

The model found 30 instances in the image, each with a label associated with them. Let's create boxes for our instances with the help of Detectron2's Boxes class and condense the pred_masks (which contain a boolean mask) into a set of pixels that we can apply over the original image:

pred, pred_masks = output[0], output_mask[0]
base = bases[0]
bboxes = Boxes(pred[:, :4])

original_pred_masks = pred_masks.view(-1, 
                                      hyp['mask_resolution'], 
                                      hyp['mask_resolution'])

pred_masks = retry_if_cuda_oom(paste_masks_in_image)(original_pred_masks, 
                                                     bboxes, 
                                                     (height, width), 
                                                     threshold=0.5)
                                                     
# Detach Tensors from the device, send to the CPU and turn into NumPy arrays
pred_masks_np = pred_masks.detach().cpu().numpy()
pred_cls = pred[:, 5].detach().cpu().numpy()
pred_conf = pred[:, 4].detach().cpu().numpy()
nimg = image[0].permute(1, 2, 0) * 255
nimg = nimg.cpu().numpy().astype(np.uint8)
nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)
nbboxes = bboxes.tensor.detach().cpu().numpy().astype(np.int)

The original_pred_masks denotes the predicted masks for the original image:

original_pred_masks.shape # torch.Size([30, 56, 56])

And finally, we can plot the results with:

def plot_results(original_image, pred_img, pred_masks_np, nbboxes, pred_cls, pred_conf, plot_labels=True):
  for one_mask, bbox, cls, conf in zip(pred_masks_np, nbboxes, pred_cls, pred_conf):
    if conf < 0.25:
        continue
    color = [np.random.randint(255), np.random.randint(255), np.random.randint(255)]

    pred_img = pred_img.copy()
                             
    # Apply mask over image in color
    pred_img[one_mask] = pred_img[one_mask] * 0.5 + np.array(color, dtype=np.uint8) * 0.5
    # Draw rectangles around all found objects
    pred_img = cv2.rectangle(pred_img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)

    if plot_labels:
      label = '%s %.3f' % (names[int(cls)], conf)
      t_size = cv2.getTextSize(label, 0, fontScale=0.1, thickness=1)[0]
      c2 = bbox[0] + t_size[0], bbox[1] - t_size[1] - 3
      pred_img = cv2.rectangle(pred_img, (bbox[0], bbox[1]), c2, color, -1, cv2.LINE_AA)
      pred_img = cv2.putText(pred_img, label, (bbox[0], bbox[1] - 2), 0, 0.5, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA)  

  fig, ax = plt.subplots(1, 2, figsize=(pred_img.shape[0]/10, pred_img.shape[1]/10), dpi=150)

  original_image = np.moveaxis(image.cpu().numpy().squeeze(), 0, 2).astype('float32')
  original_image = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)
  
  ax[0].imshow(original_image)
  ax[0].axis("off")
  ax[1].imshow(pred_img)
  ax[1].axis("off")

The image is copied so we don't apply transformations to the image in-place, but on a copy. For each pixel that matches between the input image and the predicted masks, we apply a color with an opacity of 0.5 and for each object, we draw a cv2.Rectangle() that encompasses it from the bounding boxes (bbox). If you wish to plot labels, for which there might be significant overlap, there's a plot_labels flag in the plot_results() method signature. Let's try plotting the image we've started working with earlier with and without labels:

%matplotlib inline
plot_results(image, nimg, pred_masks_np, nbboxes, pred_cls, pred_conf, plot_labels=False)
%matplotlib inline
plot_results(image, nimg, pred_masks_np, nbboxes, pred_cls, pred_conf, plot_labels=True)

We've plotted both images - the original and the segmented image in one plot. For higher resolution, adjust the dpi (dots per inch) argument in the subplots() call, and plot just the image with the predicted segmentation map/labels to occupy the figure in its entirety.

Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

Project

Real-Time Road Sign Detection with YOLOv5

# computer vision# machine learning# deep learning# python

If you drive - there's a chance you enjoy cruising down the road. A responsible driver pays attention to the road signs, and adjusts their...

David Landup
David Landup
Details
20% off
Course

Practical Deep Learning for Computer Vision with Python

# deep learning# pytorch# yolo# keras

DeepDream with TensorFlow/Keras Keypoint Detection with Detectron2 Image Captioning with KerasNLP Transformers and ConvNets Semantic Segmentation with DeepLabV3+ in Keras Real-Time Object Detection from...

David Landup
Jovana Ninkovic
Details

© 2013-2024 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms