flowframes/Pkgs/flavr-cuda/dataset/Middleburry.py

47 lines
1.5 KiB
Python

import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
import glob
class Middelburry(Dataset):
def __init__(self, data_root , ext="png"):
super().__init__()
self.data_root = data_root
self.file_list = os.listdir(self.data_root)
self.transforms = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, idx):
imgpath = os.path.join(self.data_root , self.file_list[idx])
name = self.file_list[idx]
if name == "Teddy": ## Handle inputs with just two inout frames. FLAVR takes atleast 4.
imgpaths = [os.path.join(imgpath , "frame10.png") , os.path.join(imgpath , "frame10.png") ,os.path.join(imgpath , "frame11.png") ,os.path.join(imgpath , "frame11.png") ]
else:
imgpaths = [os.path.join(imgpath , "frame09.png") , os.path.join(imgpath , "frame10.png") ,os.path.join(imgpath , "frame11.png") ,os.path.join(imgpath , "frame12.png") ]
images = [Image.open(img).convert('RGB') for img in imgpaths]
images = [self.transforms(img) for img in images]
sizes = images[0].shape
return images , name
def __len__(self):
return len(self.file_list)
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True):
dataset = Middelburry(data_root)
return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)