mirror of https://github.com/n00mkrad/flowframes
xvfi-cuda package
This commit is contained in:
parent
71b02a39de
commit
d998ef2a6b
|
@ -0,0 +1,496 @@
|
|||
import functools, random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
|
||||
|
||||
class XVFInet(nn.Module):
|
||||
|
||||
def __init__(self, args):
|
||||
super(XVFInet, self).__init__()
|
||||
self.args = args
|
||||
self.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)"
|
||||
self.nf = args.nf
|
||||
self.scale = args.module_scale_factor
|
||||
self.vfinet = VFInet(args)
|
||||
self.lrelu = nn.ReLU()
|
||||
self.in_channels = 3
|
||||
self.channel_converter = nn.Sequential(
|
||||
nn.Conv3d(self.in_channels, self.nf, [1, 3, 3], [1, 1, 1], [0, 1, 1]),
|
||||
nn.ReLU())
|
||||
|
||||
self.rec_ext_ds_module = [self.channel_converter]
|
||||
self.rec_ext_ds = nn.Conv3d(self.nf, self.nf, [1, 3, 3], [1, 2, 2], [0, 1, 1])
|
||||
for _ in range(int(np.log2(self.scale))):
|
||||
self.rec_ext_ds_module.append(self.rec_ext_ds)
|
||||
self.rec_ext_ds_module.append(nn.ReLU())
|
||||
self.rec_ext_ds_module.append(nn.Conv3d(self.nf, self.nf, [1, 3, 3], 1, [0, 1, 1]))
|
||||
self.rec_ext_ds_module.append(RResBlock2D_3D(args, T_reduce_flag=False))
|
||||
self.rec_ext_ds_module = nn.Sequential(*self.rec_ext_ds_module)
|
||||
|
||||
self.rec_ctx_ds = nn.Conv3d(self.nf, self.nf, [1, 3, 3], [1, 2, 2], [0, 1, 1])
|
||||
|
||||
print("The lowest scale depth for training (S_trn): ", self.args.S_trn)
|
||||
print("The lowest scale depth for test (S_tst): ", self.args.S_tst)
|
||||
|
||||
def forward(self, x, t_value, is_training=True):
|
||||
'''
|
||||
x shape : [B,C,T,H,W]
|
||||
t_value shape : [B,1] ###############
|
||||
'''
|
||||
B, C, T, H, W = x.size()
|
||||
B2, C2 = t_value.size()
|
||||
assert C2 == 1, "t_value shape is [B,]"
|
||||
assert T % 2 == 0, "T must be an even number"
|
||||
t_value = t_value.view(B, 1, 1, 1)
|
||||
|
||||
flow_l = None
|
||||
feat_x = self.rec_ext_ds_module(x)
|
||||
feat_x_list = [feat_x]
|
||||
self.lowest_depth_level = self.args.S_trn if is_training else self.args.S_tst
|
||||
for level in range(1, self.lowest_depth_level+1):
|
||||
feat_x = self.rec_ctx_ds(feat_x)
|
||||
feat_x_list.append(feat_x)
|
||||
|
||||
if is_training:
|
||||
out_l_list = []
|
||||
flow_refine_l_list = []
|
||||
out_l, flow_l, flow_refine_l = self.vfinet(x, feat_x_list[self.args.S_trn], flow_l, t_value, level=self.args.S_trn, is_training=True)
|
||||
out_l_list.append(out_l)
|
||||
flow_refine_l_list.append(flow_refine_l)
|
||||
for level in range(self.args.S_trn-1, 0, -1): ## self.args.S_trn, self.args.S_trn-1, ..., 1. level 0 is not included
|
||||
out_l, flow_l = self.vfinet(x, feat_x_list[level], flow_l, t_value, level=level, is_training=True)
|
||||
out_l_list.append(out_l)
|
||||
out_l, flow_l, flow_refine_l, occ_0_l0 = self.vfinet(x, feat_x_list[0], flow_l, t_value, level=0, is_training=True)
|
||||
out_l_list.append(out_l)
|
||||
flow_refine_l_list.append(flow_refine_l)
|
||||
return out_l_list[::-1], flow_refine_l_list[::-1], occ_0_l0, torch.mean(x, dim=2) # out_l_list should be reversed. [out_l0, out_l1, ...]
|
||||
|
||||
else: # Testing
|
||||
for level in range(self.args.S_tst, 0, -1): ## self.args.S_tst, self.args.S_tst-1, ..., 1. level 0 is not included
|
||||
flow_l = self.vfinet(x, feat_x_list[level], flow_l, t_value, level=level, is_training=False)
|
||||
out_l = self.vfinet(x, feat_x_list[0], flow_l, t_value, level=0, is_training=False)
|
||||
return out_l
|
||||
|
||||
|
||||
class VFInet(nn.Module):
|
||||
|
||||
def __init__(self, args):
|
||||
super(VFInet, self).__init__()
|
||||
self.args = args
|
||||
self.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)"
|
||||
self.nf = args.nf
|
||||
self.scale = args.module_scale_factor
|
||||
self.in_channels = 3
|
||||
|
||||
self.conv_flow_bottom = nn.Sequential(
|
||||
nn.Conv2d(2*self.nf, 2*self.nf, [4,4], 2, [1,1]),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(2*self.nf, 4*self.nf, [4,4], 2, [1,1]),
|
||||
nn.ReLU(),
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]),
|
||||
nn.ReLU(),
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(self.nf, 6, [3,3], 1, [1,1]),
|
||||
)
|
||||
|
||||
self.conv_flow1 = nn.Conv2d(2*self.nf, self.nf, [3, 3], 1, [1, 1])
|
||||
|
||||
self.conv_flow2 = nn.Sequential(
|
||||
nn.Conv2d(2*self.nf + 4, 2 * self.nf, [4, 4], 2, [1, 1]),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(2 * self.nf, 4 * self.nf, [4, 4], 2, [1, 1]),
|
||||
nn.ReLU(),
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]),
|
||||
nn.ReLU(),
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(self.nf, 6, [3, 3], 1, [1, 1]),
|
||||
)
|
||||
|
||||
self.conv_flow3 = nn.Sequential(
|
||||
nn.Conv2d(4 + self.nf * 4, self.nf, [1, 1], 1, [0, 0]),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(self.nf, 2 * self.nf, [4, 4], 2, [1, 1]),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(2 * self.nf, 4 * self.nf, [4, 4], 2, [1, 1]),
|
||||
nn.ReLU(),
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]),
|
||||
nn.ReLU(),
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(self.nf, 4, [3, 3], 1, [1, 1]),
|
||||
)
|
||||
|
||||
self.refine_unet = RefineUNet(args)
|
||||
self.lrelu = nn.ReLU()
|
||||
|
||||
def forward(self, x, feat_x, flow_l_prev, t_value, level, is_training):
|
||||
'''
|
||||
x shape : [B,C,T,H,W]
|
||||
t_value shape : [B,1] ###############
|
||||
'''
|
||||
B, C, T, H, W = x.size()
|
||||
assert T % 2 == 0, "T must be an even number"
|
||||
|
||||
####################### For a single level
|
||||
l = 2 ** level
|
||||
x_l = x.permute(0,2,1,3,4)
|
||||
x_l = x_l.contiguous().view(B * T, C, H, W)
|
||||
|
||||
if level == 0:
|
||||
pass
|
||||
else:
|
||||
x_l = F.interpolate(x_l, scale_factor=(1.0 / l, 1.0 / l), mode='bicubic', align_corners=False)
|
||||
'''
|
||||
Down pixel-shuffle
|
||||
'''
|
||||
x_l = x_l.view(B, T, C, H//l, W//l)
|
||||
x_l = x_l.permute(0,2,1,3,4)
|
||||
|
||||
B, C, T, H, W = x_l.size()
|
||||
|
||||
## Feature extraction
|
||||
feat0_l = feat_x[:,:,0,:,:]
|
||||
feat1_l = feat_x[:,:,1,:,:]
|
||||
|
||||
## Flow estimation
|
||||
if flow_l_prev is None:
|
||||
flow_l_tmp = self.conv_flow_bottom(torch.cat((feat0_l, feat1_l), dim=1))
|
||||
flow_l = flow_l_tmp[:,:4,:,:]
|
||||
else:
|
||||
up_flow_l_prev = 2.0*F.interpolate(flow_l_prev.detach(), scale_factor=(2,2), mode='bilinear', align_corners=False)
|
||||
warped_feat1_l = self.bwarp(feat1_l, up_flow_l_prev[:,:2,:,:])
|
||||
warped_feat0_l = self.bwarp(feat0_l, up_flow_l_prev[:,2:,:,:])
|
||||
flow_l_tmp = self.conv_flow2(torch.cat([self.conv_flow1(torch.cat([feat0_l, warped_feat1_l],dim=1)), self.conv_flow1(torch.cat([feat1_l, warped_feat0_l],dim=1)), up_flow_l_prev],dim=1))
|
||||
flow_l = flow_l_tmp[:,:4,:,:] + up_flow_l_prev
|
||||
|
||||
if not is_training and level!=0:
|
||||
return flow_l
|
||||
|
||||
flow_01_l = flow_l[:,:2,:,:]
|
||||
flow_10_l = flow_l[:,2:,:,:]
|
||||
z_01_l = torch.sigmoid(flow_l_tmp[:,4:5,:,:])
|
||||
z_10_l = torch.sigmoid(flow_l_tmp[:,5:6,:,:])
|
||||
|
||||
## Complementary Flow Reversal (CFR)
|
||||
flow_forward, norm0_l = self.z_fwarp(flow_01_l, t_value * flow_01_l, z_01_l) ## Actually, F (t) -> (t+1). Translation only. Not normalized yet
|
||||
flow_backward, norm1_l = self.z_fwarp(flow_10_l, (1-t_value) * flow_10_l, z_10_l) ## Actually, F (1-t) -> (-t). Translation only. Not normalized yet
|
||||
|
||||
flow_t0_l = -(1-t_value) * ((t_value)*flow_forward) + (t_value) * ((t_value)*flow_backward) # The numerator of Eq.(1) in the paper.
|
||||
flow_t1_l = (1-t_value) * ((1-t_value)*flow_forward) - (t_value) * ((1-t_value)*flow_backward) # The numerator of Eq.(2) in the paper.
|
||||
|
||||
norm_l = (1-t_value)*norm0_l + t_value*norm1_l
|
||||
mask_ = (norm_l.detach() > 0).type(norm_l.type())
|
||||
flow_t0_l = (1-mask_) * flow_t0_l + mask_ * (flow_t0_l.clone() / (norm_l.clone() + (1-mask_))) # Divide the numerator with denominator in Eq.(1)
|
||||
flow_t1_l = (1-mask_) * flow_t1_l + mask_ * (flow_t1_l.clone() / (norm_l.clone() + (1-mask_))) # Divide the numerator with denominator in Eq.(2)
|
||||
|
||||
## Feature warping
|
||||
warped0_l = self.bwarp(feat0_l, flow_t0_l)
|
||||
warped1_l = self.bwarp(feat1_l, flow_t1_l)
|
||||
|
||||
## Flow refinement
|
||||
flow_refine_l = torch.cat([feat0_l, warped0_l, warped1_l, feat1_l, flow_t0_l, flow_t1_l], dim=1)
|
||||
flow_refine_l = self.conv_flow3(flow_refine_l) + torch.cat([flow_t0_l, flow_t1_l], dim=1)
|
||||
flow_t0_l = flow_refine_l[:, :2, :, :]
|
||||
flow_t1_l = flow_refine_l[:, 2:4, :, :]
|
||||
|
||||
warped0_l = self.bwarp(feat0_l, flow_t0_l)
|
||||
warped1_l = self.bwarp(feat1_l, flow_t1_l)
|
||||
|
||||
## Flow upscale
|
||||
flow_t0_l = self.scale * F.interpolate(flow_t0_l, scale_factor=(self.scale, self.scale), mode='bilinear',align_corners=False)
|
||||
flow_t1_l = self.scale * F.interpolate(flow_t1_l, scale_factor=(self.scale, self.scale), mode='bilinear',align_corners=False)
|
||||
|
||||
## Image warping and blending
|
||||
warped_img0_l = self.bwarp(x_l[:,:,0,:,:], flow_t0_l)
|
||||
warped_img1_l = self.bwarp(x_l[:,:,1,:,:], flow_t1_l)
|
||||
|
||||
refine_out = self.refine_unet(torch.cat([F.pixel_shuffle(torch.cat([feat0_l, feat1_l, warped0_l, warped1_l],dim=1), self.scale), x_l[:,:,0,:,:], x_l[:,:,1,:,:], warped_img0_l, warped_img1_l, flow_t0_l, flow_t1_l],dim=1))
|
||||
occ_0_l = torch.sigmoid(refine_out[:, 0:1, :, :])
|
||||
occ_1_l = 1-occ_0_l
|
||||
|
||||
out_l = (1-t_value)*occ_0_l*warped_img0_l + t_value*occ_1_l*warped_img1_l
|
||||
out_l = out_l / ( (1-t_value)*occ_0_l + t_value*occ_1_l ) + refine_out[:, 1:4, :, :]
|
||||
|
||||
if not is_training and level==0:
|
||||
return out_l
|
||||
|
||||
if is_training:
|
||||
if flow_l_prev is None:
|
||||
# if level == self.args.S_trn:
|
||||
return out_l, flow_l, flow_refine_l[:, 0:4, :, :]
|
||||
elif level != 0:
|
||||
return out_l, flow_l
|
||||
else: # level==0
|
||||
return out_l, flow_l, flow_refine_l[:, 0:4, :, :], occ_0_l
|
||||
|
||||
def bwarp(self, x, flo):
|
||||
'''
|
||||
x: [B, C, H, W] (im2)
|
||||
flo: [B, 2, H, W] flow
|
||||
'''
|
||||
B, C, H, W = x.size()
|
||||
# mesh grid
|
||||
xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W)
|
||||
yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W)
|
||||
grid = torch.cat((xx, yy), 1).float()
|
||||
|
||||
if x.is_cuda:
|
||||
grid = grid.to(self.device)
|
||||
vgrid = torch.autograd.Variable(grid) + flo
|
||||
|
||||
# scale grid to [-1,1]
|
||||
vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
|
||||
vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
|
||||
|
||||
vgrid = vgrid.permute(0, 2, 3, 1) # [B,H,W,2]
|
||||
output = nn.functional.grid_sample(x, vgrid, align_corners=True)
|
||||
mask = torch.autograd.Variable(torch.ones(x.size())).to(self.device)
|
||||
mask = nn.functional.grid_sample(mask, vgrid, align_corners=True)
|
||||
|
||||
# mask[mask<0.9999] = 0
|
||||
# mask[mask>0] = 1
|
||||
mask = mask.masked_fill_(mask < 0.999, 0)
|
||||
mask = mask.masked_fill_(mask > 0, 1)
|
||||
|
||||
return output * mask
|
||||
|
||||
def fwarp(self, img, flo):
|
||||
|
||||
"""
|
||||
-img: image (N, C, H, W)
|
||||
-flo: optical flow (N, 2, H, W)
|
||||
elements of flo is in [0, H] and [0, W] for dx, dy
|
||||
https://github.com/lyh-18/EQVI/blob/EQVI-master/models/forward_warp_gaussian.py
|
||||
"""
|
||||
|
||||
# (x1, y1) (x1, y2)
|
||||
# +---------------+
|
||||
# | |
|
||||
# | o(x, y) |
|
||||
# | |
|
||||
# | |
|
||||
# | |
|
||||
# | |
|
||||
# +---------------+
|
||||
# (x2, y1) (x2, y2)
|
||||
|
||||
N, C, _, _ = img.size()
|
||||
|
||||
# translate start-point optical flow to end-point optical flow
|
||||
y = flo[:, 0:1:, :]
|
||||
x = flo[:, 1:2, :, :]
|
||||
|
||||
x = x.repeat(1, C, 1, 1)
|
||||
y = y.repeat(1, C, 1, 1)
|
||||
|
||||
# Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2)
|
||||
x1 = torch.floor(x)
|
||||
x2 = x1 + 1
|
||||
y1 = torch.floor(y)
|
||||
y2 = y1 + 1
|
||||
|
||||
# firstly, get gaussian weights
|
||||
w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2)
|
||||
|
||||
# secondly, sample each weighted corner
|
||||
img11, o11 = self.sample_one(img, x1, y1, w11)
|
||||
img12, o12 = self.sample_one(img, x1, y2, w12)
|
||||
img21, o21 = self.sample_one(img, x2, y1, w21)
|
||||
img22, o22 = self.sample_one(img, x2, y2, w22)
|
||||
|
||||
imgw = img11 + img12 + img21 + img22
|
||||
o = o11 + o12 + o21 + o22
|
||||
|
||||
return imgw, o
|
||||
|
||||
|
||||
def z_fwarp(self, img, flo, z):
|
||||
"""
|
||||
-img: image (N, C, H, W)
|
||||
-flo: optical flow (N, 2, H, W)
|
||||
elements of flo is in [0, H] and [0, W] for dx, dy
|
||||
modified from https://github.com/lyh-18/EQVI/blob/EQVI-master/models/forward_warp_gaussian.py
|
||||
"""
|
||||
|
||||
# (x1, y1) (x1, y2)
|
||||
# +---------------+
|
||||
# | |
|
||||
# | o(x, y) |
|
||||
# | |
|
||||
# | |
|
||||
# | |
|
||||
# | |
|
||||
# +---------------+
|
||||
# (x2, y1) (x2, y2)
|
||||
|
||||
N, C, _, _ = img.size()
|
||||
|
||||
# translate start-point optical flow to end-point optical flow
|
||||
y = flo[:, 0:1:, :]
|
||||
x = flo[:, 1:2, :, :]
|
||||
|
||||
x = x.repeat(1, C, 1, 1)
|
||||
y = y.repeat(1, C, 1, 1)
|
||||
|
||||
# Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2)
|
||||
x1 = torch.floor(x)
|
||||
x2 = x1 + 1
|
||||
y1 = torch.floor(y)
|
||||
y2 = y1 + 1
|
||||
|
||||
# firstly, get gaussian weights
|
||||
w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2, z+1e-5)
|
||||
|
||||
# secondly, sample each weighted corner
|
||||
img11, o11 = self.sample_one(img, x1, y1, w11)
|
||||
img12, o12 = self.sample_one(img, x1, y2, w12)
|
||||
img21, o21 = self.sample_one(img, x2, y1, w21)
|
||||
img22, o22 = self.sample_one(img, x2, y2, w22)
|
||||
|
||||
imgw = img11 + img12 + img21 + img22
|
||||
o = o11 + o12 + o21 + o22
|
||||
|
||||
return imgw, o
|
||||
|
||||
|
||||
def get_gaussian_weights(self, x, y, x1, x2, y1, y2, z=1.0):
|
||||
# z 0.0 ~ 1.0
|
||||
w11 = z * torch.exp(-((x - x1) ** 2 + (y - y1) ** 2))
|
||||
w12 = z * torch.exp(-((x - x1) ** 2 + (y - y2) ** 2))
|
||||
w21 = z * torch.exp(-((x - x2) ** 2 + (y - y1) ** 2))
|
||||
w22 = z * torch.exp(-((x - x2) ** 2 + (y - y2) ** 2))
|
||||
|
||||
return w11, w12, w21, w22
|
||||
|
||||
def sample_one(self, img, shiftx, shifty, weight):
|
||||
"""
|
||||
Input:
|
||||
-img (N, C, H, W)
|
||||
-shiftx, shifty (N, c, H, W)
|
||||
"""
|
||||
|
||||
N, C, H, W = img.size()
|
||||
|
||||
# flatten all (all restored as Tensors)
|
||||
flat_shiftx = shiftx.view(-1)
|
||||
flat_shifty = shifty.view(-1)
|
||||
flat_basex = torch.arange(0, H, requires_grad=False).view(-1, 1)[None, None].to(self.device).long().repeat(N, C,1,W).view(-1)
|
||||
flat_basey = torch.arange(0, W, requires_grad=False).view(1, -1)[None, None].to(self.device).long().repeat(N, C,H,1).view(-1)
|
||||
flat_weight = weight.view(-1)
|
||||
flat_img = img.contiguous().view(-1)
|
||||
|
||||
# The corresponding positions in I1
|
||||
idxn = torch.arange(0, N, requires_grad=False).view(N, 1, 1, 1).to(self.device).long().repeat(1, C, H, W).view(-1)
|
||||
idxc = torch.arange(0, C, requires_grad=False).view(1, C, 1, 1).to(self.device).long().repeat(N, 1, H, W).view(-1)
|
||||
idxx = flat_shiftx.long() + flat_basex
|
||||
idxy = flat_shifty.long() + flat_basey
|
||||
|
||||
# recording the inside part the shifted
|
||||
mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W)
|
||||
|
||||
# Mask off points out of boundaries
|
||||
ids = (idxn * C * H * W + idxc * H * W + idxx * W + idxy)
|
||||
ids_mask = torch.masked_select(ids, mask).clone().to(self.device)
|
||||
|
||||
# Note here! accmulate fla must be true for proper bp
|
||||
img_warp = torch.zeros([N * C * H * W, ]).to(self.device)
|
||||
img_warp.put_(ids_mask, torch.masked_select(flat_img * flat_weight, mask), accumulate=True)
|
||||
|
||||
one_warp = torch.zeros([N * C * H * W, ]).to(self.device)
|
||||
one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True)
|
||||
|
||||
return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W)
|
||||
|
||||
class RefineUNet(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(RefineUNet, self).__init__()
|
||||
self.args = args
|
||||
self.scale = args.module_scale_factor
|
||||
self.nf = args.nf
|
||||
self.conv1 = nn.Conv2d(self.nf, self.nf, [3,3], 1, [1,1])
|
||||
self.conv2 = nn.Conv2d(self.nf, self.nf, [3,3], 1, [1,1])
|
||||
self.lrelu = nn.ReLU()
|
||||
self.NN = nn.UpsamplingNearest2d(scale_factor=2)
|
||||
|
||||
self.enc1 = nn.Conv2d((4*self.nf)//self.scale//self.scale + 4*args.img_ch + 4, self.nf, [4, 4], 2, [1, 1])
|
||||
self.enc2 = nn.Conv2d(self.nf, 2*self.nf, [4, 4], 2, [1, 1])
|
||||
self.enc3 = nn.Conv2d(2*self.nf, 4*self.nf, [4, 4], 2, [1, 1])
|
||||
self.dec0 = nn.Conv2d(4*self.nf, 4*self.nf, [3, 3], 1, [1, 1])
|
||||
self.dec1 = nn.Conv2d(4*self.nf + 2*self.nf, 2*self.nf, [3, 3], 1, [1, 1]) ## input concatenated with enc2
|
||||
self.dec2 = nn.Conv2d(2*self.nf + self.nf, self.nf, [3, 3], 1, [1, 1]) ## input concatenated with enc1
|
||||
self.dec3 = nn.Conv2d(self.nf, 1+args.img_ch, [3, 3], 1, [1, 1]) ## input added with warped image
|
||||
|
||||
def forward(self, concat):
|
||||
enc1 = self.lrelu(self.enc1(concat))
|
||||
enc2 = self.lrelu(self.enc2(enc1))
|
||||
out = self.lrelu(self.enc3(enc2))
|
||||
|
||||
out = self.lrelu(self.dec0(out))
|
||||
out = self.NN(out)
|
||||
|
||||
out = torch.cat((out,enc2),dim=1)
|
||||
out = self.lrelu(self.dec1(out))
|
||||
|
||||
out = self.NN(out)
|
||||
out = torch.cat((out,enc1),dim=1)
|
||||
out = self.lrelu(self.dec2(out))
|
||||
|
||||
out = self.NN(out)
|
||||
out = self.dec3(out)
|
||||
return out
|
||||
|
||||
class ResBlock2D_3D(nn.Module):
|
||||
## Shape of input [B,C,T,H,W]
|
||||
## Shape of output [B,C,T,H,W]
|
||||
def __init__(self, args):
|
||||
super(ResBlock2D_3D, self).__init__()
|
||||
self.args = args
|
||||
self.nf = args.nf
|
||||
|
||||
self.conv3x3_1 = nn.Conv3d(self.nf, self.nf, [1,3,3], 1, [0,1,1])
|
||||
self.conv3x3_2 = nn.Conv3d(self.nf, self.nf, [1,3,3], 1, [0,1,1])
|
||||
self.lrelu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x shape : [B,C,T,H,W]
|
||||
'''
|
||||
B, C, T, H, W = x.size()
|
||||
|
||||
out = self.conv3x3_2(self.lrelu(self.conv3x3_1(x)))
|
||||
|
||||
return x + out
|
||||
|
||||
class RResBlock2D_3D(nn.Module):
|
||||
|
||||
def __init__(self, args, T_reduce_flag=False):
|
||||
super(RResBlock2D_3D, self).__init__()
|
||||
self.args = args
|
||||
self.nf = args.nf
|
||||
self.T_reduce_flag = T_reduce_flag
|
||||
self.resblock1 = ResBlock2D_3D(self.args)
|
||||
self.resblock2 = ResBlock2D_3D(self.args)
|
||||
if T_reduce_flag:
|
||||
self.reduceT_conv = nn.Conv3d(self.nf, self.nf, [3,1,1], 1, [0,0,0])
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x shape : [B,C,T,H,W]
|
||||
'''
|
||||
out = self.resblock1(x)
|
||||
out = self.resblock2(out)
|
||||
if self.T_reduce_flag:
|
||||
return self.reduceT_conv(out + x)
|
||||
else:
|
||||
return out + x
|
|
@ -0,0 +1,410 @@
|
|||
import argparse, os, shutil, time, random, torch, cv2, datetime, torch.utils.data, math
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
from torch.autograd import Variable
|
||||
from utils import *
|
||||
from XVFInet import *
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def parse_args():
|
||||
desc = "PyTorch implementation for XVFI"
|
||||
parser = argparse.ArgumentParser(description=desc)
|
||||
parser.add_argument('--gpu', type=int, default=0, help='gpu index')
|
||||
parser.add_argument('--net_type', type=str, default='XVFInet', choices=['XVFInet'], help='The type of Net')
|
||||
parser.add_argument('--net_object', default=XVFInet, choices=[XVFInet], help='The type of Net')
|
||||
parser.add_argument('--exp_num', type=int, default=1, help='The experiment number')
|
||||
parser.add_argument('--phase', type=str, default='test_custom', choices=['train', 'test', 'test_custom', 'metrics_evaluation',])
|
||||
parser.add_argument('--continue_training', action='store_true', default=False, help='continue the training')
|
||||
|
||||
""" Information of directories """
|
||||
parser.add_argument('--test_img_dir', type=str, default='./test_img_dir', help='test_img_dir path')
|
||||
parser.add_argument('--text_dir', type=str, default='./text_dir', help='text_dir path')
|
||||
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint_dir', help='checkpoint_dir')
|
||||
parser.add_argument('--log_dir', type=str, default='./log_dir', help='Directory name to save training logs')
|
||||
|
||||
parser.add_argument('--dataset', default='X4K1000FPS', choices=['X4K1000FPS', 'Vimeo'],
|
||||
help='Training/test Dataset')
|
||||
|
||||
# parser.add_argument('--train_data_path', type=str, default='./X4K1000FPS/train')
|
||||
# parser.add_argument('--val_data_path', type=str, default='./X4K1000FPS/val')
|
||||
# parser.add_argument('--test_data_path', type=str, default='./X4K1000FPS/test')
|
||||
parser.add_argument('--train_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/train')
|
||||
parser.add_argument('--val_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/val')
|
||||
parser.add_argument('--test_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/test')
|
||||
|
||||
|
||||
parser.add_argument('--vimeo_data_path', type=str, default='./vimeo_triplet')
|
||||
|
||||
""" Hyperparameters for Training (when [phase=='train']) """
|
||||
parser.add_argument('--epochs', type=int, default=200, help='The number of epochs to run')
|
||||
parser.add_argument('--freq_display', type=int, default=100, help='The number of iterations frequency for display')
|
||||
parser.add_argument('--save_img_num', type=int, default=4,
|
||||
help='The number of saved image while training for visualization. It should smaller than the batch_size')
|
||||
parser.add_argument('--init_lr', type=float, default=1e-4, help='The initial learning rate')
|
||||
parser.add_argument('--lr_dec_fac', type=float, default=0.25, help='step - lr_decreasing_factor')
|
||||
parser.add_argument('--lr_milestones', type=int, default=[100, 150, 180])
|
||||
parser.add_argument('--lr_dec_start', type=int, default=0,
|
||||
help='When scheduler is StepLR, lr decreases from epoch at lr_dec_start')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='The size of batch size.')
|
||||
parser.add_argument('--weight_decay', type=float, default=0, help='for optim., weight decay (default: 0)')
|
||||
|
||||
parser.add_argument('--need_patch', default=True, help='get patch form image while training')
|
||||
parser.add_argument('--img_ch', type=int, default=3, help='base number of channels for image')
|
||||
parser.add_argument('--nf', type=int, default=64, help='base number of channels for feature maps') # 64
|
||||
parser.add_argument('--module_scale_factor', type=int, default=4, help='sptial reduction for pixelshuffle')
|
||||
parser.add_argument('--patch_size', type=int, default=384, help='patch size')
|
||||
parser.add_argument('--num_thrds', type=int, default=4, help='number of threads for data loading')
|
||||
parser.add_argument('--loss_type', default='L1', choices=['L1', 'MSE', 'L1_Charbonnier_loss'], help='Loss type')
|
||||
|
||||
parser.add_argument('--S_trn', type=int, default=3, help='The lowest scale depth for training')
|
||||
parser.add_argument('--S_tst', type=int, default=5, help='The lowest scale depth for test')
|
||||
|
||||
""" Weighting Parameters Lambda for Losses (when [phase=='train']) """
|
||||
parser.add_argument('--rec_lambda', type=float, default=1.0, help='Lambda for Reconstruction Loss')
|
||||
|
||||
""" Settings for Testing (when [phase=='test' or 'test_custom']) """
|
||||
parser.add_argument('--saving_flow_flag', default=False)
|
||||
parser.add_argument('--multiple', type=int, default=8, help='Due to the indexing problem of the file names, we recommend to use the power of 2. (e.g. 2, 4, 8, 16 ...). CAUTION : For the provided X-TEST, multiple should be one of [2, 4, 8, 16, 32].')
|
||||
parser.add_argument('--metrics_types', type=list, default=["PSNR", "SSIM", "tOF"], choices=["PSNR", "SSIM", "tOF"])
|
||||
|
||||
""" Settings for test_custom (when [phase=='test_custom']) """
|
||||
parser.add_argument('--custom_path', type=str, default='./custom_path', help='path for custom video containing frames')
|
||||
parser.add_argument('--output', type=str, default='./interp', help='output path')
|
||||
parser.add_argument('--input', type=str, default='./frames', help='input path')
|
||||
parser.add_argument('--img_format', type=str, default="png")
|
||||
parser.add_argument('--mdl_dir', type=str)
|
||||
|
||||
return check_args(parser.parse_args())
|
||||
|
||||
|
||||
def check_args(args):
|
||||
# --checkpoint_dir
|
||||
check_folder(args.checkpoint_dir)
|
||||
|
||||
# --text_dir
|
||||
check_folder(args.text_dir)
|
||||
|
||||
# --log_dir
|
||||
check_folder(args.log_dir)
|
||||
|
||||
# --test_img_dir
|
||||
check_folder(args.test_img_dir)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.dataset == 'Vimeo':
|
||||
if args.phase != 'test_custom':
|
||||
args.multiple = 2
|
||||
args.S_trn = 1
|
||||
args.S_tst = 1
|
||||
args.module_scale_factor = 2
|
||||
args.patch_size = 256
|
||||
args.batch_size = 16
|
||||
print('vimeo triplet data dir : ', args.vimeo_data_path)
|
||||
|
||||
print("Exp:", args.exp_num)
|
||||
args.model_dir = args.net_type + '_' + args.dataset + '_exp' + str(
|
||||
args.exp_num) # ex) model_dir = "XVFInet_X4K1000FPS_exp1"
|
||||
|
||||
if args is None:
|
||||
exit()
|
||||
for arg in vars(args):
|
||||
print('# {} : {}'.format(arg, getattr(args, arg)))
|
||||
device = torch.device(
|
||||
'cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)"
|
||||
torch.cuda.set_device(device) # change allocation of current GPU
|
||||
# caution!!!! if not "torch.cuda.set_device()":
|
||||
# RuntimeError: grid_sampler(): expected input and grid to be on same device, but input is on cuda:1 and grid is on cuda:0
|
||||
print('Available devices: ', torch.cuda.device_count())
|
||||
print('Current cuda device: ', torch.cuda.current_device())
|
||||
print('Current cuda device name: ', torch.cuda.get_device_name(device))
|
||||
if args.gpu is not None:
|
||||
print("Use GPU: {} is used".format(args.gpu))
|
||||
|
||||
SM = save_manager(args)
|
||||
|
||||
""" Initialize a model """
|
||||
model_net = args.net_object(args).apply(weights_init).to(device)
|
||||
criterion = [set_rec_loss(args).to(device), set_smoothness_loss().to(device)]
|
||||
|
||||
# to enable the inbuilt cudnn auto-tuner
|
||||
# to find the best algorithm to use for your hardware.
|
||||
cudnn.benchmark = True
|
||||
|
||||
if args.phase == "train":
|
||||
train(model_net, criterion, device, SM, args)
|
||||
epoch = args.epochs - 1
|
||||
|
||||
elif args.phase == "test" or args.phase == "metrics_evaluation" or args.phase == 'test_custom':
|
||||
checkpoint = SM.load_model(args.mdl_dir)
|
||||
model_net.load_state_dict(checkpoint['state_dict_Model'])
|
||||
epoch = checkpoint['last_epoch']
|
||||
|
||||
postfix = '_final_x' + str(args.multiple) + '_S_tst' + str(args.S_tst)
|
||||
if args.phase != "metrics_evaluation":
|
||||
print("\n-------------------------------------- Final Test starts -------------------------------------- ")
|
||||
print('Evaluate on test set (final test) with multiple = %d ' % (args.multiple))
|
||||
|
||||
final_test_loader = get_test_data(args, multiple=args.multiple,
|
||||
validation=False) # multiple is only used for X4K1000FPS
|
||||
|
||||
final_pred_save_path = test(final_test_loader, model_net,
|
||||
criterion, epoch,
|
||||
args, device,
|
||||
multiple=args.multiple,
|
||||
postfix=postfix, validation=False)
|
||||
#SM.write_info('Final 4k frames PSNR : {:.4}\n'.format(testPSNR))
|
||||
|
||||
if args.dataset == 'X4K1000FPS' and args.phase != 'test_custom':
|
||||
final_pred_save_path = os.path.join(args.test_img_dir, args.model_dir, 'epoch_' + str(epoch).zfill(5)) + postfix
|
||||
metrics_evaluation_X_Test(final_pred_save_path, args.test_data_path, args.metrics_types,
|
||||
flow_flag=args.saving_flow_flag, multiple=args.multiple)
|
||||
|
||||
|
||||
|
||||
print("------------------------- Test has been ended. -------------------------\n")
|
||||
|
||||
quit()
|
||||
|
||||
print("Exp:", args.exp_num)
|
||||
SM = save_manager
|
||||
multi_scale_recon_loss = criterion[0]
|
||||
smoothness_loss = criterion[1]
|
||||
|
||||
optimizer = optim.Adam(model_net.parameters(), lr=args.init_lr, betas=(0.9, 0.999),
|
||||
weight_decay=args.weight_decay) # optimizer
|
||||
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=args.lr_dec_fac)
|
||||
|
||||
last_epoch = 0
|
||||
best_PSNR = 0.0
|
||||
|
||||
if args.continue_training:
|
||||
checkpoint = SM.load_model()
|
||||
last_epoch = checkpoint['last_epoch'] + 1
|
||||
best_PSNR = checkpoint['best_PSNR']
|
||||
model_net.load_state_dict(checkpoint['state_dict_Model'])
|
||||
optimizer.load_state_dict(checkpoint['state_dict_Optimizer'])
|
||||
scheduler.load_state_dict(checkpoint['state_dict_Scheduler'])
|
||||
print("Optimizer and Scheduler have been reloaded. ")
|
||||
scheduler.milestones = Counter(args.lr_milestones)
|
||||
scheduler.gamma = args.lr_dec_fac
|
||||
print("scheduler.milestones : {}, scheduler.gamma : {}".format(scheduler.milestones, scheduler.gamma))
|
||||
start_epoch = last_epoch
|
||||
|
||||
# switch to train mode
|
||||
model_net.train()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
#SM.write_info('Epoch\ttrainLoss\ttestPSNR\tbest_PSNR\n')
|
||||
#print("[*] Training starts")
|
||||
|
||||
# Main training loop for total epochs (start from 'epoch=0')
|
||||
valid_loader = get_test_data(args, multiple=4, validation=True) # multiple is only used for X4K1000FPS
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_loader = get_train_data(args,
|
||||
max_t_step_size=32) # max_t_step_size (temporal distance) is only used for X4K1000FPS
|
||||
|
||||
batch_time = AverageClass('batch_time[s]:', ':6.3f')
|
||||
losses = AverageClass('Loss:', ':.4e')
|
||||
progress = ProgressMeter(len(train_loader), batch_time, losses, prefix="Epoch: [{}]".format(epoch))
|
||||
|
||||
print('Start epoch {} at [{:s}], learning rate : [{}]'.format(epoch, (str(datetime.now())[:-7]),
|
||||
optimizer.param_groups[0]['lr']))
|
||||
|
||||
# train for one epoch
|
||||
for trainIndex, (frames, t_value) in enumerate(train_loader):
|
||||
|
||||
input_frames = frames[:, :, :-1, :] # [B, C, T, H, W]
|
||||
frameT = frames[:, :, -1, :] # [B, C, H, W]
|
||||
|
||||
# Getting the input and the target from the training set
|
||||
input_frames = Variable(input_frames.to(device))
|
||||
frameT = Variable(frameT.to(device)) # ground truth for frameT
|
||||
t_value = Variable(t_value.to(device)) # [B,1]
|
||||
|
||||
optimizer.zero_grad()
|
||||
# compute output
|
||||
pred_frameT_pyramid, pred_flow_pyramid, occ_map, simple_mean = model_net(input_frames, t_value)
|
||||
rec_loss = 0.0
|
||||
smooth_loss = 0.0
|
||||
for l, pred_frameT_l in enumerate(pred_frameT_pyramid):
|
||||
rec_loss += args.rec_lambda * multi_scale_recon_loss(pred_frameT_l,
|
||||
F.interpolate(frameT, scale_factor=1 / (2 ** l),
|
||||
mode='bicubic', align_corners=False))
|
||||
smooth_loss += 0.5 * smoothness_loss(pred_flow_pyramid[0],
|
||||
F.interpolate(frameT, scale_factor=1 / args.module_scale_factor,
|
||||
mode='bicubic',
|
||||
align_corners=False)) # Apply 1st order edge-aware smoothness loss to the fineset level
|
||||
rec_loss /= len(pred_frameT_pyramid)
|
||||
pred_frameT = pred_frameT_pyramid[0] # final result I^0_t at original scale (s=0)
|
||||
pred_coarse_flow = 2 ** (args.S_trn) * F.interpolate(pred_flow_pyramid[-1], scale_factor=2 ** (
|
||||
args.S_trn) * args.module_scale_factor, mode='bicubic', align_corners=False)
|
||||
pred_fine_flow = F.interpolate(pred_flow_pyramid[0], scale_factor=args.module_scale_factor, mode='bicubic',
|
||||
align_corners=False)
|
||||
|
||||
total_loss = rec_loss + smooth_loss
|
||||
|
||||
# compute gradient and do SGD step
|
||||
total_loss.backward() # Backpropagate
|
||||
optimizer.step() # Optimizer update
|
||||
|
||||
# measure accumulated time and update average "batch" time consumptions via "AverageClass"
|
||||
# update average values via "AverageClass"
|
||||
losses.update(total_loss.item(), 1)
|
||||
batch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
if trainIndex % args.freq_display == 0:
|
||||
progress.print(trainIndex)
|
||||
batch_images = get_batch_images(args, save_img_num=args.save_img_num,
|
||||
save_images=[pred_frameT, pred_coarse_flow, pred_fine_flow, frameT,
|
||||
simple_mean, occ_map])
|
||||
cv2.imwrite(os.path.join(args.log_dir, '{:03d}_{:04d}_training.png'.format(epoch, trainIndex)), batch_images)
|
||||
|
||||
|
||||
|
||||
if epoch >= args.lr_dec_start:
|
||||
scheduler.step()
|
||||
|
||||
# if (epoch + 1) % 10 == 0 or epoch==0:
|
||||
val_multiple = 4 if args.dataset == 'X4K1000FPS' else 2
|
||||
print('\nEvaluate on test set (validation while training) with multiple = {}'.format(val_multiple))
|
||||
postfix = '_val_' + str(val_multiple) + '_S_tst' + str(args.S_tst)
|
||||
final_pred_save_path = test(valid_loader, model_net, criterion, epoch, args,
|
||||
device, multiple=val_multiple, postfix=postfix,
|
||||
validation=True)
|
||||
|
||||
# remember best best_PSNR and best_SSIM and save checkpoint
|
||||
#print("best_PSNR : {:.3f}, testPSNR : {:.3f}".format(best_PSNR, testPSNR))
|
||||
best_PSNR_flag = testPSNR > best_PSNR
|
||||
best_PSNR = max(testPSNR, best_PSNR)
|
||||
# save checkpoint.
|
||||
combined_state_dict = {
|
||||
'net_type': args.net_type,
|
||||
'last_epoch': epoch,
|
||||
'batch_size': args.batch_size,
|
||||
'trainLoss': losses.avg,
|
||||
'testLoss': testLoss,
|
||||
'testPSNR': testPSNR,
|
||||
'best_PSNR': best_PSNR,
|
||||
'state_dict_Model': model_net.state_dict(),
|
||||
'state_dict_Optimizer': optimizer.state_dict(),
|
||||
'state_dict_Scheduler': scheduler.state_dict()}
|
||||
|
||||
SM.save_best_model(combined_state_dict, best_PSNR_flag)
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
SM.save_epc_model(combined_state_dict, epoch)
|
||||
SM.write_info('{}\t{:.4}\t{:.4}\t{:.4}\n'.format(epoch, losses.avg, testPSNR, best_PSNR))
|
||||
|
||||
print("------------------------- Training has been ended. -------------------------\n")
|
||||
print("information of model:", args.model_dir)
|
||||
print("best_PSNR of model:", best_PSNR)
|
||||
|
||||
|
||||
def test(test_loader, model_net, criterion, epoch, args, device, multiple, postfix, validation):
|
||||
#os.chdir(interp_output_path)
|
||||
|
||||
batch_time = AverageClass('Time:', ':6.3f')
|
||||
losses = AverageClass('testLoss:', ':.4e')
|
||||
PSNRs = AverageClass('testPSNR:', ':.4e')
|
||||
SSIMs = AverageClass('testSSIM:', ':.4e')
|
||||
args.divide = 2 ** (args.S_tst) * args.module_scale_factor * 4
|
||||
|
||||
# progress = ProgressMeter(len(test_loader), batch_time, accm_time, losses, PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch))
|
||||
progress = ProgressMeter(len(test_loader), PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch))
|
||||
|
||||
multi_scale_recon_loss = criterion[0]
|
||||
|
||||
# switch to evaluate mode
|
||||
model_net.eval()
|
||||
|
||||
counter = 1
|
||||
copied_src_frames = list()
|
||||
last_frame = ""
|
||||
|
||||
print("------------------------------------------- Test ----------------------------------------------")
|
||||
with torch.no_grad():
|
||||
start_time = time.time()
|
||||
for testIndex, (frames, t_value, scene_name, frameRange) in enumerate(test_loader):
|
||||
# Shape of 'frames' : [1,C,T+1,H,W]
|
||||
frameT = frames[:, :, -1, :, :] # [1,C,H,W]
|
||||
It_Path, I0_Path, I1_Path = frameRange
|
||||
|
||||
#print(I0_Path)
|
||||
#print(I1_Path)
|
||||
|
||||
input_filename = str(I0_Path).split("'")[1];
|
||||
input_filename_next = str(I1_Path).split("'")[1];
|
||||
last_frame = input_filename_next
|
||||
|
||||
frameT = Variable(frameT.to(device)) # ground truth for frameT
|
||||
t_value = Variable(t_value.to(device))
|
||||
|
||||
if (testIndex % (multiple - 1)) == 0:
|
||||
input_frames = frames[:, :, :-1, :, :] # [1,C,T,H,W]
|
||||
input_frames = Variable(input_frames.to(device))
|
||||
|
||||
B, C, T, H, W = input_frames.size()
|
||||
H_padding = (args.divide - H % args.divide) % args.divide
|
||||
W_padding = (args.divide - W % args.divide) % args.divide
|
||||
if H_padding != 0 or W_padding != 0:
|
||||
input_frames = F.pad(input_frames, (0, W_padding, 0, H_padding), "constant")
|
||||
|
||||
|
||||
pred_frameT = model_net(input_frames, t_value, is_training=False)
|
||||
|
||||
if H_padding != 0 or W_padding != 0:
|
||||
pred_frameT = pred_frameT[:, :, :H, :W]
|
||||
|
||||
|
||||
epoch_save_path = args.custom_path
|
||||
scene_save_path = os.path.join(epoch_save_path, scene_name[0])
|
||||
pred_frameT = np.squeeze(pred_frameT.detach().cpu().numpy())
|
||||
test = np.squeeze(frameT.detach().cpu().numpy())
|
||||
output_img = np.around(denorm255_np(np.transpose(pred_frameT, [1, 2, 0]))) # [h,w,c] and [-1,1] to [0,255]
|
||||
#print(os.path.join(scene_save_path, It_Path[0]))
|
||||
|
||||
frame_src_path = os.path.join(args.custom_path, args.output, '{:0>8d}.{}'.format(counter, args.img_format))
|
||||
src_frame_path = os.path.join(args.custom_path, args.input, input_filename)
|
||||
|
||||
|
||||
if os.path.isfile(src_frame_path):
|
||||
if src_frame_path in copied_src_frames:
|
||||
#print(f"Not copying source frame '{src_frame_path}' because it has already been copied before! - {len(copied_src_frames)}")
|
||||
pass
|
||||
else:
|
||||
print(f"S => {os.path.basename(src_frame_path)} => {os.path.basename(frame_src_path)}")
|
||||
shutil.copy(src_frame_path, frame_src_path)
|
||||
copied_src_frames.append(src_frame_path)
|
||||
counter += 1
|
||||
|
||||
frame_interp_path = os.path.join(args.custom_path, args.output, '{:0>8d}.{}'.format(counter, args.img_format))
|
||||
print(f"I => {os.path.basename(frame_interp_path)}")
|
||||
cv2.imwrite(frame_interp_path, output_img.astype(np.uint8))
|
||||
counter += 1
|
||||
|
||||
#losses.update(0.0, 1)
|
||||
#PSNRs.update(0.0, 1)
|
||||
#SSIMs.update(0.0, 1)
|
||||
|
||||
print("-----------------------------------------------------------------------------------------------")
|
||||
|
||||
frame_src_path = os.path.join(args.custom_path, args.output, '{:0>8d}.{}'.format(counter, args.img_format))
|
||||
print(f"LAST S => {frame_src_path}")
|
||||
src_frame_path = os.path.join(args.custom_path, args.input, last_frame)
|
||||
shutil.copy(src_frame_path, frame_src_path)
|
||||
|
||||
return epoch_save_path
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,13 @@
|
|||
[
|
||||
{
|
||||
"name": "Vimeo",
|
||||
"desc": "Model trained on Vimeo90K dataset",
|
||||
"dir": "vimeo"
|
||||
},
|
||||
{
|
||||
"name": "X4K1000FPS",
|
||||
"desc": "Model trained on X4K1000FPS dataset",
|
||||
"dir": "x4k1000fps",
|
||||
"isDefault": "true"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,961 @@
|
|||
from __future__ import division
|
||||
import os, glob, sys, torch, shutil, random, math, time, cv2
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
import torch.nn as nn
|
||||
import pandas as pd
|
||||
import torch.nn.functional as F
|
||||
from datetime import datetime
|
||||
from torch.nn import init
|
||||
from skimage.measure import compare_ssim
|
||||
from skimage.metrics import structural_similarity
|
||||
from torch.autograd import Variable
|
||||
from torchvision import models
|
||||
|
||||
|
||||
class save_manager():
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.model_dir = self.args.net_type + '_' + self.args.dataset + '_exp' + str(self.args.exp_num)
|
||||
print("model_dir:", self.model_dir)
|
||||
# ex) model_dir = "XVFInet_exp1"
|
||||
|
||||
self.checkpoint_dir = os.path.join(self.args.checkpoint_dir, self.model_dir)
|
||||
# './checkpoint_dir/XVFInet_exp1"
|
||||
#check_folder(self.checkpoint_dir)
|
||||
|
||||
#print("checkpoint_dir:", self.checkpoint_dir)
|
||||
|
||||
#self.text_dir = os.path.join(self.args.text_dir, self.model_dir)
|
||||
#print("text_dir:", self.text_dir)
|
||||
|
||||
#""" Save a text file """
|
||||
#if not os.path.exists(self.text_dir + '.txt'):
|
||||
# self.log_file = open(self.text_dir + '.txt', 'w')
|
||||
# # "w" - Write - Opens a file for writing, creates the file if it does not exist
|
||||
# self.log_file.write('----- Model parameters -----\n')
|
||||
# self.log_file.write(str(datetime.now())[:-7] + '\n')
|
||||
# for arg in vars(self.args):
|
||||
# self.log_file.write('{} : {}\n'.format(arg, getattr(self.args, arg)))
|
||||
# # ex) ./text_dir/XVFInet_exp1.txt
|
||||
# self.log_file.close()
|
||||
|
||||
# "a" - Append - Opens a file for appending, creates the file if it does not exist
|
||||
|
||||
def write_info(self, strings):
|
||||
self.log_file = open(self.text_dir + '.txt', 'a')
|
||||
self.log_file.write(strings)
|
||||
self.log_file.close()
|
||||
|
||||
def save_best_model(self, combined_state_dict, best_PSNR_flag):
|
||||
file_name = os.path.join(self.checkpoint_dir, self.model_dir + '_latest.pt')
|
||||
# file_name = "./checkpoint_dir/XVFInet_exp1/XVFInet_exp1_latest.ckpt
|
||||
torch.save(combined_state_dict, file_name)
|
||||
if best_PSNR_flag:
|
||||
shutil.copyfile(file_name, os.path.join(self.checkpoint_dir, self.model_dir + '_best_PSNR.pt'))
|
||||
|
||||
# file_path = "./checkpoint_dir/XVFInet_exp1/XVFInet_exp1_best_PSNR.ckpt
|
||||
|
||||
def save_epc_model(self, combined_state_dict, epoch):
|
||||
file_name = os.path.join(self.checkpoint_dir, self.model_dir + '_epc' + str(epoch) + '.pt')
|
||||
# file_name = "./checkpoint_dir/XVFInet_exp1/XVFInet_exp1_epc10.ckpt
|
||||
torch.save(combined_state_dict, file_name)
|
||||
|
||||
def load_epc_model(self, epoch):
|
||||
checkpoint = torch.load(os.path.join(self.checkpoint_dir, self.model_dir + '_epc' + str(epoch - 1) + '.pt'))
|
||||
print("load model '{}', epoch: {}, best_PSNR: {:3f}".format(
|
||||
os.path.join(self.checkpoint_dir, self.model_dir + '_epc' + str(epoch - 1) + '.pt'), checkpoint['last_epoch'] + 1,
|
||||
checkpoint['best_PSNR']))
|
||||
return checkpoint
|
||||
|
||||
def load_model(self, mdl_dir):
|
||||
# checkpoint = torch.load(self.checkpoint_dir + '/' + self.model_dir + '_latest.pt')
|
||||
checkpoint = torch.load(os.path.join(mdl_dir, "checkpoint.pt"), map_location='cuda:0')
|
||||
print("load model '{}', epoch: {},".format(
|
||||
os.path.join(mdl_dir, "checkpoint.pt"), checkpoint['last_epoch'] + 1))
|
||||
return checkpoint
|
||||
|
||||
def load_best_PSNR_model(self, ):
|
||||
checkpoint = torch.load(os.path.join(self.checkpoint_dir, self.model_dir + '_best_PSNR.pt'))
|
||||
print("load _best_PSNR model '{}', epoch: {}, best_PSNR: {:3f}, best_SSIM: {:3f}".format(
|
||||
os.path.join(self.checkpoint_dir, self.model_dir + '_best_PSNR.pt'), checkpoint['last_epoch'] + 1,
|
||||
checkpoint['best_PSNR'], checkpoint['best_SSIM']))
|
||||
return checkpoint
|
||||
|
||||
|
||||
def check_folder(log_dir):
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
return log_dir
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if (classname.find('Conv2d') != -1) or (classname.find('Conv3d') != -1):
|
||||
init.xavier_normal_(m.weight)
|
||||
# init.kaiming_normal_(m.weight, nonlinearity='relu')
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
init.zeros_(m.bias)
|
||||
|
||||
|
||||
def get_train_data(args, max_t_step_size):
|
||||
if args.dataset == 'X4K1000FPS':
|
||||
data_train = X_Train(args, max_t_step_size)
|
||||
elif args.dataset == 'Vimeo':
|
||||
data_train = Vimeo_Train(args)
|
||||
dataloader = torch.utils.data.DataLoader(data_train, batch_size=args.batch_size, drop_last=True, shuffle=True,
|
||||
num_workers=int(args.num_thrds), pin_memory=False)
|
||||
return dataloader
|
||||
|
||||
|
||||
def get_test_data(args, multiple, validation):
|
||||
if args.dataset == 'X4K1000FPS' and args.phase != 'test_custom':
|
||||
data_test = X_Test(args, multiple, validation) # 'validation' for validation while training for simplicity
|
||||
elif args.dataset == 'Vimeo' and args.phase != 'test_custom':
|
||||
data_test = Vimeo_Test(args, validation)
|
||||
elif args.phase == 'test_custom':
|
||||
data_test = Custom_Test(args, multiple)
|
||||
dataloader = torch.utils.data.DataLoader(data_test, batch_size=1, drop_last=True, shuffle=False, pin_memory=False)
|
||||
return dataloader
|
||||
|
||||
|
||||
def frames_loader_train(args, candidate_frames, frameRange):
|
||||
frames = []
|
||||
for frameIndex in frameRange:
|
||||
frame = cv2.imread(candidate_frames[frameIndex])
|
||||
frames.append(frame)
|
||||
(ih, iw, c) = frame.shape
|
||||
frames = np.stack(frames, axis=0) # (T, H, W, 3)
|
||||
if args.need_patch: ## random crop
|
||||
ps = args.patch_size
|
||||
ix = random.randrange(0, iw - ps + 1)
|
||||
iy = random.randrange(0, ih - ps + 1)
|
||||
frames = frames[:, iy:iy + ps, ix:ix + ps, :]
|
||||
|
||||
if random.random() < 0.5: # random horizontal flip
|
||||
frames = frames[:, :, ::-1, :]
|
||||
|
||||
# No vertical flip
|
||||
|
||||
rot = random.randint(0, 3) # random rotate
|
||||
frames = np.rot90(frames, rot, (1, 2))
|
||||
|
||||
""" np2Tensor [-1,1] normalized """
|
||||
frames = RGBframes_np2Tensor(frames, args.img_ch)
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def frames_loader_test(args, I0I1It_Path, validation):
|
||||
frames = []
|
||||
for path in I0I1It_Path:
|
||||
frame = cv2.imread(path)
|
||||
frames.append(frame)
|
||||
(ih, iw, c) = frame.shape
|
||||
frames = np.stack(frames, axis=0) # (T, H, W, 3)
|
||||
|
||||
if args.dataset == 'X4K1000FPS':
|
||||
if validation:
|
||||
ps = 512
|
||||
ix = (iw - ps) // 2
|
||||
iy = (ih - ps) // 2
|
||||
frames = frames[:, iy:iy + ps, ix:ix + ps, :]
|
||||
|
||||
""" np2Tensor [-1,1] normalized """
|
||||
frames = RGBframes_np2Tensor(frames, args.img_ch)
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def RGBframes_np2Tensor(imgIn, channel):
|
||||
## input : T, H, W, C
|
||||
if channel == 1:
|
||||
# rgb --> Y (gray)
|
||||
imgIn = np.sum(imgIn * np.reshape([65.481, 128.553, 24.966], [1, 1, 1, 3]) / 255.0, axis=3,
|
||||
keepdims=True) + 16.0
|
||||
|
||||
# to Tensor
|
||||
ts = (3, 0, 1, 2) ############# dimension order should be [C, T, H, W]
|
||||
imgIn = torch.Tensor(imgIn.transpose(ts).astype(float)).mul_(1.0)
|
||||
|
||||
# normalization [-1,1]
|
||||
imgIn = (imgIn / 255.0 - 0.5) * 2
|
||||
|
||||
return imgIn
|
||||
|
||||
|
||||
def make_2D_dataset_X_Train(dir):
|
||||
framesPath = []
|
||||
# Find and loop over all the clips in root `dir`.
|
||||
for scene_path in sorted(glob.glob(os.path.join(dir, '*', ''))):
|
||||
sample_paths = sorted(glob.glob(os.path.join(scene_path, '*', '')))
|
||||
for sample_path in sample_paths:
|
||||
frame65_list = []
|
||||
for frame in sorted(glob.glob(os.path.join(sample_path, '*.png'))):
|
||||
frame65_list.append(frame)
|
||||
framesPath.append(frame65_list)
|
||||
|
||||
print("The number of total training samples : {} which has 65 frames each.".format(
|
||||
len(framesPath))) ## 4408 folders which have 65 frames each
|
||||
return framesPath
|
||||
|
||||
|
||||
class X_Train(data.Dataset):
|
||||
def __init__(self, args, max_t_step_size):
|
||||
self.args = args
|
||||
self.max_t_step_size = max_t_step_size
|
||||
|
||||
self.framesPath = make_2D_dataset_X_Train(self.args.train_data_path)
|
||||
self.nScenes = len(self.framesPath)
|
||||
|
||||
# Raise error if no images found in train_data_path.
|
||||
if self.nScenes == 0:
|
||||
raise (RuntimeError("Found 0 files in subfolders of: " + self.args.train_data_path + "\n"))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
t_step_size = random.randint(2, self.max_t_step_size)
|
||||
t_list = np.linspace((1 / t_step_size), (1 - (1 / t_step_size)), (t_step_size - 1))
|
||||
|
||||
candidate_frames = self.framesPath[idx]
|
||||
firstFrameIdx = random.randint(0, (64 - t_step_size))
|
||||
interIdx = random.randint(1, t_step_size - 1) # relative index, 1~self.t_step_size-1
|
||||
interFrameIdx = firstFrameIdx + interIdx # absolute index
|
||||
t_value = t_list[interIdx - 1] # [0,1]
|
||||
|
||||
if (random.randint(0, 1)):
|
||||
frameRange = [firstFrameIdx, firstFrameIdx + t_step_size, interFrameIdx]
|
||||
else: ## temporally reversed order
|
||||
frameRange = [firstFrameIdx + t_step_size, firstFrameIdx, interFrameIdx]
|
||||
interIdx = t_step_size - interIdx # (self.t_step_size-1) ~ 1
|
||||
t_value = 1.0 - t_value
|
||||
|
||||
frames = frames_loader_train(self.args, candidate_frames,
|
||||
frameRange) # including "np2Tensor [-1,1] normalized"
|
||||
|
||||
return frames, np.expand_dims(np.array(t_value, dtype=np.float32), 0)
|
||||
|
||||
def __len__(self):
|
||||
return self.nScenes
|
||||
|
||||
|
||||
def make_2D_dataset_X_Test(dir, multiple, t_step_size):
|
||||
""" make [I0,I1,It,t,scene_folder] """
|
||||
""" 1D (accumulated) """
|
||||
testPath = []
|
||||
t = np.linspace((1 / multiple), (1 - (1 / multiple)), (multiple - 1))
|
||||
for type_folder in sorted(glob.glob(os.path.join(dir, '*', ''))): # [type1,type2,type3,...]
|
||||
for scene_folder in sorted(glob.glob(os.path.join(type_folder, '*', ''))): # [scene1,scene2,..]
|
||||
frame_folder = sorted(glob.glob(scene_folder + '*.png')) # 32 multiple, ['00000.png',...,'00032.png']
|
||||
for idx in range(0, len(frame_folder), t_step_size): # 0,32,64,...
|
||||
if idx == len(frame_folder) - 1:
|
||||
break
|
||||
for mul in range(multiple - 1):
|
||||
I0I1It_paths = []
|
||||
I0I1It_paths.append(frame_folder[idx]) # I0 (fix)
|
||||
I0I1It_paths.append(frame_folder[idx + t_step_size]) # I1 (fix)
|
||||
I0I1It_paths.append(frame_folder[idx + int((t_step_size // multiple) * (mul + 1))]) # It
|
||||
I0I1It_paths.append(t[mul])
|
||||
I0I1It_paths.append(scene_folder.split(os.path.join(dir, ''))[-1]) # type1/scene1
|
||||
testPath.append(I0I1It_paths)
|
||||
return testPath
|
||||
|
||||
|
||||
class X_Test(data.Dataset):
|
||||
def __init__(self, args, multiple, validation):
|
||||
self.args = args
|
||||
self.multiple = multiple
|
||||
self.validation = validation
|
||||
if validation:
|
||||
self.testPath = make_2D_dataset_X_Test(self.args.val_data_path, multiple, t_step_size=32)
|
||||
else: ## test
|
||||
self.testPath = make_2D_dataset_X_Test(self.args.test_data_path, multiple, t_step_size=32)
|
||||
|
||||
self.nIterations = len(self.testPath)
|
||||
|
||||
# Raise error if no images found in test_data_path.
|
||||
if len(self.testPath) == 0:
|
||||
if validation:
|
||||
raise (RuntimeError("Found 0 files in subfolders of: " + self.args.val_data_path + "\n"))
|
||||
else:
|
||||
raise (RuntimeError("Found 0 files in subfolders of: " + self.args.test_data_path + "\n"))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
I0, I1, It, t_value, scene_name = self.testPath[idx]
|
||||
|
||||
I0I1It_Path = [I0, I1, It]
|
||||
|
||||
frames = frames_loader_test(self.args, I0I1It_Path, self.validation)
|
||||
# including "np2Tensor [-1,1] normalized"
|
||||
|
||||
I0_path = I0.split(os.sep)[-1]
|
||||
I1_path = I1.split(os.sep)[-1]
|
||||
It_path = It.split(os.sep)[-1]
|
||||
|
||||
return frames, np.expand_dims(np.array(t_value, dtype=np.float32), 0), scene_name, [It_path, I0_path, I1_path]
|
||||
|
||||
def __len__(self):
|
||||
return self.nIterations
|
||||
|
||||
|
||||
class Vimeo_Train(data.Dataset):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.t = 0.5
|
||||
self.framesPath = []
|
||||
f = open(os.path.join(args.vimeo_data_path, 'tri_trainlist.txt'),
|
||||
'r') # '../Datasets/vimeo_triplet/sequences/tri_trainlist.txt'
|
||||
while True:
|
||||
scene_path = f.readline().split('\n')[0]
|
||||
if not scene_path: break
|
||||
frames_list = sorted(glob.glob(os.path.join(args.vimeo_data_path, 'sequences', scene_path,
|
||||
'*.png'))) # '../Datasets/vimeo_triplet/sequences/%05d/%04d/*.png'
|
||||
self.framesPath.append(frames_list)
|
||||
f.close
|
||||
# self.framesPath = self.framesPath[:20]
|
||||
self.nScenes = len(self.framesPath)
|
||||
if self.nScenes == 0:
|
||||
raise (RuntimeError("Found 0 files in subfolders of: " + args.vimeo_data_path + "\n"))
|
||||
print("nScenes of Vimeo train triplet : ", self.nScenes)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
candidate_frames = self.framesPath[idx]
|
||||
|
||||
""" Randomly reverse frames """
|
||||
if (random.randint(0, 1)):
|
||||
frameRange = [0, 2, 1]
|
||||
else:
|
||||
frameRange = [2, 0, 1]
|
||||
frames = frames_loader_train(self.args, candidate_frames,
|
||||
frameRange) # including "np2Tensor [-1,1] normalized"
|
||||
|
||||
return frames, np.expand_dims(np.array(0.5, dtype=np.float32), 0)
|
||||
|
||||
def __len__(self):
|
||||
return self.nScenes
|
||||
|
||||
|
||||
class Vimeo_Test(data.Dataset):
|
||||
def __init__(self, args, validation):
|
||||
self.args = args
|
||||
self.framesPath = []
|
||||
f = open(os.path.join(args.vimeo_data_path, 'tri_testlist.txt'), 'r')
|
||||
while True:
|
||||
scene_path = f.readline().split('\n')[0]
|
||||
if not scene_path: break
|
||||
frames_list = sorted(glob.glob(os.path.join(args.vimeo_data_path, 'sequences', scene_path,
|
||||
'*.png'))) # '../Datasets/vimeo_triplet/sequences/%05d/%04d/*.png'
|
||||
self.framesPath.append(frames_list)
|
||||
if validation:
|
||||
self.framesPath = self.framesPath[::37]
|
||||
f.close
|
||||
|
||||
self.num_scene = len(self.framesPath) # total test scenes
|
||||
if len(self.framesPath) == 0:
|
||||
raise (RuntimeError("Found 0 files in subfolders of: " + args.vimeo_data_path + "\n"))
|
||||
else:
|
||||
print("# of Vimeo triplet testset : ", self.num_scene)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
scene_name = self.framesPath[idx][0].split(os.sep)
|
||||
scene_name = os.path.join(scene_name[-3], scene_name[-2])
|
||||
I0, It, I1 = self.framesPath[idx]
|
||||
I0I1It_Path = [I0, I1, It]
|
||||
frames = frames_loader_test(self.args, I0I1It_Path, validation=False)
|
||||
|
||||
I0_path = I0.split(os.sep)[-1]
|
||||
I1_path = I1.split(os.sep)[-1]
|
||||
It_path = It.split(os.sep)[-1]
|
||||
|
||||
return frames, np.expand_dims(np.array(0.5, dtype=np.float32), 0), scene_name, [It_path, I0_path, I1_path]
|
||||
|
||||
def __len__(self):
|
||||
return self.num_scene
|
||||
|
||||
def make_2D_dataset_Custom_Test(dir, multiple):
|
||||
""" make [I0,I1,It,t,scene_folder] """
|
||||
""" 1D (accumulated) """
|
||||
testPath = []
|
||||
t = np.linspace((1 / multiple), (1 - (1 / multiple)), (multiple - 1))
|
||||
for scene_folder in sorted(glob.glob(os.path.join(dir, '*', ''))): # [scene1, scene2, scene3, ...]
|
||||
frame_folder = sorted(glob.glob(scene_folder + '*.png')) # ex) ['00000.png',...,'00123.png']
|
||||
for idx in range(0, len(frame_folder)):
|
||||
if idx == len(frame_folder) - 1:
|
||||
break
|
||||
for suffix, mul in enumerate(range(multiple - 1)):
|
||||
I0I1It_paths = []
|
||||
I0I1It_paths.append(frame_folder[idx]) # I0 (fix)
|
||||
I0I1It_paths.append(frame_folder[idx + 1]) # I1 (fix)
|
||||
target_t_Idx = frame_folder[idx].split(os.sep)[-1].split('.')[0]+'_' + str(suffix).zfill(3) + '.png'
|
||||
# ex) target t name: 00017.png => '00017_1.png'
|
||||
I0I1It_paths.append(os.path.join(scene_folder, target_t_Idx)) # It
|
||||
I0I1It_paths.append(t[mul]) # t
|
||||
I0I1It_paths.append(frame_folder[idx].split(os.path.join(dir, ''))[-1].split(os.sep)[0]) # scene1
|
||||
testPath.append(I0I1It_paths)
|
||||
break # limit to 1 directory - nmkd
|
||||
return testPath
|
||||
|
||||
|
||||
# def make_2D_dataset_Custom_Test(dir):
|
||||
# """ make [I0,I1,It,t,scene_folder] """
|
||||
# """ 1D (accumulated) """
|
||||
# testPath = []
|
||||
# for scene_folder in sorted(glob.glob(os.path.join(dir, '*/'))): # [scene1, scene2, scene3, ...]
|
||||
# frame_folder = sorted(glob.glob(scene_folder + '*.png')) # ex) ['00000.png',...,'00123.png']
|
||||
# for idx in range(0, len(frame_folder)):
|
||||
# if idx == len(frame_folder) - 1:
|
||||
# break
|
||||
# I0I1It_paths = []
|
||||
# I0I1It_paths.append(frame_folder[idx]) # I0 (fix)
|
||||
# I0I1It_paths.append(frame_folder[idx + 1]) # I1 (fix)
|
||||
# target_t_Idx = frame_folder[idx].split('/')[-1].split('.')[0]+'_x2.png'
|
||||
# # ex) target t name: 00017.png => '00017_1.png'
|
||||
# I0I1It_paths.append(os.path.join(scene_folder, target_t_Idx)) # It
|
||||
# I0I1It_paths.append(0.5) # t
|
||||
# I0I1It_paths.append(frame_folder[idx].split(os.path.join(dir, ''))[-1].split('/')[0]) # scene1
|
||||
# testPath.append(I0I1It_paths)
|
||||
# for asdf in testPath:
|
||||
# print(asdf)
|
||||
# return testPath
|
||||
|
||||
|
||||
class Custom_Test(data.Dataset):
|
||||
def __init__(self, args, multiple):
|
||||
self.args = args
|
||||
self.multiple = multiple
|
||||
self.testPath = make_2D_dataset_Custom_Test(self.args.custom_path, self.multiple)
|
||||
self.nIterations = len(self.testPath)
|
||||
|
||||
# Raise error if no images found in test_data_path.
|
||||
if len(self.testPath) == 0:
|
||||
raise (RuntimeError("Found 0 files in subfolders of: " + self.args.custom_path + "\n"))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
I0, I1, It, t_value, scene_name = self.testPath[idx]
|
||||
dummy_dir = I1 # due to there is not ground truth intermediate frame.
|
||||
I0I1It_Path = [I0, I1, dummy_dir]
|
||||
|
||||
frames = frames_loader_test(self.args, I0I1It_Path, None)
|
||||
# including "np2Tensor [-1,1] normalized"
|
||||
|
||||
I0_path = I0.split(os.sep)[-1]
|
||||
I1_path = I1.split(os.sep)[-1]
|
||||
It_path = It.split(os.sep)[-1]
|
||||
|
||||
return frames, np.expand_dims(np.array(t_value, dtype=np.float32), 0), scene_name, [It_path, I0_path, I1_path]
|
||||
|
||||
def __len__(self):
|
||||
return self.nIterations
|
||||
|
||||
|
||||
class L1_Charbonnier_loss(nn.Module):
|
||||
"""L1 Charbonnierloss."""
|
||||
|
||||
def __init__(self):
|
||||
super(L1_Charbonnier_loss, self).__init__()
|
||||
self.epsilon = 1e-3
|
||||
|
||||
def forward(self, X, Y):
|
||||
loss = torch.mean(torch.sqrt((X - Y) ** 2 + self.epsilon ** 2))
|
||||
return loss
|
||||
|
||||
|
||||
def set_rec_loss(args):
|
||||
loss_type = args.loss_type
|
||||
if loss_type == 'MSE':
|
||||
lossfunction = nn.MSELoss()
|
||||
elif loss_type == 'L1':
|
||||
lossfunction = nn.L1Loss()
|
||||
elif loss_type == 'L1_Charbonnier_loss':
|
||||
lossfunction = L1_Charbonnier_loss()
|
||||
|
||||
return lossfunction
|
||||
|
||||
|
||||
class AverageClass(object):
|
||||
""" For convenience of averaging values """
|
||||
""" refer from "https://github.com/pytorch/examples/blob/master/imagenet/main.py" """
|
||||
|
||||
def __init__(self, name, fmt=':f'):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.sum = 0.0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val, n=1.0):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name} {val' + self.fmt + '} (avg:{avg' + self.fmt + '})'
|
||||
# Accm_Time[s]: 1263.517 (avg:639.701) (<== if AverageClass('Accm_Time[s]:', ':6.3f'))
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class ProgressMeter(object):
|
||||
""" For convenience of printing diverse values by using "AverageClass" """
|
||||
""" refer from "https://github.com/pytorch/examples/blob/master/imagenet/main.py" """
|
||||
|
||||
def __init__(self, num_batches, *meters, prefix=""):
|
||||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
||||
self.meters = meters
|
||||
self.prefix = prefix
|
||||
|
||||
def print(self, batch):
|
||||
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
||||
# # Epoch: [0][ 0/196]
|
||||
entries += [str(meter) for meter in self.meters]
|
||||
print('\t'.join(entries))
|
||||
|
||||
def _get_batch_fmtstr(self, num_batches):
|
||||
num_digits = len(str(num_batches // 1))
|
||||
fmt = '{:' + str(num_digits) + 'd}'
|
||||
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
||||
|
||||
|
||||
def metrics_evaluation_X_Test(pred_save_path, test_data_path, metrics_types, flow_flag=False, multiple=8, server=None):
|
||||
"""
|
||||
pred_save_path = './test_img_dir/XVFInet_exp1/epoch_00099' when 'args.epochs=100'
|
||||
test_data_path = ex) 'F:/Jihyong/4K_1000fps_dataset/VIC_4K_1000FPS/X_TEST'
|
||||
format: -type1
|
||||
-scene1
|
||||
:
|
||||
-scene5
|
||||
-type2
|
||||
:
|
||||
-type3
|
||||
:
|
||||
-scene5
|
||||
"metrics_types": ["PSNR", "SSIM", "LPIPS", "tOF", "tLP100"]
|
||||
"flow_flag": option for saving motion visualization
|
||||
"final_test_type": ['first_interval', 1, 2, 3, 4]
|
||||
"multiple": x4, x8, x16, x32 for interpolation
|
||||
"""
|
||||
|
||||
pred_framesPath = []
|
||||
for type_folder in sorted(glob.glob(os.path.join(pred_save_path, '*', ''))): # [type1,type2,type3,...]
|
||||
for scene_folder in sorted(glob.glob(os.path.join(type_folder, '*', ''))): # [scene1,scene2,..]
|
||||
scene_framesPath = []
|
||||
for frame_path in sorted(glob.glob(scene_folder + '*.png')):
|
||||
scene_framesPath.append(frame_path)
|
||||
pred_framesPath.append(scene_framesPath)
|
||||
if len(pred_framesPath) == 0:
|
||||
raise (RuntimeError("Found 0 files in " + pred_save_path + "\n"))
|
||||
|
||||
# GT_framesPath = make_2D_dataset_X_Test(test_data_path, multiple, t_step_size=32)
|
||||
# pred_framesPath = make_2D_dataset_X_Test(pred_save_path, multiple, t_step_size=32)
|
||||
|
||||
# ex) pred_save_path: './test_img_dir/XVFInet_exp1/epoch_00099' when 'args.epochs=100'
|
||||
# ex) framesPath: [['./VIC_4K_1000FPS/VIC_Test/Fast/003_TEST_Fast/00000.png',...], ..., []] 2D List, len=30
|
||||
# ex) scenesFolder: ['Fast/003_TEST_Fast',...]
|
||||
|
||||
keys = metrics_types
|
||||
len_dict = dict.fromkeys(keys, 0)
|
||||
Total_avg_dict = dict.fromkeys(["TotalAvg_" + _ for _ in keys], 0)
|
||||
Type1_dict = dict.fromkeys(["Type1Avg_" + _ for _ in keys], 0)
|
||||
Type2_dict = dict.fromkeys(["Type2Avg_" + _ for _ in keys], 0)
|
||||
Type3_dict = dict.fromkeys(["Type3Avg_" + _ for _ in keys], 0)
|
||||
|
||||
# LPIPSnet = dm.DistModel()
|
||||
# LPIPSnet.initialize(model='net-lin', net='alex', use_gpu=True)
|
||||
|
||||
total_list_dict = {}
|
||||
key_str = 'Metrics -->'
|
||||
for key_i in keys:
|
||||
total_list_dict[key_i] = []
|
||||
key_str += ' ' + str(key_i)
|
||||
key_str += ' will be measured.'
|
||||
print(key_str)
|
||||
|
||||
for scene_idx, scene_folder in enumerate(pred_framesPath):
|
||||
per_scene_list_dict = {}
|
||||
for key_i in keys:
|
||||
per_scene_list_dict[key_i] = []
|
||||
pred_candidate = pred_framesPath[scene_idx] # get all frames in pred_framesPath
|
||||
# GT_candidate = GT_framesPath[scene_idx] # get 4800 frames
|
||||
# num_pred_frame_per_folder = len(pred_candidate)
|
||||
|
||||
# save_path = os.path.join(pred_save_path, pred_scenesFolder[scene_idx])
|
||||
save_path = scene_folder[0]
|
||||
# './test_img_dir/XVFInet_exp1/epoch_00099/type1/scene1'
|
||||
|
||||
# excluding both frame0 and frame1 (multiple of 32 indices)
|
||||
for frameIndex, pred_frame in enumerate(pred_candidate):
|
||||
# if server==87:
|
||||
# GTinterFrameIdx = pred_frame.split('/')[-1] # ex) 8, when multiple = 4, # 87 server
|
||||
# else:
|
||||
# GTinterFrameIdx = pred_frame.split('\\')[-1] # ex) 8, when multiple = 4
|
||||
# if not (GTinterFrameIdx % 32) == 0:
|
||||
if frameIndex > 0 and frameIndex < multiple:
|
||||
""" only compute predicted frames (excluding multiples of 32 indices), ex) 8, 16, 24, 40, 48, 56, ... """
|
||||
output_img = cv2.imread(pred_frame).astype(np.float32) # BGR, [0,255]
|
||||
target_img = cv2.imread(pred_frame.replace(pred_save_path, test_data_path)).astype(
|
||||
np.float32) # BGR, [0,255]
|
||||
pred_frame_split = pred_frame.split(os.sep)
|
||||
msg = "[x%d] frame %s, " % (
|
||||
multiple, os.path.join(pred_frame_split[-3], pred_frame_split[-2], pred_frame_split[-1])) # per frame
|
||||
|
||||
if "tOF" in keys: # tOF
|
||||
# if (GTinterFrameIdx % 32) == int(32/multiple):
|
||||
# if (frameIndex % multiple) == 1:
|
||||
if frameIndex == 1:
|
||||
# when first predicted frame in each interval
|
||||
pre_out_grey = cv2.cvtColor(cv2.imread(pred_candidate[0]).astype(np.float32),
|
||||
cv2.COLOR_BGR2GRAY) #### CAUTION BRG
|
||||
# pre_tar_grey = cv2.cvtColor(cv2.imread(pred_candidate[0].replace(pred_save_path, test_data_path)), cv2.COLOR_BGR2GRAY) #### CAUTION BRG
|
||||
pre_tar_grey = pre_out_grey #### CAUTION BRG
|
||||
|
||||
# if not H_match_flag or not W_match_flag:
|
||||
# pre_tar_grey = pre_tar_grey[:new_t_H, :new_t_W, :]
|
||||
|
||||
# pre_tar_grey = pre_out_grey
|
||||
|
||||
output_grey = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||
target_grey = cv2.cvtColor(target_img, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
target_OF = cv2.calcOpticalFlowFarneback(pre_tar_grey, target_grey, None, 0.5, 3, 15, 3, 5, 1.2, 0)
|
||||
output_OF = cv2.calcOpticalFlowFarneback(pre_out_grey, output_grey, None, 0.5, 3, 15, 3, 5, 1.2, 0)
|
||||
# target_OF, ofy, ofx = crop_8x8(target_OF) #check for size reason
|
||||
# output_OF, ofy, ofx = crop_8x8(output_OF)
|
||||
OF_diff = np.absolute(target_OF - output_OF)
|
||||
if flow_flag:
|
||||
""" motion visualization """
|
||||
flow_path = save_path + '_tOF_flow'
|
||||
check_folder(flow_path)
|
||||
# './test_img_dir/XVFInet_exp1/epoch_00099/Fast/003_TEST_Fast_tOF_flow'
|
||||
tOFpath = os.path.join(flow_path, "tOF_flow_%05d.png" % (GTinterFrameIdx))
|
||||
# ex) "./test_img_dir/epoch_005/Fast/003_TEST_Fast/00008_tOF" when start_idx=0, multiple=4, frameIndex=0
|
||||
hsv = np.zeros_like(output_img) # check for size reason
|
||||
hsv[..., 1] = 255
|
||||
mag, ang = cv2.cartToPolar(OF_diff[..., 0], OF_diff[..., 1])
|
||||
# print("tar max %02.6f, min %02.6f, avg %02.6f" % (mag.max(), mag.min(), mag.mean()))
|
||||
maxV = 0.4
|
||||
mag = np.clip(mag, 0.0, maxV) / maxV
|
||||
hsv[..., 0] = ang * 180 / np.pi / 2
|
||||
hsv[..., 2] = mag * 255.0 #
|
||||
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
|
||||
cv2.imwrite(tOFpath, bgr)
|
||||
print("png for motion visualization has been saved in [%s]" %
|
||||
(flow_path))
|
||||
OF_diff_tmp = np.sqrt(np.sum(OF_diff * OF_diff, axis=-1)).mean() # l1 vector norm
|
||||
# OF_diff, ofy, ofx = crop_8x8(OF_diff)
|
||||
total_list_dict["tOF"].append(OF_diff_tmp)
|
||||
per_scene_list_dict["tOF"].append(OF_diff_tmp)
|
||||
msg += "tOF %02.2f, " % (total_list_dict["tOF"][-1])
|
||||
|
||||
pre_out_grey = output_grey
|
||||
pre_tar_grey = target_grey
|
||||
|
||||
# target_img, ofy, ofx = crop_8x8(target_img)
|
||||
# output_img, ofy, ofx = crop_8x8(output_img)
|
||||
|
||||
if "PSNR" in keys: # psnr
|
||||
psnr_tmp = psnr(target_img, output_img)
|
||||
total_list_dict["PSNR"].append(psnr_tmp)
|
||||
per_scene_list_dict["PSNR"].append(psnr_tmp)
|
||||
msg += "PSNR %02.2f" % (total_list_dict["PSNR"][-1])
|
||||
|
||||
if "SSIM" in keys: # ssim
|
||||
ssim_tmp = ssim_bgr(target_img, output_img)
|
||||
total_list_dict["SSIM"].append(ssim_tmp)
|
||||
per_scene_list_dict["SSIM"].append(ssim_tmp)
|
||||
|
||||
msg += ", SSIM %02.2f" % (total_list_dict["SSIM"][-1])
|
||||
|
||||
# msg += ", crop (%d, %d)" % (ofy, ofx) # per frame (not scene)
|
||||
print(msg)
|
||||
|
||||
""" after finishing one scene """
|
||||
per_scene_pd_dict = {} # per scene
|
||||
for cur_key in keys:
|
||||
# save_path = './test_img_dir/XVFInet_exp1/epoch_00099/Fast/003_TEST_Fast'
|
||||
num_data = cur_key + "_[x%d]_[%s]" % (multiple, save_path.split(os.sep)[-2]) # '003_TEST_Fast'
|
||||
# num_data => ex) PSNR_[x8]_[041_TEST_Fast]
|
||||
""" per scene """
|
||||
per_scene_cur_list = np.float32(per_scene_list_dict[cur_key])
|
||||
per_scene_pd_dict[num_data] = pd.Series(per_scene_cur_list) # dictionary
|
||||
per_scene_num_data_sum = per_scene_cur_list.sum()
|
||||
per_scene_num_data_len = per_scene_cur_list.shape[0]
|
||||
per_scene_num_data_mean = per_scene_num_data_sum / per_scene_num_data_len
|
||||
""" accumulation """
|
||||
cur_list = np.float32(total_list_dict[cur_key])
|
||||
num_data_sum = cur_list.sum()
|
||||
num_data_len = cur_list.shape[0] # accum
|
||||
num_data_mean = num_data_sum / num_data_len
|
||||
print(" %s, (per scene) max %02.4f, min %02.4f, avg %02.4f" %
|
||||
(num_data, per_scene_cur_list.max(), per_scene_cur_list.min(), per_scene_num_data_mean)) #
|
||||
|
||||
Total_avg_dict["TotalAvg_" + cur_key] = num_data_mean # accum, update every iteration.
|
||||
|
||||
len_dict[cur_key] = num_data_len # accum, update every iteration.
|
||||
|
||||
# folder_dict["FolderAvg_" + cur_key] += num_data_mean
|
||||
if scene_idx < 5:
|
||||
Type1_dict["Type1Avg_" + cur_key] += per_scene_num_data_mean
|
||||
elif (scene_idx >= 5) and (scene_idx < 10):
|
||||
Type2_dict["Type2Avg_" + cur_key] += per_scene_num_data_mean
|
||||
elif (scene_idx >= 10) and (scene_idx < 15):
|
||||
Type3_dict["Type3Avg_" + cur_key] += per_scene_num_data_mean
|
||||
|
||||
mode = 'w' if scene_idx == 0 else 'a'
|
||||
|
||||
total_csv_path = os.path.join(pred_save_path, "total_metrics.csv")
|
||||
# ex) pred_save_path: './test_img_dir/XVFInet_exp1/epoch_00099' when 'args.epochs=100'
|
||||
pd.DataFrame(per_scene_pd_dict).to_csv(total_csv_path, mode=mode)
|
||||
|
||||
""" combining all results after looping all scenes. """
|
||||
for key in keys:
|
||||
Total_avg_dict["TotalAvg_" + key] = pd.Series(
|
||||
np.float32(Total_avg_dict["TotalAvg_" + key])) # replace key (update)
|
||||
Type1_dict["Type1Avg_" + key] = pd.Series(np.float32(Type1_dict["Type1Avg_" + key] / 5)) # replace key (update)
|
||||
Type2_dict["Type2Avg_" + key] = pd.Series(np.float32(Type2_dict["Type2Avg_" + key] / 5)) # replace key (update)
|
||||
Type3_dict["Type3Avg_" + key] = pd.Series(np.float32(Type3_dict["Type3Avg_" + key] / 5)) # replace key (update)
|
||||
|
||||
print("%s, total frames %d, total avg %02.4f, Type1 avg %02.4f, Type2 avg %02.4f, Type3 avg %02.4f" %
|
||||
(key, len_dict[key], Total_avg_dict["TotalAvg_" + key],
|
||||
Type1_dict["Type1Avg_" + key], Type2_dict["Type2Avg_" + key], Type3_dict["Type3Avg_" + key]))
|
||||
|
||||
pd.DataFrame(Total_avg_dict).to_csv(total_csv_path, mode='a')
|
||||
pd.DataFrame(Type1_dict).to_csv(total_csv_path, mode='a')
|
||||
pd.DataFrame(Type2_dict).to_csv(total_csv_path, mode='a')
|
||||
pd.DataFrame(Type3_dict).to_csv(total_csv_path, mode='a')
|
||||
|
||||
print("csv file of all metrics for all scenes has been saved in [%s]" %
|
||||
(total_csv_path))
|
||||
print("Finished.")
|
||||
|
||||
|
||||
def to_uint8(x, vmin, vmax):
|
||||
##### color space transform, originally from https://github.com/yhjo09/VSR-DUF #####
|
||||
x = x.astype('float32')
|
||||
x = (x - vmin) / (vmax - vmin) * 255 # 0~255
|
||||
return np.clip(np.round(x), 0, 255)
|
||||
|
||||
|
||||
def psnr(img_true, img_pred):
|
||||
##### PSNR with color space transform, originally from https://github.com/yhjo09/VSR-DUF #####
|
||||
"""
|
||||
# img format : [h,w,c], RGB
|
||||
"""
|
||||
# Y_true = _rgb2ycbcr(to_uint8(img_true, 0, 255), 255)[:, :, 0]
|
||||
# Y_pred = _rgb2ycbcr(to_uint8(img_pred, 0, 255), 255)[:, :, 0]
|
||||
diff = img_true - img_pred
|
||||
rmse = np.sqrt(np.mean(np.power(diff, 2)))
|
||||
if rmse == 0:
|
||||
return float('inf')
|
||||
return 20 * np.log10(255. / rmse)
|
||||
|
||||
|
||||
def ssim_bgr(img_true, img_pred): ##### SSIM for BGR, not RGB #####
|
||||
"""
|
||||
# img format : [h,w,c], BGR
|
||||
"""
|
||||
Y_true = _rgb2ycbcr(to_uint8(img_true, 0, 255)[:, :, ::-1], 255)[:, :, 0]
|
||||
Y_pred = _rgb2ycbcr(to_uint8(img_pred, 0, 255)[:, :, ::-1], 255)[:, :, 0]
|
||||
# return compare_ssim(Y_true, Y_pred, data_range=Y_pred.max() - Y_pred.min())
|
||||
return structural_similarity(Y_true, Y_pred, data_range=Y_pred.max() - Y_pred.min())
|
||||
|
||||
|
||||
def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
|
||||
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
||||
return torch.Tensor((image / factor - cent)
|
||||
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
||||
|
||||
|
||||
# [0,255]2[-1,1]2[1,3,H,W]-shaped
|
||||
|
||||
def denorm255(x):
|
||||
out = (x + 1.0) / 2.0
|
||||
return out.clamp_(0.0, 1.0) * 255.0
|
||||
|
||||
|
||||
def denorm255_np(x):
|
||||
# numpy
|
||||
out = (x + 1.0) / 2.0
|
||||
return out.clip(0.0, 1.0) * 255.0
|
||||
|
||||
|
||||
def _rgb2ycbcr(img, maxVal=255):
|
||||
##### color space transform, originally from https://github.com/yhjo09/VSR-DUF #####
|
||||
O = np.array([[16],
|
||||
[128],
|
||||
[128]])
|
||||
T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
|
||||
[-0.148223529411765, -0.290992156862745, 0.439215686274510],
|
||||
[0.439215686274510, -0.367788235294118, -0.071427450980392]])
|
||||
|
||||
if maxVal == 1:
|
||||
O = O / 255.0
|
||||
|
||||
t = np.reshape(img, (img.shape[0] * img.shape[1], img.shape[2]))
|
||||
t = np.dot(t, np.transpose(T))
|
||||
t[:, 0] += O[0]
|
||||
t[:, 1] += O[1]
|
||||
t[:, 2] += O[2]
|
||||
ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
|
||||
|
||||
return ycbcr
|
||||
|
||||
|
||||
class set_smoothness_loss(nn.Module):
|
||||
def __init__(self, weight=150.0, edge_aware=True):
|
||||
super(set_smoothness_loss, self).__init__()
|
||||
self.edge_aware = edge_aware
|
||||
self.weight = weight ** 2
|
||||
|
||||
def forward(self, flow, img):
|
||||
img_gh = torch.mean(torch.pow((img[:, :, 1:, :] - img[:, :, :-1, :]), 2), dim=1, keepdims=True)
|
||||
img_gw = torch.mean(torch.pow((img[:, :, :, 1:] - img[:, :, :, :-1]), 2), dim=1, keepdims=True)
|
||||
|
||||
weight_gh = torch.exp(-self.weight * img_gh)
|
||||
weight_gw = torch.exp(-self.weight * img_gw)
|
||||
|
||||
flow_gh = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :])
|
||||
flow_gw = torch.abs(flow[:, :, :, 1:] - flow[:, :, :, :-1])
|
||||
if self.edge_aware:
|
||||
return (torch.mean(weight_gh * flow_gh) + torch.mean(weight_gw * flow_gw)) * 0.5
|
||||
else:
|
||||
return (torch.mean(flow_gh) + torch.mean(flow_gw)) * 0.5
|
||||
|
||||
|
||||
def get_batch_images(args, save_img_num, save_images): ## For visualization during training phase
|
||||
width_num = len(save_images)
|
||||
log_img = np.zeros((save_img_num * args.patch_size, width_num * args.patch_size, 3), dtype=np.uint8)
|
||||
pred_frameT, pred_coarse_flow, pred_fine_flow, frameT, simple_mean, occ_map = save_images
|
||||
for b in range(save_img_num):
|
||||
output_img_tmp = denorm255(pred_frameT[b, :])
|
||||
output_coarse_flow_tmp = pred_coarse_flow[b, :2, :, :]
|
||||
output_fine_flow_tmp = pred_fine_flow[b, :2, :, :]
|
||||
gt_img_tmp = denorm255(frameT[b, :])
|
||||
simple_mean_img_tmp = denorm255(simple_mean[b, :])
|
||||
occ_map_tmp = occ_map[b, :]
|
||||
|
||||
output_img_tmp = np.transpose(output_img_tmp.detach().cpu().numpy(), [1, 2, 0]).astype(np.uint8)
|
||||
output_coarse_flow_tmp = flow2img(np.transpose(output_coarse_flow_tmp.detach().cpu().numpy(), [1, 2, 0]))
|
||||
output_fine_flow_tmp = flow2img(np.transpose(output_fine_flow_tmp.detach().cpu().numpy(), [1, 2, 0]))
|
||||
gt_img_tmp = np.transpose(gt_img_tmp.detach().cpu().numpy(), [1, 2, 0]).astype(np.uint8)
|
||||
simple_mean_img_tmp = np.transpose(simple_mean_img_tmp.detach().cpu().numpy(), [1, 2, 0]).astype(np.uint8)
|
||||
occ_map_tmp = np.transpose(occ_map_tmp.detach().cpu().numpy() * 255.0, [1, 2, 0]).astype(np.uint8)
|
||||
occ_map_tmp = np.concatenate([occ_map_tmp, occ_map_tmp, occ_map_tmp], axis=2)
|
||||
|
||||
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 0 * args.patch_size:1 * args.patch_size,
|
||||
:] = simple_mean_img_tmp
|
||||
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 1 * args.patch_size:2 * args.patch_size,
|
||||
:] = output_img_tmp
|
||||
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 2 * args.patch_size:3 * args.patch_size,
|
||||
:] = gt_img_tmp
|
||||
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 3 * args.patch_size:4 * args.patch_size,
|
||||
:] = output_coarse_flow_tmp
|
||||
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 4 * args.patch_size:5 * args.patch_size,
|
||||
:] = output_fine_flow_tmp
|
||||
log_img[(b) * args.patch_size:(b + 1) * args.patch_size, 5 * args.patch_size:6 * args.patch_size,
|
||||
:] = occ_map_tmp
|
||||
|
||||
return log_img
|
||||
|
||||
|
||||
def flow2img(flow, logscale=True, scaledown=6, output=False):
|
||||
"""
|
||||
topleft is zero, u is horiz, v is vertical
|
||||
red is 3 o'clock, yellow is 6, light blue is 9, blue/purple is 12
|
||||
"""
|
||||
u = flow[:, :, 1]
|
||||
# u = flow[:, :, 0]
|
||||
v = flow[:, :, 0]
|
||||
# v = flow[:, :, 1]
|
||||
|
||||
colorwheel = makecolorwheel()
|
||||
ncols = colorwheel.shape[0]
|
||||
|
||||
radius = np.sqrt(u ** 2 + v ** 2)
|
||||
if output:
|
||||
print("Maximum flow magnitude: %04f" % np.max(radius))
|
||||
if logscale:
|
||||
radius = np.log(radius + 1)
|
||||
if output:
|
||||
print("Maximum flow magnitude (after log): %0.4f" % np.max(radius))
|
||||
radius = radius / scaledown
|
||||
if output:
|
||||
print("Maximum flow magnitude (after scaledown): %0.4f" % np.max(radius))
|
||||
# rot = np.arctan2(-v, -u) / np.pi
|
||||
rot = np.arctan2(v, u) / np.pi
|
||||
|
||||
fk = (rot + 1) / 2 * (ncols - 1) # -1~1 maped to 0~ncols
|
||||
k0 = fk.astype(np.uint8) # 0, 1, 2, ..., ncols
|
||||
|
||||
k1 = k0 + 1
|
||||
k1[k1 == ncols] = 0
|
||||
|
||||
f = fk - k0
|
||||
|
||||
ncolors = colorwheel.shape[1]
|
||||
img = np.zeros(u.shape + (ncolors,))
|
||||
for i in range(ncolors):
|
||||
tmp = colorwheel[:, i]
|
||||
col0 = tmp[k0]
|
||||
col1 = tmp[k1]
|
||||
col = (1 - f) * col0 + f * col1
|
||||
|
||||
idx = radius <= 1
|
||||
# increase saturation with radius
|
||||
col[idx] = 1 - radius[idx] * (1 - col[idx])
|
||||
# out of range
|
||||
col[~idx] *= 0.75
|
||||
# img[:,:,i] = np.floor(255*col).astype(np.uint8)
|
||||
|
||||
img[:, :, i] = np.clip(255 * col, 0.0, 255.0).astype(np.uint8)
|
||||
|
||||
# return img.astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
def makecolorwheel():
|
||||
# Create a colorwheel for visualization
|
||||
RY = 15
|
||||
YG = 6
|
||||
GC = 4
|
||||
CB = 11
|
||||
BM = 13
|
||||
MR = 6
|
||||
|
||||
ncols = RY + YG + GC + CB + BM + MR
|
||||
|
||||
colorwheel = np.zeros((ncols, 3))
|
||||
|
||||
col = 0
|
||||
# RY
|
||||
colorwheel[col:col + RY, 0] = 1
|
||||
colorwheel[col:col + RY, 1] = np.arange(0, 1, 1. / RY)
|
||||
col += RY
|
||||
|
||||
# YG
|
||||
colorwheel[col:col + YG, 0] = np.arange(1, 0, -1. / YG)
|
||||
colorwheel[col:col + YG, 1] = 1
|
||||
col += YG
|
||||
|
||||
# GC
|
||||
colorwheel[col:col + GC, 1] = 1
|
||||
colorwheel[col:col + GC, 2] = np.arange(0, 1, 1. / GC)
|
||||
col += GC
|
||||
|
||||
# CB
|
||||
colorwheel[col:col + CB, 1] = np.arange(1, 0, -1. / CB)
|
||||
colorwheel[col:col + CB, 2] = 1
|
||||
col += CB
|
||||
|
||||
# BM
|
||||
colorwheel[col:col + BM, 2] = 1
|
||||
colorwheel[col:col + BM, 0] = np.arange(0, 1, 1. / BM)
|
||||
col += BM
|
||||
|
||||
# MR
|
||||
colorwheel[col:col + MR, 2] = np.arange(1, 0, -1. / MR)
|
||||
colorwheel[col:col + MR, 0] = 1
|
||||
|
||||
return colorwheel
|
Loading…
Reference in New Issue