Source code for agnapprox.nets.lenet5

"""
Class definition for LeNet5 Approximate NN
"""
import logging

import torch.optim as optim

from agnapprox.nets.approxnet import ApproxNet
from agnapprox.nets.base import lenet5

logger = logging.getLogger(__name__)

# pylint: disable=too-many-ancestors
[docs]class LeNet5(ApproxNet): """ Definition of training hyperparameters for approximate LeNet5 """ def __init__(self): super().__init__() self.model = lenet5.LeNet5(10) self.name = "LeNet5" self.topk = (1,) self.epochs = { "baseline": 10, "gradient_search": 3, "qat": 1, "approx": 5, } self.num_gpus = 1 self.gather_noisy_modules()
[docs] def _baseline_optimizers(self): optimizer = optim.SGD( self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4 ) scheduler = optim.lr_scheduler.StepLR(optimizer, 5, gamma=0.75) return [optimizer], [scheduler]
[docs] def _qat_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=1e-3, momentum=0.9) scheduler = optim.lr_scheduler.StepLR(optimizer, 5) return [optimizer], [scheduler]
[docs] def _approx_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=1e-3, momentum=0.9) scheduler = optim.lr_scheduler.StepLR(optimizer, 2) return [optimizer], [scheduler]
[docs] def _gs_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=1e-3, momentum=0.9) scheduler = optim.lr_scheduler.StepLR(optimizer, 4) return [optimizer], [scheduler]