RT-DETR A Faster Alternative to YOLO for Real-Time Object Detection (with Code)
Nov. 1, 2024, 5:26 p.m.
Object detection has always faced a major challenge — balancing speed and accuracy. Traditional models like YOLO have been fast but required a post-processing step called Non-Maximum Suppression (NMS), which slows down detection. NMS filters overlapping bounding boxes, but this introduces extra computation time, affecting the overall speed. This is where DETR (https://arxiv.org/pdf/2005.12872) (DEtection TRansformer) comes in.
RT-DETR is an end-to-end object detector based on the DETR architecture, which completely eliminates the need for NMS. By doing so, RT-DETR significantly reduces the latency seen in previous object detectors based on convolutional neural networks (CNN), like the YOLO series. It brings together a combination of a powerful backbone, a hybrid encoder, and a unique query selector to process features quickly and accurately.
Overview of RT-DETR architecutre, source: https://arxiv.org/pdf/2304.08069
Key Components of RT-DETR architecture
- Backbone:
The backbone extracts features from the input image, with the most common configurations using ResNet-50 or ResNet-101. From the backbone, RT-DETR extracts features at three levels — S3, S4, and S5. These multi-scale features help the model understand both high-level and fine-grained details of the image.
Residual building block, source: https://arxiv.org/pdf/1512.03385
2. Hybrid Encoder:
Hybrid Encoder
The hybrid encoder has two main parts: the Attention-based Intra-scale Feature Interaction (AIFI) and the Cross-scale Feature Fusion (CCFF). Here’s how each part works:
- AIFI: This layer processes only the S5 feature map from the backbone. Since S5 represents the deepest layer, it contains the richest semantic information about complete objects and their context in the image — making it ideal for transformer-based attention to capture meaningful relationships between objects. While S3 and S4 contain useful low-level features like edges and object parts, applying attention to these layers would be computationally expensive and less effective since they don’t yet represent complete objects that need to be related to each other.
The following is the PyTorch implementation of AIFI. The code is structured into three main components:
- A feed-forward layer (FFLayer) that processes features after attention
- A transformer encoder layer that implements multi-head self-attention
- The main AIFI module that ties everything together with proper feature projection and positional encoding
import torch
import torch.nn as nn
class FFLayer(nn.Module):
'''
Feed-forward network: two linear layers with a non-linear activation
in between
'''
def __init__(self,
embedd_dim:int,
ffn_dim:int,
dropout:float,
activation:str) -> nn.Module:
super().__init__()
self.feed_forward = nn.Sequential(
nn.Linear(embedd_dim, ffn_dim),
getattr(nn, activation)(),
nn.Dropout(dropout),
nn.Linear(ffn_dim, embedd_dim)
)
self.norm = nn.LayerNorm(embedd_dim)
def forward(self,x:torch.Tensor)->torch.Tensor:
residual = x
x = self.feed_forward(x)
out = self.norm(residual + x)
return out
class TransformerEncoderLayer(nn.Module):
def __init__(self,
hidden_dim:int,
n_head:int,
ffn_dim:int,
dropout:float,
activation:str = "ReLU"):
super().__init__()
self.mh_self_attn = nn.MultiheadAttention(hidden_dim, n_head, dropout, batch_first=True)
self.feed_foward_nn = FFLayer(hidden_dim,ffn_dim,dropout,activation)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self,x:torch.Tensor,mask:torch.Tensor=None,pos_emb:torch.Tensor=None) -> torch.Tensor:
# Save the input as the residual for the skip connection
residual = x
# Add positional embeddings (if provided) to the input for spatial encoding
q = k = x + pos_emb if pos_emb is not None else x
# Apply multi-head self-attention
x, attn_score = self.mh_self_attn(q, k, value=x, attn_mask=mask)
# Add residual connection and apply dropout
x= residual + self.dropout(x)
# Normalize the result to avoid distribution shift during training
x = self.norm(x)
x = self.feed_foward_nn(x)
return x
class AIFI(nn.Module):
def __init__(self):
super().__init__()
# Configuration
self.hidden_dim = 256
self.num_layers = 1
self.num_heads = 8
self.dropout = 0.0
self.eval_spatial_size = [640,640]
self.pe_temperature = 10000
# Projection layer
self.projection = nn.Linear(config.input_dim, self.hidden_dim)
#position embedding
pos_embed = self.build_2d_sincos_position_embedding(
self.eval_spatial_size[1] // 32, self.eval_spatial_size[0] // 32,
self.hidden_dim, self.pe_temperature)
setattr(self, 'pos_embed', pos_embed)
# Transformer encoder layer
self.transformer_encoder= TransformerEncoderLayer(
hidden_dim = self.hidden_dim,
n_head = self.num_heads,
ffn_dim = 1024,
dropout = self.dropout,
activation = 'GELU'
)
@staticmethod
def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.):
'''
source code from: https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py
Build 2D sinusoidal positional embeddings.
Args:
w, h: spatial dimensions of the position embedding
embed_dim: dimension of the position embeddings
temperature: scaling factor for the position encoding
Returns:
pos_embed: [1, H*W, embed_dim] position embedding
'''
grid_w = torch.arange(int(w), dtype=torch.float32)
grid_h = torch.arange(int(h), dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
assert embed_dim % 4 == 0, \
'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature ** omega)
out_w = grid_w.flatten()[..., None] @ omega[None]
out_h = grid_h.flatten()[..., None] @ omega[None]
return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
def forward(self, s5:torch.Tensor):
# s5: Input tensor of shape [batch_size, C, H, W] from S5 feature map
B,_,H,W = s5.size()
# Flatten the feature map to [batch_size, H*W, C] for processing by the transformer
s5 = s5.flatten(2).permute(0, 2, 1)
# Project input to hidden dimension
x = self.projection(s5)
# Add positional embeddings to the input features
if self.training:
pos_embed = self.build_2d_sincos_position_embedding(
w=W,
h=H,
embed_dim=self.hidden_dim
).to(x.device)
else:
pos_embed = getattr(self, 'pos_embed', None).to(x.device)
# Apply transformer encoder
x = self.transformer_encoder(x, pos_emb = pos_embed)
return x
- CCFF: This part is based on CNN layers and takes care of combining the S3, S4, and AIFI(S5) feature maps. Its job is to fuse features from different scales, ensuring that the final feature map has information at multiple resolutions. The core of CCFF is a special block called RepBlock,which uses a structure known as RepConv .
RepConv (https://arxiv.org/pdf/2101.03697) allows the network to convert between different forms of convolution operations during training and inference, making it more efficient without sacrificing performance.
CCFF on the left and Fusion Block in CCFF on the right (source: https://arxiv.org/pdf/2304.08069)
The following is the PyTorch implementation of CCFF. The code is structured into four main components:
- The RepVGG block based on the original implementation.
- A FusionBlock Responsible for merging features across different scales
- FPNBlock and PANBlock for top-down and bottom-up multi-scale features processing respectively
- The main CCFF (Cross-Channel Feature Fusion) that combine both top-down FPN and bottom-up PAN to aggregate features.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self,
in_channels:int,
out_channels:int,
kernel_size:int,
stride:int):
super().__init__()
self.convblock = nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size,stride,bias=False),
nn.BatchNorm2d(out_channels),
nn.SiLU()
)
def forward(self,x):
return self.convblock(x)
class RepVggBlock(nn.Module):
def __init__(self, ch_in, ch_out, activation='ReLU'):
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.conv1 = nn.Sequential(
nn.Conv2d(ch_in, ch_out, 3, 1, padding=1),nn.BatchNorm2d(ch_out)
)
self.conv2 = nn.Sequential(
nn.Conv2d(ch_in, ch_out, 1, 1, padding=0),nn.BatchNorm2d(ch_out)
)
self.act = nn.Identity() if activation is None else getattr(nn,activation)
def forward(self, x):
if hasattr(self, 'conv'):
y = self.conv(x)
else:
y = self.conv1(x) + self.conv2(x)
return self.act(y)
def convert_to_deploy(self):
if not hasattr(self, 'conv'):
self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)
kernel, bias = self.get_equivalent_kernel_bias()
self.conv.weight.data = kernel
self.conv.bias.data = bias
def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return F.pad(kernel1x1, [1, 1, 1, 1])
def _fuse_bn_tensor(self, branch: ConvBlock):
if branch is None:
return 0, 0
kernel = branch.conv.weight
running_mean = branch.norm.running_mean
running_var = branch.norm.running_var
gamma = branch.norm.weight
beta = branch.norm.bias
eps = branch.norm.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
class FunsionBlock(nn.Module):
def __init__(self,
in_channels:int,
out_channels:int,
num_blocks:int,
expansion:float):
super().__init__()
hidden_channels = int(out_channels*expansion)
self.conv1 = ConvBlock(in_channels, hidden_channels, 1, 1)
self.conv2 = ConvBlock(in_channels, hidden_channels, 1, 1)
self.rep_blocks = nn.Sequential(*[
RepVggBlock(hidden_channels, hidden_channels, activation="SiLU") for _ in range(num_blocks)
])
if hidden_channels != out_channels:
self.conv3 = ConvBlock(hidden_channels, out_channels, 1, 1)
else:
self.conv3 = nn.Identity()
def forward(self,x):
out = self.conv1(x)
out = self.rep_blocks(out)
x = self.conv2(x)
return self.conv3(out + x)
class FPNBlock(nn.Module):
'''
FPNBlock performs top-down processing of features with upsampling and fusion
'''
def __init__(self,
hidden_dim:int,
in_channels:list[int],
depth_mult:float,
expansion:float):
super().__init__()
self.hidden_dim = hidden_dim
self.in_channels = in_channels
self.top_down_blocks = nn.ModuleList()
for _ in range(len(self.in_channels)-1,0,-1):
self.top_down_blocks.append(
ConvBlock(self.hidden_dim,self.hidden_dim,1,1)
)
self.top_down_blocks.append(
FunsionBlock(hidden_dim * 2, hidden_dim, num_blocks= round(3 * depth_mult), expansion=expansion)
)
def forward(self,features:list[torch.Tensor]) -> list[torch.Tensor]:
# features = [aifi_s5,s4,s3]
outputs = [features[0]]
for n,m in enumerate(self.top_down_blocks):
if n % 2 != 0: # FusionBlock
# Upsample previous feature map and concatenate with the next layer's features
fusion_out = m(torch.concat([F.interpolate(out, scale_factor=2., mode='nearest'),features[n // 2 + 1]],dim=1))
outputs.insert(0,fusion_out)
else: # ConvBlock
out = m(outputs[0])
outputs[0] = out
'''
outputs = [
FusionBlock( concat[ Upsample( ConvBlock(FusionBlock( concat[ Upsample(ConvBlock(aifi_s5)), S4] )), S3] ),
ConvBlock(FusionBlock( concat[ Upsample(ConvBlock(aifi_s5)), S4] ),
ConvBlock(aifi_s5)
]
'''
return outputs
class PANBlock(nn.Module):
'''
PANBlock processes features in a bottom-up manner to refine them
'''
def __init__(self,
hidden_dim:int,
in_channels:list[int],
depth_mult:float,
expansion:float):
super().__init__()
self.hidden_dim = hidden_dim
self.in_channels = in_channels
self.bottom_up_blocks = nn.ModuleList()
for _ in range(len(self.in_channels)-1,0,-1):
self.bottom_up_blocks.append(
ConvBlock(self.hidden_dim,self.hidden_dim,3,2)
)
self.bottom_up_blocks.append(
FunsionBlock(hidden_dim * 2, hidden_dim, num_blocks= round(3 * depth_mult), expansion=expansion)
)
def forward(self,fpn_features:list[torch.Tensor]) -> list[torch.Tensor]:
outputs = [fpn_features[0]]
for n,m in enumerate(self.bottom_up_blocks):
if n % 2 != 0: # FusionBlock
fusion_out = m(torch.concat([out,fpn_features[n // 2 + 1]],dim=1))
outputs.append(fusion_out)
else: # ConvBlock
out = m(outputs[-1])
'''
outputs = [
fpn_feature[0],
FusionBlock( concat[ ConvBlock(fpn_feature[0]), fpn_feature[1] ] ),
FusionBlock( concat[ ConvBlock(FusionBlock( concat[ ConvBlock(fpn_feature[0]), fpn_feature[1] ] )), fpn_feature[2] ] ),
]
'''
class CCFF(nn.Module):
def __init__(self,
hidden_dim:int,
in_channels:list[int],
depth_mult:float,
expansion:float):
super().__init__()
self.hidden_dim = hidden_dim
self.in_channels = in_channels #channels of AIFI(S5),S4,S3
self.depth_mult = depth_mult
self.expansion = expansion
# top-down fpn
self.top_down_fpn = FPNBlock(self.hidden_dim,self.in_channels,self.depth_mult,self.expansion)
# bottom-up pan
self.bottom_up_pan = PANBlock(self.hidden_dim,self.in_channels,self.depth_mult,self.expansion)
def forward(self,aifi_s5,s4_proj,s3_proj):
top_down_out = self.top_down_fpn([aifi_s5,s4_proj,s3_proj])
bottom_up_out = self.bottom_up_pan(top_down_out)
return torch.concat(bottom_up_out,dim=1)
3. Uncertainty-Minimal Query Selector:
In traditional DETR models, object queries are selected based on confidence scores to determine potential foreground objects.
The RT-DETR paper proposes an Uncertainty-Minimal Query Selector that theoretically should explicitly construct and optimize epistemic uncertainty between classification and localization predictions.
The paper defines this uncertainty as
However, the actual implementation takes a more practical approach:
Query Selection:
- During both training and inference, queries are initially selected based on classification confidence scores.
- The selection uses topk operation to choose the most promising queries.
_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, num_queries, dim=1)
Uncertainty-Aware Training:
The uncertainty minimization is achieved through the Varifocal Loss during training, which:
- Calculates IoU between predicted and ground truth boxes to measure localization accuracy.
- Weights classification targets by their IoU scores
- Applies dynamic weighting that considers both prediction confidence and localization quality.
This design creates an implicit uncertainty-minimal selection where:
During Training:
The model learns to assign high confidence scores only when:
- The classification is accurate
- The localization IoU is high
- The overall prediction uncertainty is low
During Inference:
The confidence-based selection becomes inherently uncertainty-aware because:
- The model has learned to incorporate localization quality into its confidence predictions
- High confidence scores naturally indicate both good classification and accurate localization
- No explicit uncertainty calculation is needed, making inference more efficient
This implementation differs from both traditional DETR and the theoretical approach in the RT-DETR paper by creating an end-to-end learned uncertainty awareness through training, rather than using explicit uncertainty calculations during query selection.
Classification and IoU scores of the selected encoder features. Purple and Green dots represent the selected features from model trained with uncertainty-minimal query selection and vanilla query selection, respectively. (source: https://arxiv.org/pdf/2304.08069)
The scatter plot shows how uncertainty-minimal query selection (purple) picks higher quality features compared to vanilla selection (green). Purple dots cluster towards the top-right, indicating features with both accurate classification and localization, while green dots spread more towards the bottom-right. This visual evidence demonstrates why uncertainty-minimal selection leads to better detection performance.
4. Decoder:
After selecting the most confident queries, RT-DETR uses a transformer decoder with Multi-Scale Deformable Attention to predict object locations and classes. The decoder architecture is specifically designed for efficient processing of multi-scale feature maps while maintaining high accuracy.
By default, the decoder is composed of six identical layers. Each layer is carefully structured to process and refine object detections through three main operations:
- Regular self-attention between queries
- Multi-scale deformable cross-attention
- Feed-forward network (FFN)
To ensure stable training and effective feature transformation, each of these operations is complemented with dropout and layer normalization.
Multi-Scale Deformable Attention
Vision Transformer Deformable Attention (https://arxiv.org/pdf/2201.00520)
RT-DETR introduces Multi-Scale Deformable Attention in its decoder, which significantly improves upon the traditional self-attention mechanism found in standard transformer architectures. Traditional self-attention operates by computing relationships between all pairs of queries across the entire feature map. While this approach can capture global dependencies, it comes with a high computational cost, especially when processing high-resolution images with large feature maps. This can lead to inefficiencies, as the model must attend to every location, even areas irrelevant to the task at hand.
In contrast, Multi-Scale Deformable Attention focuses on a sparse set of key points across multiple feature scales, rather than attending to all spatial locations. This selective attention mechanism allows the model to concentrate on the most relevant regions for object detection, dramatically reducing the computational complexity. By sampling a few important points at different scales, the model maintains its ability to detect objects of varying sizes, while skipping unnecessary calculations in less critical regions.
source: https://arxiv.org/pdf/2010.04159
Results and Conclusion
source: https://arxiv.org/pdf/2304.08069
RT-DETR excels because it combines transformer-based attention for rich high-level features with CNN-based fusion for handling multi-scale information. This combination not only makes RT-DETR fast but also highly accurate. In benchmarks, RT-DETR achieves 53.1% Average Precision (AP) with ResNet-50 and 54.3% AP with ResNet-101 on COCO, significantly outperforming previous real-time detectors like YOLO.
source: https://arxiv.org/pdf/2304.08069
Additionally, RT-DETR is highly customizable. By adjusting the number of decoder layers, you can fine-tune the model for different speed and accuracy requirements without retraining. This flexibility makes RT-DETR suitable for various real-time applications, from surveillance to autonomous driving.
source: https://arxiv.org/pdf/2304.08069
In conclusion, RT-DETR is a breakthrough in real-time object detection. It eliminates the need for NMS, reduces latency, and achieves high accuracy by using an innovative hybrid architecture. With its fast and efficient design, RT-DETR sets a new standard for real-time object detection.
References
RepConv: https://arxiv.org/pdf/2101.03697
Deformable Attention: https://arxiv.org/pdf/2201.00520
RT-DETR paper: https://arxiv.org/pdf/2304.08069
RT-DETR Github: https://github.com/lyuwenyu/RT-DETR