xvfi-cuda package

This commit is contained in:
n00mkrad 2021-08-15 14:32:22 +02:00
parent 71b02a39de
commit d998ef2a6b
4 changed files with 1880 additions and 0 deletions

496
Pkgs/xvfi-cuda/XVFInet.py Normal file
View File

@ -0,0 +1,496 @@
import functools, random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class XVFInet(nn.Module):
def __init__(self, args):
super(XVFInet, self).__init__()
self.args = args
self.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)"
self.nf = args.nf
self.scale = args.module_scale_factor
self.vfinet = VFInet(args)
self.lrelu = nn.ReLU()
self.in_channels = 3
self.channel_converter = nn.Sequential(
nn.Conv3d(self.in_channels, self.nf, [1, 3, 3], [1, 1, 1], [0, 1, 1]),
nn.ReLU())
self.rec_ext_ds_module = [self.channel_converter]
self.rec_ext_ds = nn.Conv3d(self.nf, self.nf, [1, 3, 3], [1, 2, 2], [0, 1, 1])
for _ in range(int(np.log2(self.scale))):
self.rec_ext_ds_module.append(self.rec_ext_ds)
self.rec_ext_ds_module.append(nn.ReLU())
self.rec_ext_ds_module.append(nn.Conv3d(self.nf, self.nf, [1, 3, 3], 1, [0, 1, 1]))
self.rec_ext_ds_module.append(RResBlock2D_3D(args, T_reduce_flag=False))
self.rec_ext_ds_module = nn.Sequential(*self.rec_ext_ds_module)
self.rec_ctx_ds = nn.Conv3d(self.nf, self.nf, [1, 3, 3], [1, 2, 2], [0, 1, 1])
print("The lowest scale depth for training (S_trn): ", self.args.S_trn)
print("The lowest scale depth for test (S_tst): ", self.args.S_tst)
def forward(self, x, t_value, is_training=True):
'''
x shape : [B,C,T,H,W]
t_value shape : [B,1] ###############
'''
B, C, T, H, W = x.size()
B2, C2 = t_value.size()
assert C2 == 1, "t_value shape is [B,]"
assert T % 2 == 0, "T must be an even number"
t_value = t_value.view(B, 1, 1, 1)
flow_l = None
feat_x = self.rec_ext_ds_module(x)
feat_x_list = [feat_x]
self.lowest_depth_level = self.args.S_trn if is_training else self.args.S_tst
for level in range(1, self.lowest_depth_level+1):
feat_x = self.rec_ctx_ds(feat_x)
feat_x_list.append(feat_x)
if is_training:
out_l_list = []
flow_refine_l_list = []
out_l, flow_l, flow_refine_l = self.vfinet(x, feat_x_list[self.args.S_trn], flow_l, t_value, level=self.args.S_trn, is_training=True)
out_l_list.append(out_l)
flow_refine_l_list.append(flow_refine_l)
for level in range(self.args.S_trn-1, 0, -1): ## self.args.S_trn, self.args.S_trn-1, ..., 1. level 0 is not included
out_l, flow_l = self.vfinet(x, feat_x_list[level], flow_l, t_value, level=level, is_training=True)
out_l_list.append(out_l)
out_l, flow_l, flow_refine_l, occ_0_l0 = self.vfinet(x, feat_x_list[0], flow_l, t_value, level=0, is_training=True)
out_l_list.append(out_l)
flow_refine_l_list.append(flow_refine_l)
return out_l_list[::-1], flow_refine_l_list[::-1], occ_0_l0, torch.mean(x, dim=2) # out_l_list should be reversed. [out_l0, out_l1, ...]
else: # Testing
for level in range(self.args.S_tst, 0, -1): ## self.args.S_tst, self.args.S_tst-1, ..., 1. level 0 is not included
flow_l = self.vfinet(x, feat_x_list[level], flow_l, t_value, level=level, is_training=False)
out_l = self.vfinet(x, feat_x_list[0], flow_l, t_value, level=0, is_training=False)
return out_l
class VFInet(nn.Module):
def __init__(self, args):
super(VFInet, self).__init__()
self.args = args
self.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)"
self.nf = args.nf
self.scale = args.module_scale_factor
self.in_channels = 3
self.conv_flow_bottom = nn.Sequential(
nn.Conv2d(2*self.nf, 2*self.nf, [4,4], 2, [1,1]),
nn.ReLU(),
nn.Conv2d(2*self.nf, 4*self.nf, [4,4], 2, [1,1]),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]),
nn.ReLU(),
nn.Conv2d(self.nf, 6, [3,3], 1, [1,1]),
)
self.conv_flow1 = nn.Conv2d(2*self.nf, self.nf, [3, 3], 1, [1, 1])
self.conv_flow2 = nn.Sequential(
nn.Conv2d(2*self.nf + 4, 2 * self.nf, [4, 4], 2, [1, 1]),
nn.ReLU(),
nn.Conv2d(2 * self.nf, 4 * self.nf, [4, 4], 2, [1, 1]),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]),
nn.ReLU(),
nn.Conv2d(self.nf, 6, [3, 3], 1, [1, 1]),
)
self.conv_flow3 = nn.Sequential(
nn.Conv2d(4 + self.nf * 4, self.nf, [1, 1], 1, [0, 0]),
nn.ReLU(),
nn.Conv2d(self.nf, 2 * self.nf, [4, 4], 2, [1, 1]),
nn.ReLU(),
nn.Conv2d(2 * self.nf, 4 * self.nf, [4, 4], 2, [1, 1]),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]),
nn.ReLU(),
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]),
nn.ReLU(),
nn.Conv2d(self.nf, 4, [3, 3], 1, [1, 1]),
)
self.refine_unet = RefineUNet(args)
self.lrelu = nn.ReLU()
def forward(self, x, feat_x, flow_l_prev, t_value, level, is_training):
'''
x shape : [B,C,T,H,W]
t_value shape : [B,1] ###############
'''
B, C, T, H, W = x.size()
assert T % 2 == 0, "T must be an even number"
####################### For a single level
l = 2 ** level
x_l = x.permute(0,2,1,3,4)
x_l = x_l.contiguous().view(B * T, C, H, W)
if level == 0:
pass
else:
x_l = F.interpolate(x_l, scale_factor=(1.0 / l, 1.0 / l), mode='bicubic', align_corners=False)
'''
Down pixel-shuffle
'''
x_l = x_l.view(B, T, C, H//l, W//l)
x_l = x_l.permute(0,2,1,3,4)
B, C, T, H, W = x_l.size()
## Feature extraction
feat0_l = feat_x[:,:,0,:,:]
feat1_l = feat_x[:,:,1,:,:]
## Flow estimation
if flow_l_prev is None:
flow_l_tmp = self.conv_flow_bottom(torch.cat((feat0_l, feat1_l), dim=1))
flow_l = flow_l_tmp[:,:4,:,:]
else:
up_flow_l_prev = 2.0*F.interpolate(flow_l_prev.detach(), scale_factor=(2,2), mode='bilinear', align_corners=False)
warped_feat1_l = self.bwarp(feat1_l, up_flow_l_prev[:,:2,:,:])
warped_feat0_l = self.bwarp(feat0_l, up_flow_l_prev[:,2:,:,:])
flow_l_tmp = self.conv_flow2(torch.cat([self.conv_flow1(torch.cat([feat0_l, warped_feat1_l],dim=1)), self.conv_flow1(torch.cat([feat1_l, warped_feat0_l],dim=1)), up_flow_l_prev],dim=1))
flow_l = flow_l_tmp[:,:4,:,:] + up_flow_l_prev
if not is_training and level!=0:
return flow_l
flow_01_l = flow_l[:,:2,:,:]
flow_10_l = flow_l[:,2:,:,:]
z_01_l = torch.sigmoid(flow_l_tmp[:,4:5,:,:])
z_10_l = torch.sigmoid(flow_l_tmp[:,5:6,:,:])
## Complementary Flow Reversal (CFR)
flow_forward, norm0_l = self.z_fwarp(flow_01_l, t_value * flow_01_l, z_01_l) ## Actually, F (t) -> (t+1). Translation only. Not normalized yet
flow_backward, norm1_l = self.z_fwarp(flow_10_l, (1-t_value) * flow_10_l, z_10_l) ## Actually, F (1-t) -> (-t). Translation only. Not normalized yet
flow_t0_l = -(1-t_value) * ((t_value)*flow_forward) + (t_value) * ((t_value)*flow_backward) # The numerator of Eq.(1) in the paper.
flow_t1_l = (1-t_value) * ((1-t_value)*flow_forward) - (t_value) * ((1-t_value)*flow_backward) # The numerator of Eq.(2) in the paper.
norm_l = (1-t_value)*norm0_l + t_value*norm1_l
mask_ = (norm_l.detach() > 0).type(norm_l.type())
flow_t0_l = (1-mask_) * flow_t0_l + mask_ * (flow_t0_l.clone() / (norm_l.clone() + (1-mask_))) # Divide the numerator with denominator in Eq.(1)
flow_t1_l = (1-mask_) * flow_t1_l + mask_ * (flow_t1_l.clone() / (norm_l.clone() + (1-mask_))) # Divide the numerator with denominator in Eq.(2)
## Feature warping
warped0_l = self.bwarp(feat0_l, flow_t0_l)
warped1_l = self.bwarp(feat1_l, flow_t1_l)
## Flow refinement
flow_refine_l = torch.cat([feat0_l, warped0_l, warped1_l, feat1_l, flow_t0_l, flow_t1_l], dim=1)
flow_refine_l = self.conv_flow3(flow_refine_l) + torch.cat([flow_t0_l, flow_t1_l], dim=1)
flow_t0_l = flow_refine_l[:, :2, :, :]
flow_t1_l = flow_refine_l[:, 2:4, :, :]
warped0_l = self.bwarp(feat0_l, flow_t0_l)
warped1_l = self.bwarp(feat1_l, flow_t1_l)
## Flow upscale
flow_t0_l = self.scale * F.interpolate(flow_t0_l, scale_factor=(self.scale, self.scale), mode='bilinear',align_corners=False)
flow_t1_l = self.scale * F.interpolate(flow_t1_l, scale_factor=(self.scale, self.scale), mode='bilinear',align_corners=False)
## Image warping and blending
warped_img0_l = self.bwarp(x_l[:,:,0,:,:], flow_t0_l)
warped_img1_l = self.bwarp(x_l[:,:,1,:,:], flow_t1_l)
refine_out = self.refine_unet(torch.cat([F.pixel_shuffle(torch.cat([feat0_l, feat1_l, warped0_l, warped1_l],dim=1), self.scale), x_l[:,:,0,:,:], x_l[:,:,1,:,:], warped_img0_l, warped_img1_l, flow_t0_l, flow_t1_l],dim=1))
occ_0_l = torch.sigmoid(refine_out[:, 0:1, :, :])
occ_1_l = 1-occ_0_l
out_l = (1-t_value)*occ_0_l*warped_img0_l + t_value*occ_1_l*warped_img1_l
out_l = out_l / ( (1-t_value)*occ_0_l + t_value*occ_1_l ) + refine_out[:, 1:4, :, :]
if not is_training and level==0:
return out_l
if is_training:
if flow_l_prev is None:
# if level == self.args.S_trn:
return out_l, flow_l, flow_refine_l[:, 0:4, :, :]
elif level != 0:
return out_l, flow_l
else: # level==0
return out_l, flow_l, flow_refine_l[:, 0:4, :, :], occ_0_l
def bwarp(self, x, flo):
'''
x: [B, C, H, W] (im2)
flo: [B, 2, H, W] flow
'''
B, C, H, W = x.size()
# mesh grid
xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W)
yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W)
grid = torch.cat((xx, yy), 1).float()
if x.is_cuda:
grid = grid.to(self.device)
vgrid = torch.autograd.Variable(grid) + flo
# scale grid to [-1,1]
vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
vgrid = vgrid.permute(0, 2, 3, 1) # [B,H,W,2]
output = nn.functional.grid_sample(x, vgrid, align_corners=True)
mask = torch.autograd.Variable(torch.ones(x.size())).to(self.device)
mask = nn.functional.grid_sample(mask, vgrid, align_corners=True)
# mask[mask<0.9999] = 0
# mask[mask>0] = 1
mask = mask.masked_fill_(mask < 0.999, 0)
mask = mask.masked_fill_(mask > 0, 1)
return output * mask
def fwarp(self, img, flo):
"""
-img: image (N, C, H, W)
-flo: optical flow (N, 2, H, W)
elements of flo is in [0, H] and [0, W] for dx, dy
https://github.com/lyh-18/EQVI/blob/EQVI-master/models/forward_warp_gaussian.py
"""
# (x1, y1) (x1, y2)
# +---------------+
# | |
# | o(x, y) |
# | |
# | |
# | |
# | |
# +---------------+
# (x2, y1) (x2, y2)
N, C, _, _ = img.size()
# translate start-point optical flow to end-point optical flow
y = flo[:, 0:1:, :]
x = flo[:, 1:2, :, :]
x = x.repeat(1, C, 1, 1)
y = y.repeat(1, C, 1, 1)
# Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2)
x1 = torch.floor(x)
x2 = x1 + 1
y1 = torch.floor(y)
y2 = y1 + 1
# firstly, get gaussian weights
w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2)
# secondly, sample each weighted corner
img11, o11 = self.sample_one(img, x1, y1, w11)
img12, o12 = self.sample_one(img, x1, y2, w12)
img21, o21 = self.sample_one(img, x2, y1, w21)
img22, o22 = self.sample_one(img, x2, y2, w22)
imgw = img11 + img12 + img21 + img22
o = o11 + o12 + o21 + o22
return imgw, o
def z_fwarp(self, img, flo, z):
"""
-img: image (N, C, H, W)
-flo: optical flow (N, 2, H, W)
elements of flo is in [0, H] and [0, W] for dx, dy
modified from https://github.com/lyh-18/EQVI/blob/EQVI-master/models/forward_warp_gaussian.py
"""
# (x1, y1) (x1, y2)
# +---------------+
# | |
# | o(x, y) |
# | |
# | |
# | |
# | |
# +---------------+
# (x2, y1) (x2, y2)
N, C, _, _ = img.size()
# translate start-point optical flow to end-point optical flow
y = flo[:, 0:1:, :]
x = flo[:, 1:2, :, :]
x = x.repeat(1, C, 1, 1)
y = y.repeat(1, C, 1, 1)
# Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2)
x1 = torch.floor(x)
x2 = x1 + 1
y1 = torch.floor(y)
y2 = y1 + 1
# firstly, get gaussian weights
w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2, z+1e-5)
# secondly, sample each weighted corner
img11, o11 = self.sample_one(img, x1, y1, w11)
img12, o12 = self.sample_one(img, x1, y2, w12)
img21, o21 = self.sample_one(img, x2, y1, w21)
img22, o22 = self.sample_one(img, x2, y2, w22)
imgw = img11 + img12 + img21 + img22
o = o11 + o12 + o21 + o22
return imgw, o
def get_gaussian_weights(self, x, y, x1, x2, y1, y2, z=1.0):
# z 0.0 ~ 1.0
w11 = z * torch.exp(-((x - x1) ** 2 + (y - y1) ** 2))
w12 = z * torch.exp(-((x - x1) ** 2 + (y - y2) ** 2))
w21 = z * torch.exp(-((x - x2) ** 2 + (y - y1) ** 2))
w22 = z * torch.exp(-((x - x2) ** 2 + (y - y2) ** 2))
return w11, w12, w21, w22
def sample_one(self, img, shiftx, shifty, weight):
"""
Input:
-img (N, C, H, W)
-shiftx, shifty (N, c, H, W)
"""
N, C, H, W = img.size()
# flatten all (all restored as Tensors)
flat_shiftx = shiftx.view(-1)
flat_shifty = shifty.view(-1)
flat_basex = torch.arange(0, H, requires_grad=False).view(-1, 1)[None, None].to(self.device).long().repeat(N, C,1,W).view(-1)
flat_basey = torch.arange(0, W, requires_grad=False).view(1, -1)[None, None].to(self.device).long().repeat(N, C,H,1).view(-1)
flat_weight = weight.view(-1)
flat_img = img.contiguous().view(-1)
# The corresponding positions in I1
idxn = torch.arange(0, N, requires_grad=False).view(N, 1, 1, 1).to(self.device).long().repeat(1, C, H, W).view(-1)
idxc = torch.arange(0, C, requires_grad=False).view(1, C, 1, 1).to(self.device).long().repeat(N, 1, H, W).view(-1)
idxx = flat_shiftx.long() + flat_basex
idxy = flat_shifty.long() + flat_basey
# recording the inside part the shifted
mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W)
# Mask off points out of boundaries
ids = (idxn * C * H * W + idxc * H * W + idxx * W + idxy)
ids_mask = torch.masked_select(ids, mask).clone().to(self.device)
# Note here! accmulate fla must be true for proper bp
img_warp = torch.zeros([N * C * H * W, ]).to(self.device)
img_warp.put_(ids_mask, torch.masked_select(flat_img * flat_weight, mask), accumulate=True)
one_warp = torch.zeros([N * C * H * W, ]).to(self.device)
one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True)
return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W)
class RefineUNet(nn.Module):
def __init__(self, args):
super(RefineUNet, self).__init__()
self.args = args
self.scale = args.module_scale_factor
self.nf = args.nf
self.conv1 = nn.Conv2d(self.nf, self.nf, [3,3], 1, [1,1])
self.conv2 = nn.Conv2d(self.nf, self.nf, [3,3], 1, [1,1])
self.lrelu = nn.ReLU()
self.NN = nn.UpsamplingNearest2d(scale_factor=2)
self.enc1 = nn.Conv2d((4*self.nf)//self.scale//self.scale + 4*args.img_ch + 4, self.nf, [4, 4], 2, [1, 1])
self.enc2 = nn.Conv2d(self.nf, 2*self.nf, [4, 4], 2, [1, 1])
self.enc3 = nn.Conv2d(2*self.nf, 4*self.nf, [4, 4], 2, [1, 1])
self.dec0 = nn.Conv2d(4*self.nf, 4*self.nf, [3, 3], 1, [1, 1])
self.dec1 = nn.Conv2d(4*self.nf + 2*self.nf, 2*self.nf, [3, 3], 1, [1, 1]) ## input concatenated with enc2
self.dec2 = nn.Conv2d(2*self.nf + self.nf, self.nf, [3, 3], 1, [1, 1]) ## input concatenated with enc1
self.dec3 = nn.Conv2d(self.nf, 1+args.img_ch, [3, 3], 1, [1, 1]) ## input added with warped image
def forward(self, concat):
enc1 = self.lrelu(self.enc1(concat))
enc2 = self.lrelu(self.enc2(enc1))
out = self.lrelu(self.enc3(enc2))
out = self.lrelu(self.dec0(out))
out = self.NN(out)
out = torch.cat((out,enc2),dim=1)
out = self.lrelu(self.dec1(out))
out = self.NN(out)
out = torch.cat((out,enc1),dim=1)
out = self.lrelu(self.dec2(out))
out = self.NN(out)
out = self.dec3(out)
return out
class ResBlock2D_3D(nn.Module):
## Shape of input [B,C,T,H,W]
## Shape of output [B,C,T,H,W]
def __init__(self, args):
super(ResBlock2D_3D, self).__init__()
self.args = args
self.nf = args.nf
self.conv3x3_1 = nn.Conv3d(self.nf, self.nf, [1,3,3], 1, [0,1,1])
self.conv3x3_2 = nn.Conv3d(self.nf, self.nf, [1,3,3], 1, [0,1,1])
self.lrelu = nn.ReLU()
def forward(self, x):
'''
x shape : [B,C,T,H,W]
'''
B, C, T, H, W = x.size()
out = self.conv3x3_2(self.lrelu(self.conv3x3_1(x)))
return x + out
class RResBlock2D_3D(nn.Module):
def __init__(self, args, T_reduce_flag=False):
super(RResBlock2D_3D, self).__init__()
self.args = args
self.nf = args.nf
self.T_reduce_flag = T_reduce_flag
self.resblock1 = ResBlock2D_3D(self.args)
self.resblock2 = ResBlock2D_3D(self.args)
if T_reduce_flag:
self.reduceT_conv = nn.Conv3d(self.nf, self.nf, [3,1,1], 1, [0,0,0])
def forward(self, x):
'''
x shape : [B,C,T,H,W]
'''
out = self.resblock1(x)
out = self.resblock2(out)
if self.T_reduce_flag:
return self.reduceT_conv(out + x)
else:
return out + x

410
Pkgs/xvfi-cuda/main.py Normal file
View File

@ -0,0 +1,410 @@
import argparse, os, shutil, time, random, torch, cv2, datetime, torch.utils.data, math
import torch.backends.cudnn as cudnn
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from utils import *
from XVFInet import *
from collections import Counter
def parse_args():
desc = "PyTorch implementation for XVFI"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--gpu', type=int, default=0, help='gpu index')
parser.add_argument('--net_type', type=str, default='XVFInet', choices=['XVFInet'], help='The type of Net')
parser.add_argument('--net_object', default=XVFInet, choices=[XVFInet], help='The type of Net')
parser.add_argument('--exp_num', type=int, default=1, help='The experiment number')
parser.add_argument('--phase', type=str, default='test_custom', choices=['train', 'test', 'test_custom', 'metrics_evaluation',])
parser.add_argument('--continue_training', action='store_true', default=False, help='continue the training')
""" Information of directories """
parser.add_argument('--test_img_dir', type=str, default='./test_img_dir', help='test_img_dir path')
parser.add_argument('--text_dir', type=str, default='./text_dir', help='text_dir path')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint_dir', help='checkpoint_dir')
parser.add_argument('--log_dir', type=str, default='./log_dir', help='Directory name to save training logs')
parser.add_argument('--dataset', default='X4K1000FPS', choices=['X4K1000FPS', 'Vimeo'],
help='Training/test Dataset')
# parser.add_argument('--train_data_path', type=str, default='./X4K1000FPS/train')
# parser.add_argument('--val_data_path', type=str, default='./X4K1000FPS/val')
# parser.add_argument('--test_data_path', type=str, default='./X4K1000FPS/test')
parser.add_argument('--train_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/train')
parser.add_argument('--val_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/val')
parser.add_argument('--test_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/test')
parser.add_argument('--vimeo_data_path', type=str, default='./vimeo_triplet')
""" Hyperparameters for Training (when [phase=='train']) """
parser.add_argument('--epochs', type=int, default=200, help='The number of epochs to run')
parser.add_argument('--freq_display', type=int, default=100, help='The number of iterations frequency for display')
parser.add_argument('--save_img_num', type=int, default=4,
help='The number of saved image while training for visualization. It should smaller than the batch_size')
parser.add_argument('--init_lr', type=float, default=1e-4, help='The initial learning rate')
parser.add_argument('--lr_dec_fac', type=float, default=0.25, help='step - lr_decreasing_factor')
parser.add_argument('--lr_milestones', type=int, default=[100, 150, 180])
parser.add_argument('--lr_dec_start', type=int, default=0,
help='When scheduler is StepLR, lr decreases from epoch at lr_dec_start')
parser.add_argument('--batch_size', type=int, default=8, help='The size of batch size.')
parser.add_argument('--weight_decay', type=float, default=0, help='for optim., weight decay (default: 0)')
parser.add_argument('--need_patch', default=True, help='get patch form image while training')
parser.add_argument('--img_ch', type=int, default=3, help='base number of channels for image')
parser.add_argument('--nf', type=int, default=64, help='base number of channels for feature maps') # 64
parser.add_argument('--module_scale_factor', type=int, default=4, help='sptial reduction for pixelshuffle')
parser.add_argument('--patch_size', type=int, default=384, help='patch size')
parser.add_argument('--num_thrds', type=int, default=4, help='number of threads for data loading')
parser.add_argument('--loss_type', default='L1', choices=['L1', 'MSE', 'L1_Charbonnier_loss'], help='Loss type')
parser.add_argument('--S_trn', type=int, default=3, help='The lowest scale depth for training')
parser.add_argument('--S_tst', type=int, default=5, help='The lowest scale depth for test')
""" Weighting Parameters Lambda for Losses (when [phase=='train']) """
parser.add_argument('--rec_lambda', type=float, default=1.0, help='Lambda for Reconstruction Loss')
""" Settings for Testing (when [phase=='test' or 'test_custom']) """
parser.add_argument('--saving_flow_flag', default=False)
parser.add_argument('--multiple', type=int, default=8, help='Due to the indexing problem of the file names, we recommend to use the power of 2. (e.g. 2, 4, 8, 16 ...). CAUTION : For the provided X-TEST, multiple should be one of [2, 4, 8, 16, 32].')
parser.add_argument('--metrics_types', type=list, default=["PSNR", "SSIM", "tOF"], choices=["PSNR", "SSIM", "tOF"])
""" Settings for test_custom (when [phase=='test_custom']) """
parser.add_argument('--custom_path', type=str, default='./custom_path', help='path for custom video containing frames')
parser.add_argument('--output', type=str, default='./interp', help='output path')
parser.add_argument('--input', type=str, default='./frames', help='input path')
parser.add_argument('--img_format', type=str, default="png")
parser.add_argument('--mdl_dir', type=str)
return check_args(parser.parse_args())
def check_args(args):
# --checkpoint_dir
check_folder(args.checkpoint_dir)
# --text_dir
check_folder(args.text_dir)
# --log_dir
check_folder(args.log_dir)
# --test_img_dir
check_folder(args.test_img_dir)
return args
def main():
args = parse_args()
if args.dataset == 'Vimeo':
if args.phase != 'test_custom':
args.multiple = 2
args.S_trn = 1
args.S_tst = 1
args.module_scale_factor = 2
args.patch_size = 256
args.batch_size = 16
print('vimeo triplet data dir : ', args.vimeo_data_path)
print("Exp:", args.exp_num)
args.model_dir = args.net_type + '_' + args.dataset + '_exp' + str(
args.exp_num) # ex) model_dir = "XVFInet_X4K1000FPS_exp1"
if args is None:
exit()
for arg in vars(args):
print('# {} : {}'.format(arg, getattr(args, arg)))
device = torch.device(
'cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)"
torch.cuda.set_device(device) # change allocation of current GPU
# caution!!!! if not "torch.cuda.set_device()":
# RuntimeError: grid_sampler(): expected input and grid to be on same device, but input is on cuda:1 and grid is on cuda:0
print('Available devices: ', torch.cuda.device_count())
print('Current cuda device: ', torch.cuda.current_device())
print('Current cuda device name: ', torch.cuda.get_device_name(device))
if args.gpu is not None:
print("Use GPU: {} is used".format(args.gpu))
SM = save_manager(args)
""" Initialize a model """
model_net = args.net_object(args).apply(weights_init).to(device)
criterion = [set_rec_loss(args).to(device), set_smoothness_loss().to(device)]
# to enable the inbuilt cudnn auto-tuner
# to find the best algorithm to use for your hardware.
cudnn.benchmark = True
if args.phase == "train":
train(model_net, criterion, device, SM, args)
epoch = args.epochs - 1
elif args.phase == "test" or args.phase == "metrics_evaluation" or args.phase == 'test_custom':
checkpoint = SM.load_model(args.mdl_dir)
model_net.load_state_dict(checkpoint['state_dict_Model'])
epoch = checkpoint['last_epoch']
postfix = '_final_x' + str(args.multiple) + '_S_tst' + str(args.S_tst)
if args.phase != "metrics_evaluation":
print("\n-------------------------------------- Final Test starts -------------------------------------- ")
print('Evaluate on test set (final test) with multiple = %d ' % (args.multiple))
final_test_loader = get_test_data(args, multiple=args.multiple,
validation=False) # multiple is only used for X4K1000FPS
final_pred_save_path = test(final_test_loader, model_net,
criterion, epoch,
args, device,
multiple=args.multiple,
postfix=postfix, validation=False)
#SM.write_info('Final 4k frames PSNR : {:.4}\n'.format(testPSNR))
if args.dataset == 'X4K1000FPS' and args.phase != 'test_custom':
final_pred_save_path = os.path.join(args.test_img_dir, args.model_dir, 'epoch_' + str(epoch).zfill(5)) + postfix
metrics_evaluation_X_Test(final_pred_save_path, args.test_data_path, args.metrics_types,
flow_flag=args.saving_flow_flag, multiple=args.multiple)
print("------------------------- Test has been ended. -------------------------\n")
quit()
print("Exp:", args.exp_num)
SM = save_manager
multi_scale_recon_loss = criterion[0]
smoothness_loss = criterion[1]
optimizer = optim.Adam(model_net.parameters(), lr=args.init_lr, betas=(0.9, 0.999),
weight_decay=args.weight_decay) # optimizer
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=args.lr_dec_fac)
last_epoch = 0
best_PSNR = 0.0
if args.continue_training:
checkpoint = SM.load_model()
last_epoch = checkpoint['last_epoch'] + 1
best_PSNR = checkpoint['best_PSNR']
model_net.load_state_dict(checkpoint['state_dict_Model'])
optimizer.load_state_dict(checkpoint['state_dict_Optimizer'])
scheduler.load_state_dict(checkpoint['state_dict_Scheduler'])
print("Optimizer and Scheduler have been reloaded. ")
scheduler.milestones = Counter(args.lr_milestones)
scheduler.gamma = args.lr_dec_fac
print("scheduler.milestones : {}, scheduler.gamma : {}".format(scheduler.milestones, scheduler.gamma))
start_epoch = last_epoch
# switch to train mode
model_net.train()
start_time = time.time()
#SM.write_info('Epoch\ttrainLoss\ttestPSNR\tbest_PSNR\n')
#print("[*] Training starts")
# Main training loop for total epochs (start from 'epoch=0')
valid_loader = get_test_data(args, multiple=4, validation=True) # multiple is only used for X4K1000FPS
for epoch in range(start_epoch, args.epochs):
train_loader = get_train_data(args,
max_t_step_size=32) # max_t_step_size (temporal distance) is only used for X4K1000FPS
batch_time = AverageClass('batch_time[s]:', ':6.3f')
losses = AverageClass('Loss:', ':.4e')
progress = ProgressMeter(len(train_loader), batch_time, losses, prefix="Epoch: [{}]".format(epoch))
print('Start epoch {} at [{:s}], learning rate : [{}]'.format(epoch, (str(datetime.now())[:-7]),
optimizer.param_groups[0]['lr']))
# train for one epoch
for trainIndex, (frames, t_value) in enumerate(train_loader):
input_frames = frames[:, :, :-1, :] # [B, C, T, H, W]
frameT = frames[:, :, -1, :] # [B, C, H, W]
# Getting the input and the target from the training set
input_frames = Variable(input_frames.to(device))
frameT = Variable(frameT.to(device)) # ground truth for frameT
t_value = Variable(t_value.to(device)) # [B,1]
optimizer.zero_grad()
# compute output
pred_frameT_pyramid, pred_flow_pyramid, occ_map, simple_mean = model_net(input_frames, t_value)
rec_loss = 0.0
smooth_loss = 0.0
for l, pred_frameT_l in enumerate(pred_frameT_pyramid):
rec_loss += args.rec_lambda * multi_scale_recon_loss(pred_frameT_l,
F.interpolate(frameT, scale_factor=1 / (2 ** l),
mode='bicubic', align_corners=False))
smooth_loss += 0.5 * smoothness_loss(pred_flow_pyramid[0],
F.interpolate(frameT, scale_factor=1 / args.module_scale_factor,
mode='bicubic',
align_corners=False)) # Apply 1st order edge-aware smoothness loss to the fineset level
rec_loss /= len(pred_frameT_pyramid)
pred_frameT = pred_frameT_pyramid[0] # final result I^0_t at original scale (s=0)
pred_coarse_flow = 2 ** (args.S_trn) * F.interpolate(pred_flow_pyramid[-1], scale_factor=2 ** (
args.S_trn) * args.module_scale_factor, mode='bicubic', align_corners=False)
pred_fine_flow = F.interpolate(pred_flow_pyramid[0], scale_factor=args.module_scale_factor, mode='bicubic',
align_corners=False)
total_loss = rec_loss + smooth_loss
# compute gradient and do SGD step
total_loss.backward() # Backpropagate
optimizer.step() # Optimizer update
# measure accumulated time and update average "batch" time consumptions via "AverageClass"
# update average values via "AverageClass"
losses.update(total_loss.item(), 1)
batch_time.update(time.time() - start_time)
start_time = time.time()
if trainIndex % args.freq_display == 0:
progress.print(trainIndex)
batch_images = get_batch_images(args, save_img_num=args.save_img_num,
save_images=[pred_frameT, pred_coarse_flow, pred_fine_flow, frameT,
simple_mean, occ_map])
cv2.imwrite(os.path.join(args.log_dir, '{:03d}_{:04d}_training.png'.format(epoch, trainIndex)), batch_images)
if epoch >= args.lr_dec_start:
scheduler.step()
# if (epoch + 1) % 10 == 0 or epoch==0:
val_multiple = 4 if args.dataset == 'X4K1000FPS' else 2
print('\nEvaluate on test set (validation while training) with multiple = {}'.format(val_multiple))
postfix = '_val_' + str(val_multiple) + '_S_tst' + str(args.S_tst)
final_pred_save_path = test(valid_loader, model_net, criterion, epoch, args,
device, multiple=val_multiple, postfix=postfix,
validation=True)
# remember best best_PSNR and best_SSIM and save checkpoint
#print("best_PSNR : {:.3f}, testPSNR : {:.3f}".format(best_PSNR, testPSNR))
best_PSNR_flag = testPSNR > best_PSNR
best_PSNR = max(testPSNR, best_PSNR)
# save checkpoint.
combined_state_dict = {
'net_type': args.net_type,
'last_epoch': epoch,
'batch_size': args.batch_size,
'trainLoss': losses.avg,
'testLoss': testLoss,
'testPSNR': testPSNR,
'best_PSNR': best_PSNR,
'state_dict_Model': model_net.state_dict(),
'state_dict_Optimizer': optimizer.state_dict(),
'state_dict_Scheduler': scheduler.state_dict()}
SM.save_best_model(combined_state_dict, best_PSNR_flag)
if (epoch + 1) % 10 == 0:
SM.save_epc_model(combined_state_dict, epoch)
SM.write_info('{}\t{:.4}\t{:.4}\t{:.4}\n'.format(epoch, losses.avg, testPSNR, best_PSNR))
print("------------------------- Training has been ended. -------------------------\n")
print("information of model:", args.model_dir)
print("best_PSNR of model:", best_PSNR)
def test(test_loader, model_net, criterion, epoch, args, device, multiple, postfix, validation):
#os.chdir(interp_output_path)
batch_time = AverageClass('Time:', ':6.3f')
losses = AverageClass('testLoss:', ':.4e')
PSNRs = AverageClass('testPSNR:', ':.4e')
SSIMs = AverageClass('testSSIM:', ':.4e')
args.divide = 2 ** (args.S_tst) * args.module_scale_factor * 4
# progress = ProgressMeter(len(test_loader), batch_time, accm_time, losses, PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch))
progress = ProgressMeter(len(test_loader), PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch))
multi_scale_recon_loss = criterion[0]
# switch to evaluate mode
model_net.eval()
counter = 1
copied_src_frames = list()
last_frame = ""
print("------------------------------------------- Test ----------------------------------------------")
with torch.no_grad():
start_time = time.time()
for testIndex, (frames, t_value, scene_name, frameRange) in enumerate(test_loader):
# Shape of 'frames' : [1,C,T+1,H,W]
frameT = frames[:, :, -1, :, :] # [1,C,H,W]
It_Path, I0_Path, I1_Path = frameRange
#print(I0_Path)
#print(I1_Path)
input_filename = str(I0_Path).split("'")[1];
input_filename_next = str(I1_Path).split("'")[1];
last_frame = input_filename_next
frameT = Variable(frameT.to(device)) # ground truth for frameT
t_value = Variable(t_value.to(device))
if (testIndex % (multiple - 1)) == 0:
input_frames = frames[:, :, :-1, :, :] # [1,C,T,H,W]
input_frames = Variable(input_frames.to(device))
B, C, T, H, W = input_frames.size()
H_padding = (args.divide - H % args.divide) % args.divide
W_padding = (args.divide - W % args.divide) % args.divide
if H_padding != 0 or W_padding != 0:
input_frames = F.pad(input_frames, (0, W_padding, 0, H_padding), "constant")
pred_frameT = model_net(input_frames, t_value, is_training=False)
if H_padding != 0 or W_padding != 0:
pred_frameT = pred_frameT[:, :, :H, :W]
epoch_save_path = args.custom_path
scene_save_path = os.path.join(epoch_save_path, scene_name[0])
pred_frameT = np.squeeze(pred_frameT.detach().cpu().numpy())
test = np.squeeze(frameT.detach().cpu().numpy())
output_img = np.around(denorm255_np(np.transpose(pred_frameT, [1, 2, 0]))) # [h,w,c] and [-1,1] to [0,255]
#print(os.path.join(scene_save_path, It_Path[0]))
frame_src_path = os.path.join(args.custom_path, args.output, '{:0>8d}.{}'.format(counter, args.img_format))
src_frame_path = os.path.join(args.custom_path, args.input, input_filename)
if os.path.isfile(src_frame_path):
if src_frame_path in copied_src_frames:
#print(f"Not copying source frame '{src_frame_path}' because it has already been copied before! - {len(copied_src_frames)}")
pass
else:
print(f"S => {os.path.basename(src_frame_path)} => {os.path.basename(frame_src_path)}")
shutil.copy(src_frame_path, frame_src_path)
copied_src_frames.append(src_frame_path)
counter += 1
frame_interp_path = os.path.join(args.custom_path, args.output, '{:0>8d}.{}'.format(counter, args.img_format))
print(f"I => {os.path.basename(frame_interp_path)}")
cv2.imwrite(frame_interp_path, output_img.astype(np.uint8))
counter += 1
#losses.update(0.0, 1)
#PSNRs.update(0.0, 1)
#SSIMs.update(0.0, 1)
print("-----------------------------------------------------------------------------------------------")
frame_src_path = os.path.join(args.custom_path, args.output, '{:0>8d}.{}'.format(counter, args.img_format))
print(f"LAST S => {frame_src_path}")
src_frame_path = os.path.join(args.custom_path, args.input, last_frame)
shutil.copy(src_frame_path, frame_src_path)
return epoch_save_path
if __name__ == '__main__':
main()

View File

@ -0,0 +1,13 @@
[
{
"name": "Vimeo",
"desc": "Model trained on Vimeo90K dataset",
"dir": "vimeo"
},
{
"name": "X4K1000FPS",
"desc": "Model trained on X4K1000FPS dataset",
"dir": "x4k1000fps",
"isDefault": "true"
}
]

961
Pkgs/xvfi-cuda/utils.py Normal file
View File

@ -0,0 +1,961 @@
from __future__ import division
import os, glob, sys, torch, shutil, random, math, time, cv2
import numpy as np
import torch.utils.data as data
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from datetime import datetime
from torch.nn import init
from skimage.measure import compare_ssim
from skimage.metrics import structural_similarity
from torch.autograd import Variable
from torchvision import models
class save_manager():
def __init__(self, args):
self.args = args
self.model_dir = self.args.net_type + '_' + self.args.dataset + '_exp' + str(self.args.exp_num)
print("model_dir:", self.model_dir)
# ex) model_dir = "XVFInet_exp1"
self.checkpoint_dir = os.path.join(self.args.checkpoint_dir, self.model_dir)
# './checkpoint_dir/XVFInet_exp1"
#check_folder(self.checkpoint_dir)
#print("checkpoint_dir:", self.checkpoint_dir)
#self.text_dir = os.path.join(self.args.text_dir, self.model_dir)
#print("text_dir:", self.text_dir)
#""" Save a text file """
#if not os.path.exists(self.text_dir + '.txt'):
# self.log_file = open(self.text_dir + '.txt', 'w')
# # "w" - Write - Opens a file for writing, creates the file if it does not exist
# self.log_file.write('----- Model parameters -----\n')
# self.log_file.write(str(datetime.now())[:-7] + '\n')
# for arg in vars(self.args):
# self.log_file.write('{} : {}\n'.format(arg, getattr(self.args, arg)))
# # ex) ./text_dir/XVFInet_exp1.txt
# self.log_file.close()
# "a" - Append - Opens a file for appending, creates the file if it does not exist
def write_info(self, strings):
self.log_file = open(self.text_dir + '.txt', 'a')
self.log_file.write(strings)
self.log_file.close()
def save_best_model(self, combined_state_dict, best_PSNR_flag):
file_name = os.path.join(self.checkpoint_dir, self.model_dir + '_latest.pt')
# file_name = "./checkpoint_dir/XVFInet_exp1/XVFInet_exp1_latest.ckpt
torch.save(combined_state_dict, file_name)
if best_PSNR_flag:
shutil.copyfile(file_name, os.path.join(self.checkpoint_dir, self.model_dir + '_best_PSNR.pt'))
# file_path = "./checkpoint_dir/XVFInet_exp1/XVFInet_exp1_best_PSNR.ckpt
def save_epc_model(self, combined_state_dict, epoch):
file_name = os.path.join(self.checkpoint_dir, self.model_dir + '_epc' + str(epoch) + '.pt')
# file_name = "./checkpoint_dir/XVFInet_exp1/XVFInet_exp1_epc10.ckpt
torch.save(combined_state_dict, file_name)
def load_epc_model(self, epoch):
checkpoint = torch.load(os.path.join(self.checkpoint_dir, self.model_dir + '_epc' + str(epoch - 1) + '.pt'))
print("load model '{}', epoch: {}, best_PSNR: {:3f}".format(
os.path.join(self.checkpoint_dir, self.model_dir + '_epc' + str(epoch - 1) + '.pt'), checkpoint['last_epoch'] + 1,
checkpoint['best_PSNR']))
return checkpoint
def load_model(self, mdl_dir):
# checkpoint = torch.load(self.checkpoint_dir + '/' + self.model_dir + '_latest.pt')
checkpoint = torch.load(os.path.join(mdl_dir, "checkpoint.pt"), map_location='cuda:0')
print("load model '{}', epoch: {},".format(
os.path.join(mdl_dir, "checkpoint.pt"), checkpoint['last_epoch'] + 1))
return checkpoint
def load_best_PSNR_model(self, ):
checkpoint = torch.load(os.path.join(self.checkpoint_dir, self.model_dir + '_best_PSNR.pt'))
print("load _best_PSNR model '{}', epoch: {}, best_PSNR: {:3f}, best_SSIM: {:3f}".format(
os.path.join(self.checkpoint_dir, self.model_dir + '_best_PSNR.pt'), checkpoint['last_epoch'] + 1,
checkpoint['best_PSNR'], checkpoint['best_SSIM']))
return checkpoint
def check_folder(log_dir):
if not os.path.exists(log_dir):
os.makedirs(log_dir)
return log_dir
def weights_init(m):
classname = m.__class__.__name__
if (classname.find('Conv2d') != -1) or (classname.find('Conv3d') != -1):
init.xavier_normal_(m.weight)
# init.kaiming_normal_(m.weight, nonlinearity='relu')
if hasattr(m, 'bias') and m.bias is not None:
init.zeros_(m.bias)
def get_train_data(args, max_t_step_size):
if args.dataset == 'X4K1000FPS':
data_train = X_Train(args, max_t_step_size)
elif args.dataset == 'Vimeo':
data_train = Vimeo_Train(args)
dataloader = torch.utils.data.DataLoader(data_train, batch_size=args.batch_size, drop_last=True, shuffle=True,
num_workers=int(args.num_thrds), pin_memory=False)
return dataloader
def get_test_data(args, multiple, validation):
if args.dataset == 'X4K1000FPS' and args.phase != 'test_custom':
data_test = X_Test(args, multiple, validation) # 'validation' for validation while training for simplicity
elif args.dataset == 'Vimeo' and args.phase != 'test_custom':
data_test = Vimeo_Test(args, validation)
elif args.phase == 'test_custom':
data_test = Custom_Test(args, multiple)
dataloader = torch.utils.data.DataLoader(data_test, batch_size=1, drop_last=True, shuffle=False, pin_memory=False)
return dataloader
def frames_loader_train(args, candidate_frames, frameRange):
frames = []
for frameIndex in frameRange:
frame = cv2.imread(candidate_frames[frameIndex])
frames.append(frame)
(ih, iw, c) = frame.shape
frames = np.stack(frames, axis=0) # (T, H, W, 3)
if args.need_patch: ## random crop
ps = args.patch_size
ix = random.randrange(0, iw - ps + 1)
iy = random.randrange(0, ih - ps + 1)
frames = frames[:, iy:iy + ps, ix:ix + ps, :]
if random.random() < 0.5: # random horizontal flip
frames = frames[:, :, ::-1, :]
# No vertical flip
rot = random.randint(0, 3) # random rotate
frames = np.rot90(frames, rot, (1, 2))
""" np2Tensor [-1,1] normalized """
frames = RGBframes_np2Tensor(frames, args.img_ch)
return frames
def frames_loader_test(args, I0I1It_Path, validation):
frames = []
for path in I0I1It_Path:
frame = cv2.imread(path)
frames.append(frame)
(ih, iw, c) = frame.shape
frames = np.stack(frames, axis=0) # (T, H, W, 3)
if args.dataset == 'X4K1000FPS':
if validation:
ps = 512
ix = (iw - ps) // 2
iy = (ih - ps) // 2
frames = frames[:, iy:iy + ps, ix:ix + ps, :]
""" np2Tensor [-1,1] normalized """
frames = RGBframes_np2Tensor(frames, args.img_ch)
return frames
def RGBframes_np2Tensor(imgIn, channel):
## input : T, H, W, C
if channel == 1:
# rgb --> Y (gray)
imgIn = np.sum(imgIn * np.reshape([65.481, 128.553, 24.966], [1, 1, 1, 3]) / 255.0, axis=3,
keepdims=True) + 16.0
# to Tensor
ts = (3, 0, 1, 2) ############# dimension order should be [C, T, H, W]
imgIn = torch.Tensor(imgIn.transpose(ts).astype(float)).mul_(1.0)
# normalization [-1,1]
imgIn = (imgIn / 255.0 - 0.5) * 2
return imgIn
def make_2D_dataset_X_Train(dir):
framesPath = []
# Find and loop over all the clips in root `dir`.
for scene_path in sorted(glob.glob(os.path.join(dir, '*', ''))):
sample_paths = sorted(glob.glob(os.path.join(scene_path, '*', '')))
for sample_path in sample_paths:
frame65_list = []
for frame in sorted(glob.glob(os.path.join(sample_path, '*.png'))):
frame65_list.append(frame)
framesPath.append(frame65_list)
print("The number of total training samples : {} which has 65 frames each.".format(
len(framesPath))) ## 4408 folders which have 65 frames each
return framesPath
class X_Train(data.Dataset):
def __init__(self, args, max_t_step_size):
self.args = args
self.max_t_step_size = max_t_step_size
self.framesPath = make_2D_dataset_X_Train(self.args.train_data_path)
self.nScenes = len(self.framesPath)
# Raise error if no images found in train_data_path.
if self.nScenes == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + self.args.train_data_path + "\n"))
def __getitem__(self, idx):
t_step_size = random.randint(2, self.max_t_step_size)
t_list = np.linspace((1 / t_step_size), (1 - (1 / t_step_size)), (t_step_size - 1))
candidate_frames = self.framesPath[idx]
firstFrameIdx = random.randint(0, (64 - t_step_size))
interIdx = random.randint(1, t_step_size - 1) # relative index, 1~self.t_step_size-1
interFrameIdx = firstFrameIdx + interIdx # absolute index
t_value = t_list[interIdx - 1] # [0,1]
if (random.randint(0, 1)):
frameRange = [firstFrameIdx, firstFrameIdx + t_step_size, interFrameIdx]
else: ## temporally reversed order
frameRange = [firstFrameIdx + t_step_size, firstFrameIdx, interFrameIdx]
interIdx = t_step_size - interIdx # (self.t_step_size-1) ~ 1
t_value = 1.0 - t_value
frames = frames_loader_train(self.args, candidate_frames,
frameRange) # including "np2Tensor [-1,1] normalized"
return frames, np.expand_dims(np.array(t_value, dtype=np.float32), 0)
def __len__(self):
return self.nScenes
def make_2D_dataset_X_Test(dir, multiple, t_step_size):
""" make [I0,I1,It,t,scene_folder] """
""" 1D (accumulated) """
testPath = []
t = np.linspace((1 / multiple), (1 - (1 / multiple)), (multiple - 1))
for type_folder in sorted(glob.glob(os.path.join(dir, '*', ''))): # [type1,type2,type3,...]
for scene_folder in sorted(glob.glob(os.path.join(type_folder, '*', ''))): # [scene1,scene2,..]
frame_folder = sorted(glob.glob(scene_folder + '*.png')) # 32 multiple, ['00000.png',...,'00032.png']
for idx in range(0, len(frame_folder), t_step_size): # 0,32,64,...
if idx == len(frame_folder) - 1:
break
for mul in range(multiple - 1):
I0I1It_paths = []
I0I1It_paths.append(frame_folder[idx]) # I0 (fix)
I0I1It_paths.append(frame_folder[idx + t_step_size]) # I1 (fix)
I0I1It_paths.append(frame_folder[idx + int((t_step_size // multiple) * (mul + 1))]) # It
I0I1It_paths.append(t[mul])
I0I1It_paths.append(scene_folder.split(os.path.join(dir, ''))[-1]) # type1/scene1
testPath.append(I0I1It_paths)
return testPath
class X_Test(data.Dataset):
def __init__(self, args, multiple, validation):
self.args = args
self.multiple = multiple
self.validation = validation
if validation:
self.testPath = make_2D_dataset_X_Test(self.args.val_data_path, multiple, t_step_size=32)
else: ## test
self.testPath = make_2D_dataset_X_Test(self.args.test_data_path, multiple, t_step_size=32)
self.nIterations = len(self.testPath)
# Raise error if no images found in test_data_path.
if len(self.testPath) == 0:
if validation:
raise (RuntimeError("Found 0 files in subfolders of: " + self.args.val_data_path + "\n"))
else:
raise (RuntimeError("Found 0 files in subfolders of: " + self.args.test_data_path + "\n"))
def __getitem__(self, idx):
I0, I1, It, t_value, scene_name = self.testPath[idx]
I0I1It_Path = [I0, I1, It]
frames = frames_loader_test(self.args, I0I1It_Path, self.validation)
# including "np2Tensor [-1,1] normalized"
I0_path = I0.split(os.sep)[-1]
I1_path = I1.split(os.sep)[-1]
It_path = It.split(os.sep)[-1]
return frames, np.expand_dims(np.array(t_value, dtype=np.float32), 0), scene_name, [It_path, I0_path, I1_path]
def __len__(self):
return self.nIterations
class Vimeo_Train(data.Dataset):
def __init__(self, args):
self.args = args
self.t = 0.5
self.framesPath = []
f = open(os.path.join(args.vimeo_data_path, 'tri_trainlist.txt'),
'r') # '../Datasets/vimeo_triplet/sequences/tri_trainlist.txt'
while True:
scene_path = f.readline().split('\n')[0]
if not scene_path: break
frames_list = sorted(glob.glob(os.path.join(args.vimeo_data_path, 'sequences', scene_path,
'*.png'))) # '../Datasets/vimeo_triplet/sequences/%05d/%04d/*.png'
self.framesPath.append(frames_list)
f.close
# self.framesPath = self.framesPath[:20]
self.nScenes = len(self.framesPath)
if self.nScenes == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + args.vimeo_data_path + "\n"))
print("nScenes of Vimeo train triplet : ", self.nScenes)
def __getitem__(self, idx):
candidate_frames = self.framesPath[idx]
""" Randomly reverse frames """
if (random.randint(0, 1)):
frameRange = [0, 2, 1]
else:
frameRange = [2, 0, 1]
frames = frames_loader_train(self.args, candidate_frames,
frameRange) # including "np2Tensor [-1,1] normalized"
return frames, np.expand_dims(np.array(0.5, dtype=np.float32), 0)
def __len__(self):
return self.nScenes
class Vimeo_Test(data.Dataset):
def __init__(self, args, validation):
self.args = args
self.framesPath = []
f = open(os.path.join(args.vimeo_data_path, 'tri_testlist.txt'), 'r')
while True:
scene_path = f.readline().split('\n')[0]
if not scene_path: break
frames_list = sorted(glob.glob(os.path.join(args.vimeo_data_path, 'sequences', scene_path,
'*.png'))) # '../Datasets/vimeo_triplet/sequences/%05d/%04d/*.png'
self.framesPath.append(frames_list)
if validation:
self.framesPath = self.framesPath[::37]
f.close
self.num_scene = len(self.framesPath) # total test scenes
if len(self.framesPath) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + args.vimeo_data_path + "\n"))
else:
print("# of Vimeo triplet testset : ", self.num_scene)
def __getitem__(self, idx):
scene_name = self.framesPath[idx][0].split(os.sep)
scene_name = os.path.join(scene_name[-3], scene_name[-2])
I0, It, I1 = self.framesPath[idx]
I0I1It_Path = [I0, I1, It]
frames = frames_loader_test(self.args, I0I1It_Path, validation=False)
I0_path = I0.split(os.sep)[-1]
I1_path = I1.split(os.sep)[-1]
It_path = It.split(os.sep)[-1]
return frames, np.expand_dims(np.array(0.5, dtype=np.float32), 0), scene_name, [It_path, I0_path, I1_path]
def __len__(self):
return self.num_scene
def make_2D_dataset_Custom_Test(dir, multiple):
""" make [I0,I1,It,t,scene_folder] """
""" 1D (accumulated) """
testPath = []
t = np.linspace((1 / multiple), (1 - (1 / multiple)), (multiple - 1))
for scene_folder in sorted(glob.glob(os.path.join(dir, '*', ''))): # [scene1, scene2, scene3, ...]
frame_folder = sorted(glob.glob(scene_folder + '*.png')) # ex) ['00000.png',...,'00123.png']
for idx in range(0, len(frame_folder)):
if idx == len(frame_folder) - 1:
break
for suffix, mul in enumerate(range(multiple - 1)):
I0I1It_paths = []
I0I1It_paths.append(frame_folder[idx]) # I0 (fix)
I0I1It_paths.append(frame_folder[idx + 1]) # I1 (fix)
target_t_Idx = frame_folder[idx].split(os.sep)[-1].split('.')[0]+'_' + str(suffix).zfill(3) + '.png'
# ex) target t name: 00017.png => '00017_1.png'
I0I1It_paths.append(os.path.join(scene_folder, target_t_Idx)) # It
I0I1It_paths.append(t[mul]) # t
I0I1It_paths.append(frame_folder[idx].split(os.path.join(dir, ''))[-1].split(os.sep)[0]) # scene1
testPath.append(I0I1It_paths)
break # limit to 1 directory - nmkd
return testPath
# def make_2D_dataset_Custom_Test(dir):
# """ make [I0,I1,It,t,scene_folder] """
# """ 1D (accumulated) """
# testPath = []
# for scene_folder in sorted(glob.glob(os.path.join(dir, '*/'))): # [scene1, scene2, scene3, ...]
# frame_folder = sorted(glob.glob(scene_folder + '*.png')) # ex) ['00000.png',...,'00123.png']
# for idx in range(0, len(frame_folder)):
# if idx == len(frame_folder) - 1:
# break
# I0I1It_paths = []
# I0I1It_paths.append(frame_folder[idx]) # I0 (fix)
# I0I1It_paths.append(frame_folder[idx + 1]) # I1 (fix)
# target_t_Idx = frame_folder[idx].split('/')[-1].split('.')[0]+'_x2.png'
# # ex) target t name: 00017.png => '00017_1.png'
# I0I1It_paths.append(os.path.join(scene_folder, target_t_Idx)) # It
# I0I1It_paths.append(0.5) # t
# I0I1It_paths.append(frame_folder[idx].split(os.path.join(dir, ''))[-1].split('/')[0]) # scene1
# testPath.append(I0I1It_paths)
# for asdf in testPath:
# print(asdf)
# return testPath
class Custom_Test(data.Dataset):
def __init__(self, args, multiple):
self.args = args
self.multiple = multiple
self.testPath = make_2D_dataset_Custom_Test(self.args.custom_path, self.multiple)
self.nIterations = len(self.testPath)
# Raise error if no images found in test_data_path.
if len(self.testPath) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + self.args.custom_path + "\n"))
def __getitem__(self, idx):
I0, I1, It, t_value, scene_name = self.testPath[idx]
dummy_dir = I1 # due to there is not ground truth intermediate frame.
I0I1It_Path = [I0, I1, dummy_dir]
frames = frames_loader_test(self.args, I0I1It_Path, None)
# including "np2Tensor [-1,1] normalized"
I0_path = I0.split(os.sep)[-1]
I1_path = I1.split(os.sep)[-1]
It_path = It.split(os.sep)[-1]
return frames, np.expand_dims(np.array(t_value, dtype=np.float32), 0), scene_name, [It_path, I0_path, I1_path]
def __len__(self):
return self.nIterations
class L1_Charbonnier_loss(nn.Module):
"""L1 Charbonnierloss."""
def __init__(self):
super(L1_Charbonnier_loss, self).__init__()
self.epsilon = 1e-3
def forward(self, X, Y):
loss = torch.mean(torch.sqrt((X - Y) ** 2 + self.epsilon ** 2))
return loss
def set_rec_loss(args):
loss_type = args.loss_type
if loss_type == 'MSE':
lossfunction = nn.MSELoss()
elif loss_type == 'L1':
lossfunction = nn.L1Loss()
elif loss_type == 'L1_Charbonnier_loss':
lossfunction = L1_Charbonnier_loss()
return lossfunction
class AverageClass(object):
""" For convenience of averaging values """
""" refer from "https://github.com/pytorch/examples/blob/master/imagenet/main.py" """
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val, n=1.0):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} (avg:{avg' + self.fmt + '})'
# Accm_Time[s]: 1263.517 (avg:639.701) (<== if AverageClass('Accm_Time[s]:', ':6.3f'))
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
""" For convenience of printing diverse values by using "AverageClass" """
""" refer from "https://github.com/pytorch/examples/blob/master/imagenet/main.py" """
def __init__(self, num_batches, *meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def print(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
# # Epoch: [0][ 0/196]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def metrics_evaluation_X_Test(pred_save_path, test_data_path, metrics_types, flow_flag=False, multiple=8, server=None):
"""
pred_save_path = './test_img_dir/XVFInet_exp1/epoch_00099' when 'args.epochs=100'
test_data_path = ex) 'F:/Jihyong/4K_1000fps_dataset/VIC_4K_1000FPS/X_TEST'
format: -type1
-scene1
:
-scene5
-type2
:
-type3
:
-scene5
"metrics_types": ["PSNR", "SSIM", "LPIPS", "tOF", "tLP100"]
"flow_flag": option for saving motion visualization
"final_test_type": ['first_interval', 1, 2, 3, 4]
"multiple": x4, x8, x16, x32 for interpolation
"""
pred_framesPath = []
for type_folder in sorted(glob.glob(os.path.join(pred_save_path, '*', ''))): # [type1,type2,type3,...]
for scene_folder in sorted(glob.glob(os.path.join(type_folder, '*', ''))): # [scene1,scene2,..]
scene_framesPath = []
for frame_path in sorted(glob.glob(scene_folder + '*.png')):
scene_framesPath.append(frame_path)
pred_framesPath.append(scene_framesPath)
if len(pred_framesPath) == 0:
raise (RuntimeError("Found 0 files in " + pred_save_path + "\n"))
# GT_framesPath = make_2D_dataset_X_Test(test_data_path, multiple, t_step_size=32)
# pred_framesPath = make_2D_dataset_X_Test(pred_save_path, multiple, t_step_size=32)
# ex) pred_save_path: './test_img_dir/XVFInet_exp1/epoch_00099' when 'args.epochs=100'
# ex) framesPath: [['./VIC_4K_1000FPS/VIC_Test/Fast/003_TEST_Fast/00000.png',...], ..., []] 2D List, len=30
# ex) scenesFolder: ['Fast/003_TEST_Fast',...]
keys = metrics_types
len_dict = dict.fromkeys(keys, 0)
Total_avg_dict = dict.fromkeys(["TotalAvg_" + _ for _ in keys], 0)
Type1_dict = dict.fromkeys(["Type1Avg_" + _ for _ in keys], 0)
Type2_dict = dict.fromkeys(["Type2Avg_" + _ for _ in keys], 0)
Type3_dict = dict.fromkeys(["Type3Avg_" + _ for _ in keys], 0)
# LPIPSnet = dm.DistModel()
# LPIPSnet.initialize(model='net-lin', net='alex', use_gpu=True)
total_list_dict = {}
key_str = 'Metrics -->'
for key_i in keys:
total_list_dict[key_i] = []
key_str += ' ' + str(key_i)
key_str += ' will be measured.'
print(key_str)
for scene_idx, scene_folder in enumerate(pred_framesPath):
per_scene_list_dict = {}
for key_i in keys:
per_scene_list_dict[key_i] = []
pred_candidate = pred_framesPath[scene_idx] # get all frames in pred_framesPath
# GT_candidate = GT_framesPath[scene_idx] # get 4800 frames
# num_pred_frame_per_folder = len(pred_candidate)
# save_path = os.path.join(pred_save_path, pred_scenesFolder[scene_idx])
save_path = scene_folder[0]
# './test_img_dir/XVFInet_exp1/epoch_00099/type1/scene1'
# excluding both frame0 and frame1 (multiple of 32 indices)
for frameIndex, pred_frame in enumerate(pred_candidate):
# if server==87:
# GTinterFrameIdx = pred_frame.split('/')[-1] # ex) 8, when multiple = 4, # 87 server
# else:
# GTinterFrameIdx = pred_frame.split('\\')[-1] # ex) 8, when multiple = 4
# if not (GTinterFrameIdx % 32) == 0:
if frameIndex > 0 and frameIndex < multiple:
""" only compute predicted frames (excluding multiples of 32 indices), ex) 8, 16, 24, 40, 48, 56, ... """
output_img = cv2.imread(pred_frame).astype(np.float32) # BGR, [0,255]
target_img = cv2.imread(pred_frame.replace(pred_save_path, test_data_path)).astype(
np.float32) # BGR, [0,255]
pred_frame_split = pred_frame.split(os.sep)
msg = "[x%d] frame %s, " % (
multiple, os.path.join(pred_frame_split[-3], pred_frame_split[-2], pred_frame_split[-1])) # per frame
if "tOF" in keys: # tOF
# if (GTinterFrameIdx % 32) == int(32/multiple):
# if (frameIndex % multiple) == 1:
if frameIndex == 1:
# when first predicted frame in each interval
pre_out_grey = cv2.cvtColor(cv2.imread(pred_candidate[0]).astype(np.float32),
cv2.COLOR_BGR2GRAY) #### CAUTION BRG
# pre_tar_grey = cv2.cvtColor(cv2.imread(pred_candidate[0].replace(pred_save_path, test_data_path)), cv2.COLOR_BGR2GRAY) #### CAUTION BRG
pre_tar_grey = pre_out_grey #### CAUTION BRG
# if not H_match_flag or not W_match_flag:
# pre_tar_grey = pre_tar_grey[:new_t_H, :new_t_W, :]
# pre_tar_grey = pre_out_grey
output_grey = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
target_grey = cv2.cvtColor(target_img, cv2.COLOR_BGR2GRAY)
target_OF = cv2.calcOpticalFlowFarneback(pre_tar_grey, target_grey, None, 0.5, 3, 15, 3, 5, 1.2, 0)
output_OF = cv2.calcOpticalFlowFarneback(pre_out_grey, output_grey, None, 0.5, 3, 15, 3, 5, 1.2, 0)
# target_OF, ofy, ofx = crop_8x8(target_OF) #check for size reason
# output_OF, ofy, ofx = crop_8x8(output_OF)
OF_diff = np.absolute(target_OF - output_OF)
if flow_flag:
""" motion visualization """
flow_path = save_path + '_tOF_flow'
check_folder(flow_path)
# './test_img_dir/XVFInet_exp1/epoch_00099/Fast/003_TEST_Fast_tOF_flow'
tOFpath = os.path.join(flow_path, "tOF_flow_%05d.png" % (GTinterFrameIdx))
# ex) "./test_img_dir/epoch_005/Fast/003_TEST_Fast/00008_tOF" when start_idx=0, multiple=4, frameIndex=0
hsv = np.zeros_like(output_img) # check for size reason
hsv[..., 1] = 255
mag, ang = cv2.cartToPolar(OF_diff[..., 0], OF_diff[..., 1])
# print("tar max %02.6f, min %02.6f, avg %02.6f" % (mag.max(), mag.min(), mag.mean()))
maxV = 0.4
mag = np.clip(mag, 0.0, maxV) / maxV
hsv[..., 0] = ang * 180 / np.pi / 2
hsv[..., 2] = mag * 255.0 #
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
cv2.imwrite(tOFpath, bgr)
print("png for motion visualization has been saved in [%s]" %
(flow_path))
OF_diff_tmp = np.sqrt(np.sum(OF_diff * OF_diff, axis=-1)).mean() # l1 vector norm
# OF_diff, ofy, ofx = crop_8x8(OF_diff)
total_list_dict["tOF"].append(OF_diff_tmp)
per_scene_list_dict["tOF"].append(OF_diff_tmp)
msg += "tOF %02.2f, " % (total_list_dict["tOF"][-1])
pre_out_grey = output_grey
pre_tar_grey = target_grey
# target_img, ofy, ofx = crop_8x8(target_img)
# output_img, ofy, ofx = crop_8x8(output_img)
if "PSNR" in keys: # psnr
psnr_tmp = psnr(target_img, output_img)
total_list_dict["PSNR"].append(psnr_tmp)
per_scene_list_dict["PSNR"].append(psnr_tmp)
msg += "PSNR %02.2f" % (total_list_dict["PSNR"][-1])
if "SSIM" in keys: # ssim
ssim_tmp = ssim_bgr(target_img, output_img)
total_list_dict["SSIM"].append(ssim_tmp)
per_scene_list_dict["SSIM"].append(ssim_tmp)
msg += ", SSIM %02.2f" % (total_list_dict["SSIM"][-1])
# msg += ", crop (%d, %d)" % (ofy, ofx) # per frame (not scene)
print(msg)
""" after finishing one scene """
per_scene_pd_dict = {} # per scene
for cur_key in keys:
# save_path = './test_img_dir/XVFInet_exp1/epoch_00099/Fast/003_TEST_Fast'
num_data = cur_key + "_[x%d]_[%s]" % (multiple, save_path.split(os.sep)[-2]) # '003_TEST_Fast'
# num_data => ex) PSNR_[x8]_[041_TEST_Fast]
""" per scene """
per_scene_cur_list = np.float32(per_scene_list_dict[cur_key])
per_scene_pd_dict[num_data] = pd.Series(per_scene_cur_list) # dictionary
per_scene_num_data_sum = per_scene_cur_list.sum()
per_scene_num_data_len = per_scene_cur_list.shape[0]
per_scene_num_data_mean = per_scene_num_data_sum / per_scene_num_data_len
""" accumulation """
cur_list = np.float32(total_list_dict[cur_key])
num_data_sum = cur_list.sum()
num_data_len = cur_list.shape[0] # accum
num_data_mean = num_data_sum / num_data_len
print(" %s, (per scene) max %02.4f, min %02.4f, avg %02.4f" %
(num_data, per_scene_cur_list.max(), per_scene_cur_list.min(), per_scene_num_data_mean)) #
Total_avg_dict["TotalAvg_" + cur_key] = num_data_mean # accum, update every iteration.
len_dict[cur_key] = num_data_len # accum, update every iteration.
# folder_dict["FolderAvg_" + cur_key] += num_data_mean
if scene_idx < 5:
Type1_dict["Type1Avg_" + cur_key] += per_scene_num_data_mean
elif (scene_idx >= 5) and (scene_idx < 10):
Type2_dict["Type2Avg_" + cur_key] += per_scene_num_data_mean
elif (scene_idx >= 10) and (scene_idx < 15):
Type3_dict["Type3Avg_" + cur_key] += per_scene_num_data_mean
mode = 'w' if scene_idx == 0 else 'a'
total_csv_path = os.path.join(pred_save_path, "total_metrics.csv")
# ex) pred_save_path: './test_img_dir/XVFInet_exp1/epoch_00099' when 'args.epochs=100'
pd.DataFrame(per_scene_pd_dict).to_csv(total_csv_path, mode=mode)
""" combining all results after looping all scenes. """
for key in keys:
Total_avg_dict["TotalAvg_" + key] = pd.Series(
np.float32(Total_avg_dict["TotalAvg_" + key])) # replace key (update)
Type1_dict["Type1Avg_" + key] = pd.Series(np.float32(Type1_dict["Type1Avg_" + key] / 5)) # replace key (update)
Type2_dict["Type2Avg_" + key] = pd.Series(np.float32(Type2_dict["Type2Avg_" + key] / 5)) # replace key (update)
Type3_dict["Type3Avg_" + key] = pd.Series(np.float32(Type3_dict["Type3Avg_" + key] / 5)) # replace key (update)
print("%s, total frames %d, total avg %02.4f, Type1 avg %02.4f, Type2 avg %02.4f, Type3 avg %02.4f" %
(key, len_dict[key], Total_avg_dict["TotalAvg_" + key],
Type1_dict["Type1Avg_" + key], Type2_dict["Type2Avg_" + key], Type3_dict["Type3Avg_" + key]))
pd.DataFrame(Total_avg_dict).to_csv(total_csv_path, mode='a')
pd.DataFrame(Type1_dict).to_csv(total_csv_path, mode='a')
pd.DataFrame(Type2_dict).to_csv(total_csv_path, mode='a')
pd.DataFrame(Type3_dict).to_csv(total_csv_path, mode='a')
print("csv file of all metrics for all scenes has been saved in [%s]" %
(total_csv_path))
print("Finished.")
def to_uint8(x, vmin, vmax):
##### color space transform, originally from https://github.com/yhjo09/VSR-DUF #####
x = x.astype('float32')
x = (x - vmin) / (vmax - vmin) * 255 # 0~255
return np.clip(np.round(x), 0, 255)
def psnr(img_true, img_pred):
##### PSNR with color space transform, originally from https://github.com/yhjo09/VSR-DUF #####
"""
# img format : [h,w,c], RGB
"""
# Y_true = _rgb2ycbcr(to_uint8(img_true, 0, 255), 255)[:, :, 0]
# Y_pred = _rgb2ycbcr(to_uint8(img_pred, 0, 255), 255)[:, :, 0]
diff = img_true - img_pred
rmse = np.sqrt(np.mean(np.power(diff, 2)))
if rmse == 0:
return float('inf')
return 20 * np.log10(255. / rmse)
def ssim_bgr(img_true, img_pred): ##### SSIM for BGR, not RGB #####
"""
# img format : [h,w,c], BGR
"""
Y_true = _rgb2ycbcr(to_uint8(img_true, 0, 255)[:, :, ::-1], 255)[:, :, 0]
Y_pred = _rgb2ycbcr(to_uint8(img_pred, 0, 255)[:, :, ::-1], 255)[:, :, 0]
# return compare_ssim(Y_true, Y_pred, data_range=Y_pred.max() - Y_pred.min())
return structural_similarity(Y_true, Y_pred, data_range=Y_pred.max() - Y_pred.min())
def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
# [0,255]2[-1,1]2[1,3,H,W]-shaped
def denorm255(x):
out = (x + 1.0) / 2.0
return out.clamp_(0.0, 1.0) * 255.0
def denorm255_np(x):
# numpy
out = (x + 1.0) / 2.0
return out.clip(0.0, 1.0) * 255.0
def _rgb2ycbcr(img, maxVal=255):
##### color space transform, originally from https://github.com/yhjo09/VSR-DUF #####
O = np.array([[16],
[128],
[128]])
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
if maxVal == 1:
O = O / 255.0
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
t = np.dot(t, np.transpose(T))
t[:, 0] += O[0]
t[:, 1] += O[1]
t[:, 2] += O[2]
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
return ycbcr
class set_smoothness_loss(nn.Module):
def __init__(self, weight=150.0, edge_aware=True):
super(set_smoothness_loss, self).__init__()
self.edge_aware = edge_aware
self.weight = weight ** 2
def forward(self, flow, img):
img_gh = torch.mean(torch.pow((img[:, :, 1:, :] - img[:, :, :-1, :]), 2), dim=1, keepdims=True)
img_gw = torch.mean(torch.pow((img[:, :, :, 1:] - img[:, :, :, :-1]), 2), dim=1, keepdims=True)
weight_gh = torch.exp(-self.weight * img_gh)
weight_gw = torch.exp(-self.weight * img_gw)
flow_gh = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :])
flow_gw = torch.abs(flow[:, :, :, 1:] - flow[:, :, :, :-1])
if self.edge_aware:
return (torch.mean(weight_gh * flow_gh) + torch.mean(weight_gw * flow_gw)) * 0.5
else:
return (torch.mean(flow_gh) + torch.mean(flow_gw)) * 0.5
def get_batch_images(args, save_img_num, save_images): ## For visualization during training phase
width_num = len(save_images)
log_img = np.zeros((save_img_num * args.patch_size, width_num * args.patch_size, 3), dtype=np.uint8)
pred_frameT, pred_coarse_flow, pred_fine_flow, frameT, simple_mean, occ_map = save_images
for b in range(save_img_num):
output_img_tmp = denorm255(pred_frameT[b, :])
output_coarse_flow_tmp = pred_coarse_flow[b, :2, :, :]
output_fine_flow_tmp = pred_fine_flow[b, :2, :, :]
gt_img_tmp = denorm255(frameT[b, :])
simple_mean_img_tmp = denorm255(simple_mean[b, :])
occ_map_tmp = occ_map[b, :]
output_img_tmp = np.transpose(output_img_tmp.detach().cpu().numpy(), [1, 2, 0]).astype(np.uint8)
output_coarse_flow_tmp = flow2img(np.transpose(output_coarse_flow_tmp.detach().cpu().numpy(), [1, 2, 0]))
output_fine_flow_tmp = flow2img(np.transpose(output_fine_flow_tmp.detach().cpu().numpy(), [1, 2, 0]))
gt_img_tmp = np.transpose(gt_img_tmp.detach().cpu().numpy(), [1, 2, 0]).astype(np.uint8)
simple_mean_img_tmp = np.transpose(simple_mean_img_tmp.detach().cpu().numpy(), [1, 2, 0]).astype(np.uint8)
occ_map_tmp = np.transpose(occ_map_tmp.detach().cpu().numpy() * 255.0, [1, 2, 0]).astype(np.uint8)
occ_map_tmp = np.concatenate([occ_map_tmp, occ_map_tmp, occ_map_tmp], axis=2)
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 0 * args.patch_size:1 * args.patch_size,
:] = simple_mean_img_tmp
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 1 * args.patch_size:2 * args.patch_size,
:] = output_img_tmp
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 2 * args.patch_size:3 * args.patch_size,
:] = gt_img_tmp
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 3 * args.patch_size:4 * args.patch_size,
:] = output_coarse_flow_tmp
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 4 * args.patch_size:5 * args.patch_size,
:] = output_fine_flow_tmp
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 5 * args.patch_size:6 * args.patch_size,
:] = occ_map_tmp
return log_img
def flow2img(flow, logscale=True, scaledown=6, output=False):
"""
topleft is zero, u is horiz, v is vertical
red is 3 o'clock, yellow is 6, light blue is 9, blue/purple is 12
"""
u = flow[:, :, 1]
# u = flow[:, :, 0]
v = flow[:, :, 0]
# v = flow[:, :, 1]
colorwheel = makecolorwheel()
ncols = colorwheel.shape[0]
radius = np.sqrt(u ** 2 + v ** 2)
if output:
print("Maximum flow magnitude: %04f" % np.max(radius))
if logscale:
radius = np.log(radius + 1)
if output:
print("Maximum flow magnitude (after log): %0.4f" % np.max(radius))
radius = radius / scaledown
if output:
print("Maximum flow magnitude (after scaledown): %0.4f" % np.max(radius))
# rot = np.arctan2(-v, -u) / np.pi
rot = np.arctan2(v, u) / np.pi
fk = (rot + 1) / 2 * (ncols - 1) # -1~1 maped to 0~ncols
k0 = fk.astype(np.uint8) # 0, 1, 2, ..., ncols
k1 = k0 + 1
k1[k1 == ncols] = 0
f = fk - k0
ncolors = colorwheel.shape[1]
img = np.zeros(u.shape + (ncolors,))
for i in range(ncolors):
tmp = colorwheel[:, i]
col0 = tmp[k0]
col1 = tmp[k1]
col = (1 - f) * col0 + f * col1
idx = radius <= 1
# increase saturation with radius
col[idx] = 1 - radius[idx] * (1 - col[idx])
# out of range
col[~idx] *= 0.75
# img[:,:,i] = np.floor(255*col).astype(np.uint8)
img[:, :, i] = np.clip(255 * col, 0.0, 255.0).astype(np.uint8)
# return img.astype(np.uint8)
return img
def makecolorwheel():
# Create a colorwheel for visualization
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros((ncols, 3))
col = 0
# RY
colorwheel[col:col + RY, 0] = 1
colorwheel[col:col + RY, 1] = np.arange(0, 1, 1. / RY)
col += RY
# YG
colorwheel[col:col + YG, 0] = np.arange(1, 0, -1. / YG)
colorwheel[col:col + YG, 1] = 1
col += YG
# GC
colorwheel[col:col + GC, 1] = 1
colorwheel[col:col + GC, 2] = np.arange(0, 1, 1. / GC)
col += GC
# CB
colorwheel[col:col + CB, 1] = np.arange(1, 0, -1. / CB)
colorwheel[col:col + CB, 2] = 1
col += CB
# BM
colorwheel[col:col + BM, 2] = 1
colorwheel[col:col + BM, 0] = np.arange(0, 1, 1. / BM)
col += BM
# MR
colorwheel[col:col + MR, 2] = np.arange(1, 0, -1. / MR)
colorwheel[col:col + MR, 0] = 1
return colorwheel