mirror of https://github.com/n00mkrad/flowframes
47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
import os
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
import random
|
|
import glob
|
|
|
|
|
|
class Middelburry(Dataset):
|
|
def __init__(self, data_root , ext="png"):
|
|
|
|
super().__init__()
|
|
|
|
self.data_root = data_root
|
|
self.file_list = os.listdir(self.data_root)
|
|
|
|
self.transforms = transforms.Compose([
|
|
transforms.ToTensor()
|
|
])
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
imgpath = os.path.join(self.data_root , self.file_list[idx])
|
|
name = self.file_list[idx]
|
|
if name == "Teddy": ## Handle inputs with just two inout frames. FLAVR takes atleast 4.
|
|
imgpaths = [os.path.join(imgpath , "frame10.png") , os.path.join(imgpath , "frame10.png") ,os.path.join(imgpath , "frame11.png") ,os.path.join(imgpath , "frame11.png") ]
|
|
else:
|
|
imgpaths = [os.path.join(imgpath , "frame09.png") , os.path.join(imgpath , "frame10.png") ,os.path.join(imgpath , "frame11.png") ,os.path.join(imgpath , "frame12.png") ]
|
|
|
|
images = [Image.open(img).convert('RGB') for img in imgpaths]
|
|
images = [self.transforms(img) for img in images]
|
|
|
|
sizes = images[0].shape
|
|
|
|
return images , name
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.file_list)
|
|
|
|
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True):
|
|
|
|
dataset = Middelburry(data_root)
|
|
return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
|