전공관련/Deep Learning

[Pytorch] Custom Dataloader를 사용하자

매직블럭 2019. 12. 23. 10:21

torch를 사용하다보면 단순 classification이면 기본 폴더로부터 읽어오는 dataloader를 써도 되지만 

새로운 작업을 하거나 classification 외의 다른 학습을 위해서는 custom dataloader가 필요한 경우가 있다.

 

필요한 기능에 따라 custum dataloader의 기본형을 정리 해 둘 예정!

 


1. 각 class 폴더 안에 이미지 데이터가 있는 경우

    (단, 이 경우 실제 class 이름(폴더명)을 가져오려면 폴더명 list에서 id번째 폴더명을 받아와야함)

class customDL(Dataset):
    def read_data_set(self):
        all_img_files = []
        all_labels = []

        class_names = os.walk(self.data_set_path).__next__()[1]

        for index, class_name in enumerate(class_names):
            label = index
            img_dir = os.path.join(self.data_set_path, class_name)
            img_files = os.walk(img_dir).__next__()[2]

            for img_file in img_files:
                img_file = os.path.join(img_dir, img_file)
                img = cv2.imread(img_file, flags=cv2.IMREAD_GRAYSCALE)
                if img is not None:
                    all_img_files.append(img_file)
                    all_labels.append(label)

        return all_img_files, all_labels, len(all_img_files), len(class_names)

    def __init__(self, folder_path, transform=None):
        self.data_set_path = folder_path
        self.image_files_path, self.labels, self.length, self.num_classes = self.read_data_set()
        self.transform = transform

    def __getitem__(self, idx):
        img_origin = cv2.imread(self.image_files_path[idx], flags=cv2.IMREAD_GRAYSCALE)

       # 필요한 연산 수행.

        id = self.labels[idx]

        return img_origin, id

    def __len__(self):
        return self.length

 

 

 

[추가 예정..]