Dataset和DataLoader用法

Dataset和DataLoader用法

在d2l中有简洁的加载固定数据的方式,如下

d2l.load_data_fashion_mnist()
# 源码
Signature: d2l.load_data_fashion_mnist(batch_size, resize=None)
Source:   
def load_data_fashion_mnist(batch_size, resize=None):
    """Download the Fashion-MNIST dataset and then load it into memory.

    Defined in :numref:`sec_fashion_mnist`"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
File:      ~/anaconda3/envs/d2l/lib/python3.9/site-packages/d2l/torch.py
Type:      function

如果我们要自定义需要加载的数据集

数据集:一个图片文件夹,用csv文件来表示训练数据和标签

# 定义Dataset
import pandas as pd
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as transforms

class CustomDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data = pd.read_csv(csv_file) 
        self.root_dir = root_dir
        self.transform = transform
        label_encoder = LabelEncoder()
        self.labels = label_encoder.fit_transform(self.data['label'])
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.data.iloc[idx, 0])
        # 读取图片并做增广
        image = Image.open(img_name)
        if self.transform is not None:
            image = self.transform(image)
        # 将数字转换成独热编码的张量(记得转换成float)
        label = F.one_hot(torch.tensor(self.labels[idx]), 		
        					num_classes=self.data['label'].nunique()).float()
        return image, label

# 定义参数和超参数训练
batch_size = 256
lr = num_epoch = 0.9, 10

# 加载数据
sample = '/kaggle/input/classify-leaves/sample_submission.csv'
ts_path = "/kaggle/input/classify-leaves/test.csv"
tr_path = "/kaggle/input/classify-leaves/train.csv"
image_path = '/kaggle/input/classify-leaves'

dataset = CustomDataset(csv_file = sample, root_dir = image_path, transform=transform_train)
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
tr_dataset, te_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size])

tr_dataloader = DataLoader(tr_dataset, batch_size, shuffle=True)
ts_dataloader = DataLoader(te_dataset, batch_size, shuffle=False)

总结

需要将__init__,len,__getitem__按照数据集和模型的要求,对应的编写好代码。