pytorch有关 Dataset和 DataLoader的心得

2022-07-28,,,

先来看看官方文档的说法:https://pytorch.org/docs/stable/data.html

DataLoader支持两种数据集:map-style datasets 和 iterable-style datasets.

一般我们用的最多的是map-style datasets,因此这里只讲map类型的,还有我到目前为止也没用过iterable类型的。(无知导致无能,很抱歉,这部分我不知道~)

我们要使用map-style datasets,要实现两种方法__getitem__()和__len__(),这里我拿出我最近写的一个demo

class myDataset(Dataset):
    def __init__(self, data, label):
        self.data_list = data
        self.label_list = label
    def __getitem__(self, index):
        data_idx = []
        data_idx.append(word2idx.transform(self.data_list[index].split(), max_len=max_len))
        text = torch.LongTensor(data_idx)
        label = torch.LongTensor([self.label_list[index]])
        return text, label
    def __len__(self):
        return len(self.data_list)
dataset = myDataset(data_list, label_list)
data_loader = DataLoader(dataset, batch_size=128, shuffle=True)

注意点:

1、__getitem__(self,index)里面每次的返回值,是一对数据,即文本和标签,我们通过参数index来确定返回哪个数据。

(我之前是返回的一批数据直接导致内存爆了,真的是无知者无畏啊,给大家看看爆的多少,一百多G的GPU)

 

 

本文地址:https://blog.csdn.net/qq_40819945/article/details/109622591

《pytorch有关 Dataset和 DataLoader的心得.doc》

下载本文的Word格式文档,以方便收藏与打印。