pytorch 基本操作



Original Source Here

二、Dataset

pytorch是一個深度網路的訓練框架,所以或多或少一定會有資料集合,而且多少會需要對資料做一些操作,比方說設定batch,這些操作在pytorch提供的模組DataLoader下,有很自動化方便的操作,但他只接受pytorch的Dataset,所以要先建立好Dataset,再將他送入DataLoader。

自定義的Dataset,必定要有三個方法在內,基本格式很固定,如下:

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):#繼承Dataset
def __init__(self, data, label):
# 假設data是一個numpy array的格式,也可以是DataFrame
# 要對資料有所理解才知道這邊要放啥
self.data = data # feature
self.label = label # label
def __len__(self):
return len(self.data)

def __getitem__(self, index):
x_data = torch.FloatTensor(self.data[index])
y_label = torch.FloatTensor(self.label[index])
return (x_data, y_label)

以上__init__(), __len()__, __getitem__()是必要的,名稱不能改。

建立好類別後就分別建立訓練和測試的dataset

train_dataset = CustomDataset(train_data,train_label)
test_dataset = CustomDataset(test_data, test_label)

remark: CustomDataset這個名稱可以按照識別度自己取

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: