Source code for utils.common_utils

"""
    Utility file consisting of common functions and variables used during training and evaluation
"""

import json
import torch
import torchvision.transforms as transforms
import utils.enums as enums

# Basic Image Transform to convert to Pytorch tensor for GPU training
image_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Image Transform to apply Color Jitter augmentation (used only during training)
image_transform_jitter = transforms.Compose([transforms.ToPILImage(), transforms.ColorJitter(hue=.2, saturation=.2),
                                             transforms.ToTensor(),  transforms.Normalize([0.5] * 3, [0.5] * 3)])

# Image Transform to apply Random Flip augmentation (used only during training)
image_transform_flip = transforms.Compose([transforms.ToPILImage(), transforms.RandomHorizontalFlip(p=0.5),
                                           transforms.RandomRotation(10), transforms.ToTensor(),
                                           transforms.Normalize([0.5] * 3, [0.5] * 3)])

# Image Transform to apply Color Jitter and Random Flip augmentation (used only during training)
image_transform_jitter_flip = transforms.Compose([transforms.ToPILImage(), transforms.ColorJitter(hue=.2, saturation=.2),
                                                  transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(10),
                                                  transforms.ToTensor(),  transforms.Normalize([0.5] * 3, [0.5] * 3)])


# Dictionary to convert cloth categories to classes for numerical purposes
CLOTH_CATEGORIES = {enums.SAREE: 0,
                    enums.WOMEN_KURTA: 1,
                    enums.LEHENGA: 2,
                    enums.BLOUSE: 3,
                    enums.GOWNS: 4,
                    enums.DUPATTAS: 5,
                    enums.LEGGINGS_AND_SALWARS: 6,
                    enums.PALAZZOS: 7,
                    enums.PETTICOATS: 8,
                    enums.MOJARIS_WOMEN: 9,
                    enums.DHOTI_PANTS: 10,
                    enums.KURTA_MEN: 11,
                    enums.NEHRU_JACKETS: 12,
                    enums.SHERWANIS: 13,
                    enums.MOJARIS_MEN: 14
                    }


[docs]def read_json_data(file_name): """ Utility function to read data from json file Args: file_name (str): Path to json file to be read Returns: article_list (List<dict>): List of dict that contains metadata for each item """ with open(file_name) as f: article_list = [json.loads(line) for line in f] return article_list
[docs]def get_accuracy(y_pred, y_actual): """ Utility function to compute accuracy for the minibatch Args: y_pred (Tensor): Predicted class labess y_actual (Tensor): Ground Truth class labels """ _, predicted = torch.max(y_pred, 1) correct = (predicted == y_actual).sum().item() return correct