← Back to Topics

Semantic and Instance Segmentation

CMPUT 328 - Assignment 6 Study Guide

CMPUT 328 - Visual Recognition - Assignment 6

TABLE OF CONTENTS

  1. Introduction to Image Segmentation
  2. Semantic Segmentation
  3. Fully Convolutional Networks (FCN)
  4. Transposed Convolution (Upsampling)
  5. U-Net Architecture
  6. Instance Segmentation
  7. Mask R-CNN
  8. Training Strategies
  9. Loss Functions
  10. Evaluation Metrics
  11. Data Preparation and Labels
  12. Downsampling and Upsampling
  13. Skip Connections
  14. State-of-the-Art Models
  15. Implementation Guide
  16. Common Pitfalls and Solutions

1. Introduction to Image Segmentation

What is Image Segmentation?

Image segmentation is the task of partitioning an image into multiple segments or regions, where each pixel is assigned to a specific class or instance.

Why Image Segmentation?

Applications:

Types of Segmentation

1. Semantic Segmentation:

2. Instance Segmentation:

3. Panoptic Segmentation:


2. Semantic Segmentation

Definition

Semantic segmentation is classifying each pixel of an image into a category or class.

How It Works

Input: RGB Image (H × W × 3)
Output: Label Map (H × W)

Each pixel in the output corresponds to a class label.

Example

Input Image: Street scene
Classes: road, car, pedestrian, building, sky, tree
Output: Each pixel labeled with one of these classes

Challenges

  1. High computational cost: Processing every pixel
  2. Context understanding: Need both local and global information
  3. Boundary precision: Accurate segmentation at object edges
  4. Scale variation: Objects appear at different sizes
  5. Occlusion: Objects partially hidden by others

Differences from Classification

Aspect Classification Semantic Segmentation
Input Image Image
Output Single label Label per pixel
Spatial info Lost (after pooling) Preserved
Complexity Lower Higher

3. Fully Convolutional Networks (FCN)

Motivation

Traditional CNNs for classification:

Problem: We need spatial output (same size as input)!

FCN Architecture

Key idea: Replace fully connected layers with convolutional layers

Input Image (H × W × 3)
    ↓
[Conv + Pool] × N  ← Downsampling (encoder)
    ↓
[Conv Transpose] × M  ← Upsampling (decoder)
    ↓
Output Map (H × W × num_classes)

Why Use Fully Convolutional Architecture?

Advantages:

  1. Accepts any input size: No fixed-size requirement
  2. Preserves spatial information: Output has spatial structure
  3. Efficient: Shares computation across overlapping regions
  4. End-to-end training: Learn features and segmentation together

Converting FC to Conv

# Traditional classification network
x = x.view(x.size(0), -1)  # Flatten: [batch, C*H*W]
x = self.fc(x)              # FC layer

# Fully convolutional network
# No flattening!
x = self.conv(x)  # [batch, channels, h, w] preserved

Output Size Calculation

After multiple conv and pool operations:

Input: 32×32
After Conv (stride=1, padding=1): 32×32
After MaxPool (2): 16×16
After Conv (stride=1, padding=1): 16×16
After MaxPool (2): 8×8

Problem: Output is smaller than input! Solution: Upsampling operations


4. Transposed Convolution (Upsampling)

What is Transposed Convolution?

Also called "deconvolution" or "upsampling":

How It Works

Regular convolution:

Transposed convolution:

Mathematical Intuition

For a simple case:

Input = conv(output, kernel, stride)
Output = conv_transpose(input, kernel, stride)

Example

Input: 2×2
Kernel: 3×3
Stride: 2
Output: 5×5  (larger!)

Process:

  1. Insert zeros between input pixels (based on stride)
  2. Apply regular convolution
  3. Result is upsampled output

In PyTorch

nn.ConvTranspose2d(
    in_channels=64,
    out_channels=32,
    kernel_size=3,
    stride=2,
    padding=1,
    output_padding=1
)

# Input: [batch, 64, 16, 16]
# Output: [batch, 32, 32, 32]

Output Size Formula

output_size = (input_size - 1) × stride - 2 × padding + kernel_size + output_padding

Transposed Conv vs Bilinear Interpolation

Aspect Transposed Conv Bilinear Interpolation
Learnable Yes No
Parameters Many None
Quality Better (learned) Fixed
Artifacts Checkerboard possible Smooth

Best practice:


5. U-Net Architecture

What is U-Net?

U-Net is an important architecture for semantic segmentation, especially popular in medical imaging.

Architecture Structure

       INPUT (572×572)
           ↓
    ┌──────────────┐
    │   Encoder    │ ← Contracting path (downsampling)
    │   (Down)     │
    └──────┬───────┘
           ↓
    ┌──────────────┐
    │  Bottleneck  │ ← Lowest resolution
    └──────┬───────┘
           ↓
    ┌──────────────┐
    │   Decoder    │ ← Expanding path (upsampling)
    │    (Up)      │   + Skip connections →
    └──────┬───────┘
           ↓
      OUTPUT (388×388)

Key Features

1. Encoder-Decoder Structure:

2. Skip Connections:

3. Symmetric:

U-Net Block Structure

Encoder Block:

# Each encoder block
Conv(3×3)  BatchNorm  ReLU
Conv(3×3)  BatchNorm  ReLU
MaxPool(2×2)   Downsample

Decoder Block:

# Each decoder block
ConvTranspose(2×2)   Upsample
Concatenate with skip connection
Conv(3×3)  BatchNorm  ReLU
Conv(3×3)  BatchNorm  ReLU

Why U-Net Works

  1. Context + Localization: Encoder captures what, decoder determines where
  2. Skip connections: Preserve spatial information lost during downsampling
  3. Few parameters: Works well even with small datasets
  4. Data augmentation friendly: Can be trained with limited data

U-Net vs FCN

Feature U-Net FCN
Skip connections Yes (concat) Yes (add)
Structure Symmetric U Asymmetric
Best for Medical imaging, small data General segmentation
Parameters Moderate Varies

6. Instance Segmentation

What is Instance Segmentation?

Instance segmentation = Semantic segmentation + Object detection

Goal: Detect and segment each object instance separately

Difference from Semantic Segmentation

Semantic Segmentation:
- All cars labeled as "car"
- Cannot distinguish car₁ from car₂

Instance Segmentation:
- car₁, car₂, car₃ labeled separately
- Each instance has unique ID

Challenges

  1. Varying number of instances: Unknown how many objects in image
  2. Overlapping objects: Objects may occlude each other
  3. Different scales: Objects appear at different sizes
  4. Computational cost: More complex than semantic segmentation

Approaches

1. Detect-then-Segment:

2. Segment-then-Detect:

3. Bottom-up:


7. Mask R-CNN

Overview

Mask R-CNN extends Faster R-CNN by adding a branch for predicting segmentation masks on each Region of Interest (RoI).

Architecture

Input Image
    ↓
┌─────────────────────┐
│  Backbone (ResNet)  │ ← Feature extraction
└─────────┬───────────┘
          ↓
┌─────────────────────┐
│  Region Proposal    │ ← Propose object locations
│  Network (RPN)      │
└─────────┬───────────┘
          ↓
┌─────────────────────┐
│   RoI Align         │ ← Extract features for each proposal
└─────────┬───────────┘
          ↓
    ┌────┴────┐
    ↓         ↓         ↓
┌────────┐ ┌────────┐ ┌────────┐
│  Box   │ │ Class  │ │  Mask  │ ← Three parallel heads
│Regress │ │Predict │ │ Predict│
└────────┘ └────────┘ └────────┘

Key Components

1. Backbone:

2. Region Proposal Network (RPN):

3. RoI Align:

4. Three Heads:

Mask Head

# Mask head architecture
# For each RoI:
Conv(3×3) × 4   Feature extraction
ConvTranspose(2×2)   Upsample
Conv(1×1)   num_classes masks

# Output: [num_classes, 28, 28]
# One binary mask per class

RoI Align vs RoI Pool

RoI Pool (Faster R-CNN):

RoI Align (Mask R-CNN):

Training Mask R-CNN

Multi-task loss:

L_total = L_cls + L_box + L_mask

Where:
- L_cls: Classification loss (cross-entropy)
- L_box: Bounding box regression loss (smooth L1)
- L_mask: Mask loss (binary cross-entropy per pixel)

Important: Mask loss only computed for the true class (not all classes)

Mask R-CNN Performance

State-of-the-art results on COCO:


8. Training Strategies

Data Augmentation

Essential augmentations for segmentation:

  1. Random crop and resize
transforms.RandomResizedCrop(size, scale=(0.5, 2.0))
  1. Horizontal flip (must flip both image and mask!)
if random.random() > 0.5:
    image = transforms.functional.hflip(image)
    mask = transforms.functional.hflip(mask)
  1. Color jitter (image only)
transforms.ColorJitter(brightness=0.2, contrast=0.2)
  1. Random rotation (both image and mask)
angle = random.uniform(-10, 10)
image = transforms.functional.rotate(image, angle)
mask = transforms.functional.rotate(mask, angle)

CRITICAL: Always apply same geometric transformation to both image and mask!

Training Pipeline

for epoch in range(num_epochs):
    model.train()
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        # Forward pass
        outputs = model(images)

        # Compute loss
        loss = criterion(outputs, masks)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validate
    validate(model, val_loader)

Learning Rate Strategies

1. Warm-up then decay:

# Start small, increase, then decay
epochs: 0-5: linear increase
epochs: 5-50: cosine decay

2. Poly learning rate:

lr = base_lr × (1 - iter/max_iter)^power

3. Step decay:

# Reduce by factor at milestones
milestones = [30, 60, 90]
gamma = 0.1

Batch Size Considerations

Class Imbalance

Problem: Some classes appear much more than others (e.g., background vs rare object)

Solutions:

  1. Weighted loss: Weight rare classes higher
  2. Focal loss: Focus on hard examples
  3. Data sampling: Oversample rare classes

9. Loss Functions

Cross-Entropy Loss

Standard for semantic segmentation:

criterion = nn.CrossEntropyLoss()
loss = criterion(outputs, targets)

# outputs: [batch, num_classes, H, W]
# targets: [batch, H, W] (class indices)

Formula:

L_CE = -Σ_pixels Σ_classes y_c × log(p_c)

Weighted Cross-Entropy

For class imbalance:

# Compute class weights (inverse frequency)
class_weights = torch.tensor([0.5, 2.0, 1.0, ...])
criterion = nn.CrossEntropyLoss(weight=class_weights)

Dice Loss

Popular in medical imaging:

def dice_loss(pred, target):
    smooth = 1.0
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()

    dice = (2. * intersection + smooth) / (
        pred_flat.sum() + target_flat.sum() + smooth
    )

    return 1 - dice

Advantages:

Focal Loss

Focuses on hard examples:

def focal_loss(pred, target, alpha=0.25, gamma=2.0):
    ce_loss = F.cross_entropy(pred, target, reduction='none')
    p_t = torch.exp(-ce_loss)
    focal = alpha * (1 - p_t) ** gamma * ce_loss
    return focal.mean()

Why it works:

Combined Loss

Often best to combine losses:

total_loss = ce_loss + dice_loss
# or
total_loss = alpha * ce_loss + beta * dice_loss

10. Evaluation Metrics

Pixel Accuracy

Simplest metric:

accuracy = correct_pixels / total_pixels

Problem: Not good for class imbalance

Mean IoU (Intersection over Union)

Most common metric for segmentation:

IoU = (Prediction ∩ Ground Truth) / (Prediction ∪ Ground Truth)

Per class:

def compute_iou(pred, target, class_id):
    pred_mask = (pred == class_id)
    target_mask = (target == class_id)

    intersection = (pred_mask & target_mask).sum()
    union = (pred_mask | target_mask).sum()

    iou = intersection / (union + 1e-6)
    return iou

Mean IoU:

mean_iou = mean([iou_class_1, iou_class_2, ..., iou_class_n])

Dice Coefficient

Alternative to IoU:

Dice = 2 × |Prediction ∩ Ground Truth| / (|Prediction| + |Ground Truth|)

Relationship to IoU:

Dice = 2 × IoU / (1 + IoU)

For Instance Segmentation

Average Precision (AP):

COCO metrics:


11. Data Preparation and Labels

Label Format for Semantic Segmentation

Image-like format:

Training image: RGB (H × W × 3)
Label image: Grayscale (H × W)

Each pixel value = class index
Example:
- 0: background
- 1: car
- 2: person
- 3: road

Creating Label Images

import numpy as np
from PIL import Image

# Create label image (same size as input)
label = np.zeros((H, W), dtype=np.uint8)

# Fill in labels (example: all pixels in certain region)
label[100:200, 100:200] = 1  # Class 1 (e.g., car)
label[50:150, 50:150] = 2    # Class 2 (e.g., person)

# Save as image
Image.fromarray(label).save('label.png')

Dataset Structure

dataset/
├── images/
│   ├── train/
│   │   ├── img1.jpg
│   │   ├── img2.jpg
│   └── val/
│       ├── img3.jpg
└── labels/
    ├── train/
    │   ├── img1.png
    │   ├── img2.png
    └── val/
        ├── img3.png

PyTorch Dataset

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        label_path = os.path.join(self.label_dir, self.images[idx])

        image = Image.open(img_path).convert("RGB")
        label = Image.open(label_path)

        if self.transform:
            image, label = self.transform(image, label)

        return image, label

Important Considerations

  1. Label images must be single-channel (grayscale)
  2. No compression artifacts: Save as PNG, not JPEG
  3. Same dimensions: Label must match image size
  4. Class indices start from 0
  5. Background typically class 0

12. Downsampling and Upsampling

Why Downsample?

Computational efficiency:

Larger receptive field:

Does NOT hurt segmentation accuracy when combined with upsampling!

Downsampling Methods

1. Max Pooling:

nn.MaxPool2d(kernel_size=2, stride=2)
# 32×32 → 16×16

2. Strided Convolution:

nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1)
# 32×32 → 16×16

Upsampling Methods

1. Transposed Convolution (see Section 4):

nn.ConvTranspose2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1)

2. Bilinear Interpolation:

nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

3. Bilinear + Conv (Recommended):

nn.Sequential(
    nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
    nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
)

Downsampling-Upsampling Pipeline

Input: 256×256
    ↓ Conv + Pool
    128×128
    ↓ Conv + Pool
    64×64
    ↓ Conv + Pool
    32×32  ← Bottleneck (most compressed)
    ↓ Upsample + Conv
    64×64
    ↓ Upsample + Conv
    128×128
    ↓ Upsample + Conv
Output: 256×256

13. Skip Connections

What are Skip Connections?

Direct connections from encoder to decoder at same resolution:

Encoder              Decoder

32×32 ─────────────→ 32×32 (concat or add)
  ↓                     ↑
16×16 ─────────────→ 16×16
  ↓                     ↑
 8×8  ─────────────→  8×8
  ↓                     ↑
 4×4  ←─ bottleneck

Why Skip Connections?

1. Preserve spatial details:

2. Better gradient flow:

3. Combine features:

Types of Skip Connections

1. Concatenation (U-Net style):

# Encoder output: [batch, 64, 32, 32]
# Decoder output: [batch, 64, 32, 32]
# Concatenate along channel dimension
skip = torch.cat([encoder_out, decoder_out], dim=1)
# Result: [batch, 128, 32, 32]

2. Addition (ResNet style):

# Both must have same channels
skip = encoder_out + decoder_out
# Result: [batch, 64, 32, 32]

Implementation

class UNetWithSkips(nn.Module):
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)      # 256×256
        e2 = self.enc2(e1)     # 128×128
        e3 = self.enc3(e2)     # 64×64
        e4 = self.enc4(e3)     # 32×32

        # Bottleneck
        b = self.bottleneck(e4)  # 16×16

        # Decoder with skip connections
        d4 = self.dec4(torch.cat([b, e4], dim=1))    # 32×32
        d3 = self.dec3(torch.cat([d4, e3], dim=1))   # 64×64
        d2 = self.dec2(torch.cat([d3, e2], dim=1))   # 128×128
        d1 = self.dec1(torch.cat([d2, e1], dim=1))   # 256×256

        return self.final(d1)

Skip Connections: With vs Without

Aspect Without Skips With Skips
Boundary quality Blurry Sharp
Training speed Slower Faster
Gradient flow Difficult Easy
Fine details Lost Preserved

Rule of thumb: Always use skip connections for segmentation!


14. State-of-the-Art Models

DeepLab v3+

Key innovations:

  1. Atrous (dilated) convolution: Enlarges receptive field without pooling
  2. Atrous Spatial Pyramid Pooling (ASPP): Multi-scale context
  3. Encoder-decoder with skip connections

Performance: State-of-the-art on PASCAL VOC, Cityscapes

PSPNet (Pyramid Scene Parsing Network)

Key innovation:

HRNet (High-Resolution Network)

Key innovation:

Advantage: Better for tasks requiring fine details

Swin-UNETR

Recent advancement:

SAM (Segment Anything Model)

Foundation model for segmentation:

Use cases:


15. Implementation Guide

Complete U-Net Implementation

import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=21):
        super(UNet, self).__init__()

        # Encoder
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = self.conv_block(1024, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self.conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self.conv_block(128, 64)

        # Final layer
        self.final = nn.Conv2d(64, num_classes, 1)

    def conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        # Bottleneck
        b = self.bottleneck(self.pool(e4))

        # Decoder
        d4 = self.upconv4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)

        d3 = self.upconv3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.upconv2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        return self.final(d1)

Training Loop

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0

    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)

        # Forward
        outputs = model(images)
        loss = criterion(outputs, masks)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

def evaluate(model, dataloader, device, num_classes):
    model.eval()
    total_iou = 0

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = outputs.argmax(dim=1)

            # Compute IoU
            iou = compute_mean_iou(preds, masks, num_classes)
            total_iou += iou

    return total_iou / len(dataloader)

def compute_mean_iou(pred, target, num_classes):
    ious = []
    for cls in range(num_classes):
        pred_cls = (pred == cls)
        target_cls = (target == cls)

        intersection = (pred_cls & target_cls).sum().float()
        union = (pred_cls | target_cls).sum().float()

        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append((intersection / union).item())

    # Mean IoU (ignoring NaN for classes not in ground truth)
    ious = [iou for iou in ious if not np.isnan(iou)]
    return np.mean(ious) if ious else 0.0

Complete Training Script

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Model
model = UNet(in_channels=3, num_classes=21).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', patience=5, factor=0.5
)

# Training
num_epochs = 50
best_iou = 0

for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    val_iou = evaluate(model, val_loader, device, num_classes=21)

    # Update learning rate
    scheduler.step(val_iou)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Val IoU: {val_iou:.4f}")

    # Save best model
    if val_iou > best_iou:
        best_iou = val_iou
        torch.save(model.state_dict(), 'best_unet.pt')
        print(f"Saved best model with IoU: {best_iou:.4f}")

16. Common Pitfalls and Solutions

Pitfall 1: Wrong Label Format

Wrong:

# Labels as one-hot encoded (H, W, num_classes)
labels = torch.zeros(H, W, num_classes)

Correct:

# Labels as class indices (H, W)
labels = torch.zeros(H, W, dtype=torch.long)
labels[...] = class_id  # 0 to num_classes-1

Pitfall 2: Augmentation Mismatch

Wrong:

# Augment image and mask separately
image = transform(image)  # Random crop/flip
mask = transform(mask)    # Different random crop/flip!

Correct:

# Apply SAME transformation to both
seed = np.random.randint(2**32)
random.seed(seed)
torch.manual_seed(seed)
image = transform(image)

random.seed(seed)
torch.manual_seed(seed)
mask = transform(mask)

Pitfall 3: Size Mismatch

Problem: Output size doesn't match input size

Debug:

print(f"Input: {x.shape}")
x = model(x)
print(f"Output: {x.shape}")

Solution: Adjust padding/stride in upsampling layers

Pitfall 4: Class Imbalance Ignored

Wrong:

# Treat all classes equally
criterion = nn.CrossEntropyLoss()

Correct:

# Weight rare classes higher
class_weights = compute_class_weights(train_dataset)
criterion = nn.CrossEntropyLoss(weight=class_weights)

Pitfall 5: Not Using Skip Connections

Wrong:

# Encoder-decoder without skips
x = encoder(x)
x = decoder(x)  # Blurry boundaries!

Correct:

# With skip connections
e1, e2, e3, e4 = encoder(x)
x = decoder(e1, e2, e3, e4)  # Sharp boundaries!

Pitfall 6: Wrong Evaluation Mode

Wrong:

# Not setting eval mode
model.train()  # or not calling eval()
with torch.no_grad():
    outputs = model(images)

Correct:

model.eval()  # CRITICAL!
with torch.no_grad():
    outputs = model(images)

Pitfall 7: Memory Issues

Problem: Out of memory with high-resolution images

Solutions:

  1. Reduce batch size
  2. Use gradient checkpointing
  3. Crop images into patches
  4. Use mixed precision training
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

with autocast():
    outputs = model(images)
    loss = criterion(outputs, masks)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

Pitfall 8: Ignoring Boundary Pixels

Problem: Predictions at image boundaries are poor

Solution: Use padding or valid convolutions carefully

Pitfall 9: Oversmoothing

Problem: Predictions are blurry

Causes:

Solutions:

Pitfall 10: Not Visualizing Predictions

Always visualize during training:

def visualize_predictions(model, image, mask):
    model.eval()
    with torch.no_grad():
        pred = model(image.unsqueeze(0))
        pred = pred.argmax(dim=1).squeeze(0)

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(image.permute(1, 2, 0))
    axes[0].set_title('Input')

    axes[1].imshow(mask)
    axes[1].set_title('Ground Truth')

    axes[2].imshow(pred.cpu())
    axes[2].set_title('Prediction')

    plt.show()

Summary

Key Takeaways

  1. Semantic segmentation classifies each pixel into a class
  2. Instance segmentation distinguishes individual object instances
  3. Fully convolutional networks preserve spatial information
  4. U-Net architecture with skip connections is highly effective
  5. Transposed convolution or bilinear interpolation for upsampling
  6. Skip connections preserve fine-grained details
  7. Mask R-CNN extends Faster R-CNN for instance segmentation
  8. Data augmentation must be applied to both image and mask
  9. Mean IoU is the standard evaluation metric
  10. Class imbalance requires weighted loss or focal loss

Typical Results

Task Dataset Metric Good Result
Semantic Seg PASCAL VOC mIoU 75-85%
Semantic Seg Cityscapes mIoU 75-80%
Instance Seg COCO AP 35-40%

Next Steps


References


End of Lesson

DOWNLOAD ANKI DECK