"""
Class definition for VGG Approximate NN
"""
import logging
from typing import Optional
import torch
import torchvision
from .approxnet import ApproxNet
logger = logging.getLogger(__name__)
# pylint: disable=too-many-ancestors
[docs]class VGG(ApproxNet):
"""
Definition of training hyperparameters for
approximate VGG
"""
def __init__(
self,
vgg_size: Optional[str] = "VGG11",
num_classes: int = 200,
pretrained: bool = True,
):
super().__init__()
self.name = vgg_size
if self.name == "VGG11":
self.model = torchvision.models.vgg11_bn(pretrained=pretrained)
if self.name == "VGG13":
self.model = torchvision.models.vgg13_bn(pretrained=pretrained)
if self.name == "VGG16":
self.model = torchvision.models.vgg16_bn(pretrained=pretrained)
if self.name == "VGG19":
self.model = torchvision.models.vgg19_bn(pretrained=pretrained)
# Replace last layer with randomly initialized layer of correct size
if num_classes != 1000:
self.model.classifier[6] = torch.nn.Linear(4096, num_classes)
self.topk = (1, 5)
self.epochs: dict = {
"baseline": 30,
"qat": 8,
"gradient_search": 3,
"approx": 2,
}
self.num_gpus = 1
self.gather_noisy_modules()
[docs] def _baseline_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(), lr=5e-2, momentum=0.9, weight_decay=5e-4
)
scheduler = {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, factor=0.2, patience=3, mode="max"
),
"monitor": "val_acc_top5",
"interval": "epoch",
"name": "lr",
}
return [optimizer], [scheduler]
[docs] def _qat_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=1e-3, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 4)
return [optimizer], [scheduler]
[docs] def _approx_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=1e-3, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1)
return [optimizer], [scheduler]
[docs] def _gs_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=5e-4, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2)
return [optimizer], [scheduler]