Before diving into pruning, let’s take a brief look at YOLOv3. It is a powerful but fairly computationally heavy model for edge devices. To put it simply, architectures like YOLO follow a backbone–neck–head design, where the backbone (e.g., Darknet-53 in YOLOv3) extracts hierarchical image features, the neck (often feature pyramids or path aggregation networks) fuses multi-scale features, and the detection head applies convolutional layers to predict bounding boxes, objectness scores, and class probabilities.
In this article we are going to follow the open sourced pruning technique from Pytorch while there is also "Vitis AI optimizer (pruning tool)" one can check withwhich is available at free license at Vitis AI 3.5 GPU docker. For more information on Vitis AI optimizer - https://xilinx.github.io/Vitis-AI/3.5/html/docs/workflow-model-development.html#model-optimization and example at here.
This article can be followed for AMD-XilinxKria, MPSoC FPGA and Versal AdaptiveSoCs based ML model deployments.
For Quantizing/Compiling float model (without pruning) with Vitis AI one can also refer this tutorial - https://www.hackster.io/LogicTronix/yolov3-pytorch-quantization-compilation-and-inference-8ce23c
Below is the architecture of YOLOv3:
This article will not cover the mathematical and conceptual intricacies of single-shot object detector like YOLOv3. For in-depth explanation of YoloV3 architecture, please refer to this: https://www.youtube.com/watch?v=9fhAbvPWzKs
However, there are few topics related to model’s characteristics we need to review before we move onto pruning :
- Number of parameters and Compute Complexity
- Compute capacity of devices and its relation to model’s latency
The baseline YOLOv3 network requires 61.6 million parameters. The number of parameters roughly gives us an idea about the resource requirement of our model. For Deep Convolution Network, number of parameters can be calculated with the following formula:
nparamw=(KernelH*KernelW*inputChannels)*OutputChannels*numberofConvLayers
This number is independent of the input size of the model. In most cases, Over-parameterized models often yield better predictive capability than underparametrized model i.e. higher the number of parameters, more intricate patterns it captures but in turn is more power and compute hungry.
Compute Complexity:The model's workload is expressed in terms of GFLOPs {Giga(Billion) Floating Point Operations}. It defines how many floating point operations in a single second does the forward pass require. A single Multiply and Accumulate Operation(MAC) is equivalent to two FLOPs i.e. (one multiplication and one addition).
1 GFLOPs=2 * GMACs
To calculate MAC of a single convolution layer. We make use of the following formula:
(KernelW*KernelH*InputChannels*OutputChannels*InputH*InputW)*ConvLayers
So in the case of YOLOv3, on 416x416 input: 32.8 GMACs (multiply–accumulate operations) per forward pass is required. In terms of GFLOPs, the compute complexity of the model is: 65.6 GFLOPs.
Pruning Literature and Dependency GraphsPruning is one of the widely used inference optimization techniques which aims to reduce model size and computation complexity by removing less important parameters while preserving accuracy. The concept was popularized by the paper The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. Here the authors suggested that within a large, randomly initialized neural network, there exist much smaller subnetworks –called winning tickets– that, when trained in isolation with their original initialization, can match the performance of the full network. This implies that overparameterized networks are not strictly necessary for accuracy, but help optimization by embedding these trainable sparse subnetworks. The hypothesis sparked major interest in pruning, showing that effective models can be both sparse and performant without needing massive architectures.
In general, pruning can be framed as an optimization problem:
where θ are the network weights, M is a binary mask that selects which weights or channels to keep, and k is the pruning budget(sparsity). ||M||0 is the L0 “norm” and counts the number of non-zero entries in M.
Types of pruning:- Unstructured pruning: M acts at the element level, setting individual weights to zero (e.g. magnitude pruning: prune if ) ∣i∣ < ϵ. This reduces parameters but produces sparses matrices that are inefficient.
- Structured pruning: M removes whole filters, channels or blocks. For example, in channel pruning the output feature map Fj from filter j is dropped if its importance score falls below a threshold, directly reducing FLOPs.
Usually for pruning, we have to evaluate the importance of the parameters using some criteria and remove the parameters with the smallest importance.
- Magnitude-based: prune by ||Wj||p (e.g. L1 norm of filter weight)
- Activation-based: prune filters with low average response, e.g.
- Gradient/Taylor-based: Approximate loss change if filter j is removed:
For in-depth explanation of pruning literature, I suggest you to follow Lei Mao Blogs for pruning: Pruning for Neural Networks Parameter Importance Approximation Via Taylor Expansion In Neural Network Pruning
DepGraph:In this article we use DepGraph library for pruning YOLO. It groups and removes coupled parameters. Modern neural networks are not simple linear stacks of layers, instead they often include residual connection, concatenations, splits, and multi-branch modules. This creates interdependencies between layers. If you naively prune a channel in one layer without considering these connections, you can break the entire computation graph. This is why a dependency graph algorithm is required. It addresses this issue by automatically identifying dependencies and collecting groups for pruning. It explicitly tracks how each tensor flows through the network. In practice, we use torch_pruning framework which builds this graph automatically by tracing forward pass. When you request to prune a channel in one layer, the library looks at the dependency graph to identify and prune the corresponding inputs in all connected layers. The heavy lifting of parameter importance determination, and executing pruning operation is all done with torch_pruning framework.
FLOPs Reduction with Channel PruningNow, let’s look at how channel pruning effects the FLOPs.
For a standard convolution layer, the number of FLOPs is:
FLOPs=H x W x Cin x Cout x k2
- (H, W): spatial dimensions of the output feature map
- Cin: number of input channels
- Cout: number of output channels (filters)
- k: kernel size (e.g., 3 for a 3×3 conv)
If we prune a ratio (r) of the output channels, then the new cost becomes:
If both input and output channels are pruned by ratio (r), then:
which leads to roughly a quadratic reduction in FLOPs as pruning propagates through layers.
Pruning Implementation:Now that we have some idea on theoritical aspect of compute complexity of models, device’s compute capacity, and pruning and its effect on FLOPs, we can move onto implementation. Here, we will test out the base YOLOv3 model and pruned YOLOv3 model at 20% sparsity on CPU, GPU and DPU, and compare the results. Again, most of the heavy lifting for pruning is done by Torch-Pruning framework, so I highly suggest you to go through their documentation before moving forward: Torch-Pruning
Before moving onto pruning code details. Let’s understand the standard pruning dynamics followed.
In this article, we perform pruning on converged model. The usual practice is to prune the network in multiple iterative steps, finetune(retrain) for few epochs after pruning, evaluate the model, reprune the model and repeat the steps for few iterations.
We are going to need a few helper functions:
1) Train one epoch: When this function is called, it trains the network for a single epoch.
def train_one_epoch(args,model, loader,ema, criterion, optimizer, device,scaler,epochs):
model.train()
multipart_loss_meter = AverageMeter()
obj_loss_meter = AverageMeter()
noobj_loss_meter = AverageMeter()
txty_loss_meter = AverageMeter()
twth_loss_meter = AverageMeter()
cls_loss_meter = AverageMeter()
optimizer.zero_grad()
losses = defaultdict(float)
loss_type = ['multipart','obj','noobj','txty','twth','cls']
for i,(_,images, targets,_) in
enumerate(tqdm(loader,desc="Training",leave=False)):
ni = i + len(loader) * (epochs - 1)
if ni <= args.nw:
args.grad_accumulate = max(1,np.interp(ni,[0,args.nw],
[1,args.nominal_batch_size /
args.batch_size]).round())
set_lr(optimizer,args.base_lr * pow(ni/(args.nw),4))
images = images.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
with amp.autocast(enabled = not args.no_amp):
outputs = model(images)
loss = criterion(outputs, targets)
scaler.scale((loss[0] / args.grad_accumulate) *
args.world_size).backward()
if ni - args.last_opt_step >= args.grad_accumulate:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if ema is not None:
ema.update(model)
args.last_opt_step = ni
multipart_loss_meter.update(loss[0].item(),images.size(0))
obj_loss_meter.update(loss[1].item(),images.size(0))
noobj_loss_meter.update(loss[2].item(),images.size(0))
txty_loss_meter.update(loss[3].item(),images.size(0))
twth_loss_meter.update(loss[4].item(),images.size(0))
cls_loss_meter.update(loss[5].item(),images.size(0))
del images,outputs
torch.cuda.empty_cache()
loss_avg =[multipart_loss_meter.avg,
obj_loss_meter.avg,
noobj_loss_meter.avg,
txty_loss_meter.avg,
twth_loss_meter.avg,
cls_loss_meter.avg]
loss_str = f"[Train-Epoch:{epochs:03d}]"
for loss_name,loss_value in zip(loss_type,loss_avg):
losses[loss_name] = loss_value
loss_str += f"{loss_name}:{losses[loss_name]:.4f}"
print(f"{loss_name}: {losses[loss_name]:.4f}")
return loss_str
Bear in mind that here grad Accumulation and mixed precision concepts are used. For in-depth information on gradient Accumulation and how it is an alternative for eliminating gradient noise introduced due to small batch size and mixed precision training, refer to: Gradient Accumulation Automatic Mixed Precision
2) Next helper function we would be using are the mAP calculation function and the evaluation function.
@torch.no_grad()
def calc_mAP(args,model,val_loader,anchors,dpu:bool = True):
model.eval()
args.mAP_filepath = Path(val_loader.dataset.dataset.mAP_filepath)
args.exp_path = Path(args.exp_path)
os.makedirs(args.exp_path, exist_ok=True)
evaluator = Evaluator(args.mAP_filepath)
mAP_dict,eval_text = validate(args,anchors = anchors,
dataloader = val_loader,
model = model,evaluator = evaluator,
save_result = True,dpu = True,
save_filename = "Pruned_map.txt")
return mAP_dict,eval_text
@torch.no_grad()
def evaluate(args, model, loader, criterion,anchors, device, desc="eval"):
model.eval()
multipart_loss_meter = AverageMeter()
obj_loss_meter = AverageMeter()
noobj_loss_meter = AverageMeter()
txty_loss_meter = AverageMeter()
twth_loss_meter = AverageMeter()
cls_loss_meter = AverageMeter()
for _,images, targets,_ in tqdm(loader, desc=f"Evaluating-{desc}",
leave=False):
images = images.to(device, non_blocking=True)
out = model(images)
preds0 = do_sigmoid(out[0])
preds1 = do_sigmoid(out[1])
preds2 = do_sigmoid(out[2])
outputs = (preds0,preds1,preds2)
loss = criterion(outputs, targets)
multipart_loss_meter.update(loss[0].item(),images.size(0))
obj_loss_meter.update(loss[1].item(),images.size(0))
noobj_loss_meter.update(loss[2],images.size(0))
txty_loss_meter.update(loss[3],images.size(0))
twth_loss_meter.update(loss[4],images.size(0))
cls_loss_meter.update(loss[5],images.size(0))
# loss_meter.update(loss, images.size(0))
mAP_dict,eval_text = calc_mAP(args,
model,
val_loader = loader,
anchors = anchors,
dpu = True)
print(f"[{desc}]\nMultipart_Loss: {multipart_loss_meter.avg:.4f} | Object Loss:
{obj_loss_meter.avg:.4f} | No Object Loss:{noobj_loss_meter.avg:.4f} |
txty
Loss:{txty_loss_meter.avg:.4f} twth Loss:{twth_loss_meter.avg:.4f} |
cls loss: {cls_loss_meter.avg:.4f}\nmAP:{eval_text}")
return [multipart_loss_meter.avg,
obj_loss_meter.avg,
noobj_loss_meter.avg,
txty_loss_meter.avg,
twth_loss_meter.avg,
cls_loss_meter.avg],mAP_dict
The evaluate function is built on top of validate function and Evaluator class from this repo:Object-Detection-Metrics
3) Taylor Prune function.
def taylor_prune(args,
model: nn.Module,
example_inputs: torch.Tensor,
train_loader: DataLoader,
test_loader:DataLoader,
criterion: nn.Module,
device: torch.device,
pruning_ratio: float = 0.45,
iter_steps: int = 5,
round_to: int = 16,
ignored_layers: Optional[list] = None,
finetune_epochs: int = 0,
lr: float = 1e-3) -> nn.Module:
assert tp is not None, "torch-pruning is not installed. pip install torch-
pruning"
model.to(device)
model.train()
# Importance: TaylorExpansion (requires gradients)
imp = tp.importance.TaylorImportance()
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs.to(device),
importance = imp,
iterative_steps = iter_steps,
pruning_ratio = pruning_ratio,
ignored_layers = ignored_layers,
round_to = round_to
)
base_macs,base_params = tp.utils.count_ops_and_params(model,
example_inputs.to(device))
print(f"[Pruning] Baseline: MACs={base_macs/1e6:.2f}M | Params=
{base_params/1e6:.2f}M")
optimizer = optim.SGD(model.parameters(),lr = lr,momentum = 0.9,
weight_decay = 1e-4)
for i in range(iter_steps):
# --- collect Taylor grads on a small batch
_, images, targets, _ = next(iter(train_loader))
images = images.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
outputs = model(images)
loss = criterion(outputs, targets)
loss[0].backward()
# --- prune
pruner.step()
base_macs,base_params = tp.utils.count_ops_and_params(model,
example_inputs.to(device))
print(f"[Pruning] Baseline: MACs={base_macs/1e6:.2f}M | Params=
{base_params/1e6:.2f}M")
# --- rebuild optimizer after structural change
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9,
weight_decay=1e-4)
optimizer.zero_grad(set_to_none=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=finetune_epochs,
eta_min=0.001 * 0.1)
# --- optional BN recalibration
with torch.no_grad():
model.train()
for _b, (_, imgs, _, _) in zip(range(100), train_loader):
imgs = imgs.to(device, non_blocking=True)
_ = model(imgs)
# --- short fine-tune: NO warm-up here
if finetune_epochs > 0:
# fix accumulation and disable warm-up in your train loop
steady_acc = max(round(args.nominal_batch_size / args.batch_size), 1)
args.grad_accumulate = steady_acc
args.nw = -1
args.last_opt_step = -1
scaler = amp.GradScaler(enabled = not args.no_amp)
ema = ModelEMA(model = model)
best_map = 0
for e in range(finetune_epochs):
tr_str = train_one_epoch(args,model, train_loader, ema, criterion,
optimizer, device,scaler,epochs = e + 1) # train loop should respect args.nw == -1
scheduler.step()
print()
print(f" ↳ fine-tune {e+1}/{finetune_epochs} | {tr_str}")
loss,mAP_dict = evaluate(args = args,model = model,loader =
test_loader,criterion=criterion,anchors = model.anchors,device =
device,desc = "eval")
print(loss)
if mAP_dict['all']['mAP_50'] > best_map:
best_map = mAP_dict['all']['mAP_50']
torch.save(model,'./yolov3-pruned-best.pth')
return model
In this article we will follow taylor approximation for determining the parameter importance. It is a gradient based importance determining method. We calculate the loss for each pruning iteration steps after resetting the gradients, so that no gradient accumulation happens.
In the code above:
1) We define a Taylor-based importance (gradient x activation) for pruning.
2) Initialize a MagnitudePruner object from torch-pruning with given pruning ratio, iterative steps, and ignored layers.
3) Acquire the model stats using tp.utils.count_ops_and_params to compute the original MACs and parameters before pruning.
4) Initialize SGD optimizer with lr = 1e-3, momentum = 0.9, weight_decay = 1e-4 5) Iterative pruning loop (runs for iter_steps):
- Samples a batch from train_loader
- Runs a forward pass, computes loss, and backpropagates.
- Calls pruner.step() to prune based on Taylor Importance.
- Prints updated MACs and Params. * Rebuilds the optimizer since the model structure has changed.
- Sets up a cosine annealing LR scheduler.
6) Runs a few batches in forward-only mode to re-estimate BN statistics after pruning.
7) Optional fine-tuning (if finetune_epochs > 0)
- Adjusts gradient accumulation args.
- Uses AMP scaler and EMA for stable training.
- Runs short fine-tuning epochs with train_one_epoch.
- Evaluates model on test_loader each epoch.
- Saves the best pruned model as yolov3-pruned-best.pth if mAP improves.
8) Returns the pruned (and optionally fine-tuned) model.
While pruning object detectors, we must be careful not to prune detection heads. And to heavily preserve the original accuracy, we must also exclude some entry and exit stages of residual blocks. For this, we create another helper function as:
def collect_ignored_convs(model, keep_stem=False,
keep_stage_entry=False,
keep_exit_stage = False):
ignored = set()
for name, m in model.named_modules():
if not isinstance(m, nn.Conv2d):
continue
if name.endswith(".detect"):
ignored.add(m)
continue
if keep_exit_stage and ".res_block" in name and ".conv2.conv.0" in name:
ignored.add(m)
continue
if keep_stem and name == "backbone.conv1.conv.0":
ignored.add(m)
continue
if keep_stage_entry and (name.endswith(".res_block1.conv.conv.0")
or name.endswith(".res_block2.conv.conv.0")
or name.endswith(".res_block3.conv.conv.0")
or name.endswith(".res_block4.conv.conv.0")
or name.endswith(".res_block5.conv.conv.0")):
ignored.add(m)
continue
return list(ignored)
This will return a list of layers that is to be ignored while pruning. This list will then be passed to the Magnitude Pruner Object. In our case, we only ignored the final detection head and pruned all the remaining layers.
By running:
python3 yolov3pruning.py --model <Path to .pt file> --data <path to .yaml file>
--train-base
--prune
--prune-steps <no. of pruning steps>
--finetune-steps <no. of finetuning steps>
--post-train-pruned <path to pruned model>
After pruning the network, we obtain the following result:
This is verifies our that the pruned network leads to quadratic reduction in GFLOPs.
Deployment and InferenceOn Kria K26 (Kria KV260):For inferencing the model on the Kria KV260 board, we would need to quantize and compile the model in the VitisAI docker environment. In this article, we won’t go through setting up the VitisAI docker environment.
Now this tutorial on Quantizing/Compiling float model with Vitis AI can also be followed - https://www.hackster.io/LogicTronix/yolov3-pytorch-quantization-compilation-and-inference-8ce23c
For Quantizing the model to INT8, we will follow this code:
def quantization(title='optimize',
model_name='',
file_path=''):
quant_mode = args.quant_mode
deploy = args.deploy
batch_size = args.batch_size
config_file = args.config_file
finetune = args.fast_finetune
subset_length = args.subset_len
data = args.data
device = torch.device(args.device)
if quant_mode != 'test' and deploy:
deploy = False
print(r'Warning: Exporting xmodel needs to be done in quantization test')
if deploy and (batch_size != 1):
print(r'Warning: Exporting xmodel needs batch size to be 1 and only 1 iteration of inference, change them automatically!')
batch_size = 1
#Load the model
anchors = [
[0.248, 0.7237237 ],
[0.36144578, 0.53 ],
[0.42, 0.9306667 ],
[0.456, 0.6858006 ],
[0.488, 0.8168168 ],
[0.6636637, 0.274 ],
[0.806, 0.648 ],
[0.8605263, 0.8736842 ],
[0.944, 0.5733333 ]
]
model = load_model(mode = "dpu",
device=args.device,
input_size = 416,
num_classes = 20,
model_type = "base",
anchors = anchors,
model_path = args.model_path)
input_sig = torch.randn([batch_size,3,416,416]).to(device)
if quant_mode == 'float':
quant_model = model
else:
quantizer = torch_quantizer(quant_mode,
model,
(input_sig,),
device=device,
quant_config_file=config_file)
quant_model = quantizer.quant_model.eval()
if quant_mode == 'calib':
calib_loader = get_dataloader(voc_path = data,
batch_size = batch_size,
subset_length = subset_length,
train = False)
quant_model.eval()
with torch.no_grad():
for _,imgs,_,_ in tqdm(calib_loader,desc = 'Calibrating'):
imgs = imgs.to(device,non_blocking=True)
_ = quant_model(imgs)
criterion = YOLOv3Loss(input_size=416,
num_classes = 20,
anchors = model.anchors)
if finetune == True:
ft_loader = get_dataloader(
voc_path=data,
batch_size = batch_size,
subset_length = subset_length,
train = False)
quantizer.fast_finetune(evaluate,(quant_model,ft_loader,criterion))
quantizer.export_quant_config()
elif quant_mode == 'test':
if quant_mode == finetune:
quantizer.load_ft_param() # only exists if calib+fast_finetune ran earlier
print("Loaded fast-finetune params.")
with torch.no_grad():
_ = quant_model(input_sig)
if deploy:
quantizer.export_torch_script()
quantizer.export_onnx_model()
quantizer.export_xmodel(deploy_check=True, dynamic_batch=True)
From this we get the.xmodel which will be used for getting a compiled model.
vai_c_xir --xmodel /path/to/.xmodel
--arch /opt/vitis_ai/compiler/arch/DPUCZDX8G/KV260/arch.json
--net_name <name of the output model>
--output_dir ./Compiled
By running this script inside VitisAI docker container, we get a compiled output.xmodel which will be used for greating XIR graph and execute the model on DPU.
For inference code, please refer to the repo: [Link to repo(private right now)]
Results:Following results were obtained on Kria K26 (KV260) board and RTX2070 board.
Inference Time Comparison for DPU.Pruning YOLOv3 with Torch-Pruning showed that significant reductions in MACs and parameters directly translated into lower inference latency and better energy efficiency, especially on edge hardware like the Kria K26. By strategically removing redundant channels while preserving accuracy, we achieved ~42% reduction in MACs and ~39% reduction in parameters, resulting in faster model execution with lower power cost per frame. These results highlight pruning as a practical approach to bridge the gap between high-performance vision models and the constraints of real-world edge deployments
Thanks for going through this tutorial!
Kudos to Saurav Raj Paudel for creating this in-depth pruning tutorial. Thanks Dikesh Shakya Banda for the support.
LogicTronix is AMD-Xilinx Partner for FPGA Design and ML Acceleration.
Comments