Update dataset.py

This commit is contained in:
hzwer 2022-04-11 11:43:19 +08:00 committed by GitHub
parent 5f5ffb5e85
commit 981ae76ca2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 34 additions and 8 deletions

View File

@ -35,9 +35,8 @@ class VimeoDataset(Dataset):
self.meta_data = self.testlist
else:
self.meta_data = self.trainlist[cnt:]
def aug(self, img0, gt, img1, h, w):
def crop(self, img0, gt, img1, h, w):
ih, iw, _ = img0.shape
x = np.random.randint(0, ih - h + 1)
y = np.random.randint(0, iw - w + 1)
@ -54,12 +53,24 @@ class VimeoDataset(Dataset):
img0 = cv2.imread(imgpaths[0])
gt = cv2.imread(imgpaths[1])
img1 = cv2.imread(imgpaths[2])
return img0, gt, img1
timestep = 0.5
return img0, gt, img1, timestep
# RIFEm with Vimeo-Septuplet
# imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png', imgpath + '/im4.png', imgpath + '/im5.png', imgpath + '/im6.png', imgpath + '/im7.png']
# ind = [0, 1, 2, 3, 4, 5, 6]
# random.shuffle(ind)
# ind = ind[:3]
# ind.sort()
# img0 = cv2.imread(imgpaths[ind[0]])
# gt = cv2.imread(imgpaths[ind[1]])
# img1 = cv2.imread(imgpaths[ind[2]])
# timestep = (ind[1] - ind[0]) * 1.0 / (ind[2] - ind[0] + 1e-6)
def __getitem__(self, index):
img0, gt, img1 = self.getimg(index)
img0, gt, img1, timestep = self.getimg(index)
if self.dataset_name == 'train':
img0, gt, img1 = self.aug(img0, gt, img1, 224, 224)
img0, gt, img1 = self.crop(img0, gt, img1, 224, 224)
if random.uniform(0, 1) < 0.5:
img0 = img0[:, :, ::-1]
img1 = img1[:, :, ::-1]
@ -76,8 +87,23 @@ class VimeoDataset(Dataset):
tmp = img1
img1 = img0
img0 = tmp
# timestep = 1 - timestep
timestep = 1 - timestep
# random rotation
p = random.uniform(0, 1)
if p < 0.25:
img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE)
gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE)
img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE)
elif p < 0.5:
img0 = cv2.rotate(img0, cv2.ROTATE_180)
gt = cv2.rotate(gt, cv2.ROTATE_180)
img1 = cv2.rotate(img1, cv2.ROTATE_180)
elif p < 0.75:
img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE)
gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE)
img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE)
img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
return torch.cat((img0, img1, gt), 0)
timestep = torch.tensor(timestep).reshape(1, 1, 1)
return torch.cat((img0, img1, gt), 0), timestep