Creating a Powerful COVID-19 Mask Detection Tool with PyTorch



Original Source Here

Distillation

This was the golden step.

Photo by Jan Ranft on Unsplash

Distillation is a bleeding-edge technique that trains smaller models to make faster predictions. It distills the knowledge in a network. This is perfect for our use case.

In distillation, you train a student from a teacher model. Instead of training your model on your data, you train it on the predictions of another model. As a result, you replicate the results with a smaller network.

Distillation can be challenging and resource-intensive to implement. Luckily, KD_Lib for PyTorch provides implementations of research papers accessible as a library. The below code snippet was used for vanilla distillation.

import torchimport torch.optim as optimfrom torchvision import datasets, transforms
from KD_Lib.KD import VanillaKD
# Define models
teacher_model = resnet
student_model = inception
# Define optimizers
teacher_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
student_optimizer = optim.SGD(student_model.parameters(), 0.01)
# Perform distillation
distiller = VanillaKD(teacher_model, student_model, train_dataset_loader, val_dataset_loader,
teacher_optimizer, student_optimizer, device = 'cuda')
distiller.train_teacher(epochs=5, plot_losses=True, save_model=True)
distiller.train_student(epochs=5, plot_losses=True, save_model=True)
distiller.evaluate(teacher=False)
distiller.get_parameters()

Using vanilla distillation on our DenseNet model, we achieved 99.85% accuracy with our baseline CNN. The CNN outperformed every state-of-the-art model’s speed at 775 inferences/second on a V100. The CNN did this with <15% as many parameters.

Even more impressively, running distillation again continued to improve the accuracy. For example, results improved with a combination of NoisyTeacher, SelfTraining, and MessyCollab.

Side note: self-training uses your model as both the teacher and the student, how cool!

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: