Original Source Here
This was the golden step.
Distillation is a bleeding-edge technique that trains smaller models to make faster predictions. It distills the knowledge in a network. This is perfect for our use case.
In distillation, you train a student from a teacher model. Instead of training your model on your data, you train it on the predictions of another model. As a result, you replicate the results with a smaller network.
Distillation can be challenging and resource-intensive to implement. Luckily, KD_Lib for PyTorch provides implementations of research papers accessible as a library. The below code snippet was used for vanilla distillation.
import torchimport torch.optim as optimfrom torchvision import datasets, transforms
from KD_Lib.KD import VanillaKD# Define models
teacher_model = resnet
student_model = inception# Define optimizers
teacher_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
student_optimizer = optim.SGD(student_model.parameters(), 0.01)# Perform distillation
distiller = VanillaKD(teacher_model, student_model, train_dataset_loader, val_dataset_loader,
teacher_optimizer, student_optimizer, device = 'cuda')
distiller.train_teacher(epochs=5, plot_losses=True, save_model=True)
distiller.train_student(epochs=5, plot_losses=True, save_model=True)
Using vanilla distillation on our DenseNet model, we achieved 99.85% accuracy with our baseline CNN. The CNN outperformed every state-of-the-art model’s speed at 775 inferences/second on a V100. The CNN did this with <15% as many parameters.
Even more impressively, running distillation again continued to improve the accuracy. For example, results improved with a combination of NoisyTeacher, SelfTraining, and MessyCollab.
Side note: self-training uses your model as both the teacher and the student, how cool!
Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot