"""
Class definition for ResNet Approximate NN
"""
import logging
from typing import Optional
import torch.optim as optim
from .approxnet import ApproxNet
from .base import resnet
logger = logging.getLogger(__name__)
# pylint: disable=too-many-ancestors
[docs]class ResNet(ApproxNet):
"""
Definition of training hyperparameters for
approximate ResNet
"""
def __init__(self, resnet_size: Optional[str] = "ResNet8", **kwargs):
super().__init__(**kwargs)
self.name = resnet_size
if self.name == "ResNet8":
self.model = resnet.resnet8()
if self.name == "ResNet14":
self.model = resnet.resnet14()
if self.name == "ResNet20":
self.model = resnet.resnet20()
if self.name == "ResNet32":
self.model = resnet.resnet32()
self.topk: tuple = (1,)
self.epochs = {
"baseline": 180,
"gradient_search": 30,
"qat": 30,
"approx": 10,
}
self.num_gpus: int = 1
self.gather_noisy_modules()
[docs] def _baseline_optimizers(self):
optimizer = optim.SGD(
self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4
)
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[90, 135], gamma=0.1
)
return [optimizer], [scheduler]
[docs] def _qat_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=1e-2, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10)
return [optimizer], [scheduler]
[docs] def _approx_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=1e-3, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, 4)
return [optimizer], [scheduler]
[docs] def _gs_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=0.1, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10)
return [optimizer], [scheduler]