model_archs package¶
Submodules¶
model_archs.img_models module¶
Class file that enlists models for image classification
- class model_archs.img_models.ResNetFeaturesFlatten(model_key)[source]¶
Bases:
torch.nn.modules.module.Module- __init__(model_key)[source]¶
Initialize the model architecture
- Parameters
model_key (str) – Key value corresponding to the model architecture
- Returns
None
- forward(input)[source]¶
Function to compute forward pass of the network
- Parameters
input (Tensor) – Image tensor of shape (N X C X H X W), where N denotes minibatch size, C, H, W denotes image channels, width and height
- Returns
output (Tensor) – Raw prediction scores for each class
- training: bool¶