Source code for model_archs.img_models

"""
    Class file that enlists models for image classification
"""
import torch
import torch.nn as nn
from utils.config import MODELS


[docs]class ResNetFeaturesFlatten(nn.Module):
[docs] def __init__(self, model_key): """ Initialize the model architecture Args: model_key (str): Key value corresponding to the model architecture Returns: None """ super(ResNetFeaturesFlatten, self).__init__() self.resnet_model = MODELS[model_key] self.conv1 = self.resnet_model.conv1 self.bn1 = self.resnet_model.bn1 self.relu = self.resnet_model.relu self.maxpool = self.resnet_model.maxpool self.layer1 = self.resnet_model.layer1 self.layer2 = self.resnet_model.layer2 self.layer3 = self.resnet_model.layer3 self.layer4 = self.resnet_model.layer4 self.avgpool = self.resnet_model.avgpool if model_key == "resnet18": self.fc = nn.Linear(512, 15) else: self.fc = nn.Linear(2048, 15)
[docs] def forward(self, input): """ Function to compute forward pass of the network Args: input (Tensor): Image tensor of shape (N X C X H X W), where N denotes minibatch size, C, H, W denotes image channels, width and height Returns: output (Tensor): Raw prediction scores for each class """ x = self.conv1(input) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) output = self.fc(self.relu(x)) return output