Update dataset.py
This commit is contained in:
parent
5f5ffb5e85
commit
981ae76ca2
42
dataset.py
42
dataset.py
|
@ -35,9 +35,8 @@ class VimeoDataset(Dataset):
|
|||
self.meta_data = self.testlist
|
||||
else:
|
||||
self.meta_data = self.trainlist[cnt:]
|
||||
|
||||
|
||||
def aug(self, img0, gt, img1, h, w):
|
||||
|
||||
def crop(self, img0, gt, img1, h, w):
|
||||
ih, iw, _ = img0.shape
|
||||
x = np.random.randint(0, ih - h + 1)
|
||||
y = np.random.randint(0, iw - w + 1)
|
||||
|
@ -54,12 +53,24 @@ class VimeoDataset(Dataset):
|
|||
img0 = cv2.imread(imgpaths[0])
|
||||
gt = cv2.imread(imgpaths[1])
|
||||
img1 = cv2.imread(imgpaths[2])
|
||||
return img0, gt, img1
|
||||
timestep = 0.5
|
||||
return img0, gt, img1, timestep
|
||||
|
||||
# RIFEm with Vimeo-Septuplet
|
||||
# imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png', imgpath + '/im4.png', imgpath + '/im5.png', imgpath + '/im6.png', imgpath + '/im7.png']
|
||||
# ind = [0, 1, 2, 3, 4, 5, 6]
|
||||
# random.shuffle(ind)
|
||||
# ind = ind[:3]
|
||||
# ind.sort()
|
||||
# img0 = cv2.imread(imgpaths[ind[0]])
|
||||
# gt = cv2.imread(imgpaths[ind[1]])
|
||||
# img1 = cv2.imread(imgpaths[ind[2]])
|
||||
# timestep = (ind[1] - ind[0]) * 1.0 / (ind[2] - ind[0] + 1e-6)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img0, gt, img1 = self.getimg(index)
|
||||
img0, gt, img1, timestep = self.getimg(index)
|
||||
if self.dataset_name == 'train':
|
||||
img0, gt, img1 = self.aug(img0, gt, img1, 224, 224)
|
||||
img0, gt, img1 = self.crop(img0, gt, img1, 224, 224)
|
||||
if random.uniform(0, 1) < 0.5:
|
||||
img0 = img0[:, :, ::-1]
|
||||
img1 = img1[:, :, ::-1]
|
||||
|
@ -76,8 +87,23 @@ class VimeoDataset(Dataset):
|
|||
tmp = img1
|
||||
img1 = img0
|
||||
img0 = tmp
|
||||
# timestep = 1 - timestep
|
||||
timestep = 1 - timestep
|
||||
# random rotation
|
||||
p = random.uniform(0, 1)
|
||||
if p < 0.25:
|
||||
img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE)
|
||||
gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE)
|
||||
img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE)
|
||||
elif p < 0.5:
|
||||
img0 = cv2.rotate(img0, cv2.ROTATE_180)
|
||||
gt = cv2.rotate(gt, cv2.ROTATE_180)
|
||||
img1 = cv2.rotate(img1, cv2.ROTATE_180)
|
||||
elif p < 0.75:
|
||||
img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
|
||||
img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
|
||||
gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
|
||||
return torch.cat((img0, img1, gt), 0)
|
||||
timestep = torch.tensor(timestep).reshape(1, 1, 1)
|
||||
return torch.cat((img0, img1, gt), 0), timestep
|
||||
|
|
Loading…
Reference in New Issue