05 AI application

Predict Lung Disease

hyojung chang

written by hyojung chang

The result of AI application that we want are as follows. For this, Faster R-CNN was adopted among object detection methods and machine learning was performed through Torchvision.



0. About Torchvision

Torchvision is a PyTorch's package consist of popular datasets, model architectures, and common image transformations for computer vision. We conducted machine learning through Faster R-CNN library(torchvision.models.detection.faster_rcnn) provided by Torchvision.


1. Define the custom Dataset according to the required structure by Torchvision

Prepared data 

  • images
  • annotation.csv

We prepared images and annotation.csv file consist of [filename, coordinates of RoI, classname(= label)]. We converted them into our custom datasets for training. The Dataset structure required by Torchvision is as follows.


Dataset structure required by Torchvision :

  • image: a PIL Image of size (H, W)
  • target: a dict containing the following fields

         - boxes (FloatTensor[N, 4]): the coordinates of the N bounding boxes in [x0, y0, x1, y1] format, ranging from 0 to W and 0 to H

         - labels (Int64Tensor[N]): the label for each bounding box. 0 represents always the background class.

         - image_id (Int64Tensor[1]): an image identifier. It should be unique between all the images in the dataset, and is used during evaluation

         - area (Tensor[N]): The area of the bounding box. This is used during evaluation with the COCO metric, to separate the metric scores between small, medium and large boxes.

         - iscrowd (UInt8Tensor[N]): instances with iscrowd=True will be ignored during evaluation.

         - (optionally) masks (UInt8Tensor[N, H, W]): The segmentation masks for each one of the objects

         - (optionally) keypoints (FloatTensor[N, K, 3]): For each one of the N objects, it contains the K keypoints in [x, y, visibility] format, defining the object. visibility=0 means that the keypoint is not visible. Note that for data augmentation, the notion of flipping a keypoint is dependent on the data representation, and you should probably adapt references/detection/transforms.py for your new keypoint representation


We parsed one annotation of image from annotation.csv file and defined our custom Dataset class. One element of the dataset consists of one original image and information about RoIs that we want to train.



2. Train and evaluate the model

We adjusted detailed parameters for training based on our comparison results.

Detailed parameters of training

  • epoch = 40
  • batch size = 16
  • the number of workers = 4
  • learning rate = 0.05

Result of evaluation

  • AP = 47.32%
  • Total training time = 15h 38m


(Note) We split data into training set : validation set : testing set = 6 : 2 : 2. As a result, there are 1,0989 images for training, 3638 images for validation and 3,679 images for testing.


4. Predict lung disease