Updated RIFE-CUDA to support new 3.2-3.5 models with fallback for 3.0-3.1

This commit is contained in:
N00MKRAD 2021-06-15 15:41:15 +02:00
parent 532d556d1e
commit 7abf45f673
14 changed files with 753 additions and 285 deletions

1
.gitignore vendored
View File

@ -32,6 +32,7 @@ bld/
[Ll]ogs/ [Ll]ogs/
Flowframes*.7z Flowframes*.7z
FF*.7z FF*.7z
Build/WebInstaller
# NMKD Python Redist Pkg # NMKD Python Redist Pkg
[Pp]y*/ [Pp]y*/

View File

@ -91,12 +91,9 @@ class IFNet(nn.Module):
self.block2 = IFBlock(8, scale=2, c=96) self.block2 = IFBlock(8, scale=2, c=96)
self.block3 = IFBlock(8, scale=1, c=48) self.block3 = IFBlock(8, scale=1, c=48)
def forward(self, x, UHD=False): def forward(self, x, scale=1.0):
if UHD: x = F.interpolate(x, scale_factor=0.5 * scale, mode="bilinear",
x = F.interpolate(x, scale_factor=0.25, mode="bilinear", align_corners=False) align_corners=False)
else:
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
align_corners=False)
flow0 = self.block0(x) flow0 = self.block0(x)
F1 = flow0 F1 = flow0
warped_img0 = warp(x[:, :3], F1) warped_img0 = warp(x[:, :3], F1)
@ -111,6 +108,8 @@ class IFNet(nn.Module):
warped_img1 = warp(x[:, 3:], -F3) warped_img1 = warp(x[:, 3:], -F3)
flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3), 1)) flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3), 1))
F4 = (flow0 + flow1 + flow2 + flow3) F4 = (flow0 + flow1 + flow2 + flow3)
F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear",
align_corners=False) / scale
return F4, [F1, F2, F3, F4] return F4, [F1, F2, F3, F4]
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -61,26 +61,28 @@ class IFNet(nn.Module):
self.block2 = IFBlock(10, scale=2, c=96) self.block2 = IFBlock(10, scale=2, c=96)
self.block3 = IFBlock(10, scale=1, c=48) self.block3 = IFBlock(10, scale=1, c=48)
def forward(self, x, UHD=False): def forward(self, x, scale=1.0):
if UHD: if scale != 1.0:
x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False) x = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False)
flow0 = self.block0(x) flow0 = self.block0(x)
F1 = flow0 F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2]) warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4]) warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1)) flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1) F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2]) warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4]) warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1)) flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2) F3 = (flow0 + flow1 + flow2)
F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F3_large[:, :2]) warped_img0 = warp(x[:, :3], F3_large[:, :2])
warped_img1 = warp(x[:, 3:], F3_large[:, 2:4]) warped_img1 = warp(x[:, 3:], F3_large[:, 2:4])
flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1)) flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1))
F4 = (flow0 + flow1 + flow2 + flow3) F4 = (flow0 + flow1 + flow2 + flow3)
if scale != 1.0:
F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale
return F4, [F1, F2, F3, F4] return F4, [F1, F2, F3, F4]
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -2,23 +2,22 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from model.warplayer import warp from model.warplayer import warp
from model.refine import *
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True), padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(out_planes),
nn.PReLU(out_planes) nn.PReLU(out_planes)
) )
@ -26,56 +25,93 @@ class IFBlock(nn.Module):
def __init__(self, in_planes, c=64): def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__() super(IFBlock, self).__init__()
self.conv0 = nn.Sequential( self.conv0 = nn.Sequential(
conv(in_planes, c, 3, 2, 1), conv(in_planes, c//2, 3, 2, 1),
conv(c, 2*c, 3, 2, 1), conv(c//2, c, 3, 2, 1),
) )
self.convblock0 = nn.Sequential( self.convblock0 = nn.Sequential(
conv(2*c, 2*c), conv(c, c),
conv(2*c, 2*c), conv(c, c)
) )
self.convblock1 = nn.Sequential( self.convblock1 = nn.Sequential(
conv(2*c, 2*c), conv(c, c),
conv(2*c, 2*c), conv(c, c)
) )
self.convblock2 = nn.Sequential( self.convblock2 = nn.Sequential(
conv(2*c, 2*c), conv(c, c),
conv(2*c, 2*c), conv(c, c)
) )
self.conv1 = nn.ConvTranspose2d(2*c, 4, 4, 2, 1) self.convblock3 = nn.Sequential(
conv(c, c),
conv(c, c)
)
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(c, 4, 4, 2, 1),
)
self.conv2 = nn.ConvTranspose2d(c, 1, 4, 2, 1)
def forward(self, x, flow=None, scale=1): def forward(self, x, flow, scale=1):
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
if flow != None: flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * (1. / scale) feat = self.conv0(torch.cat((x, flow), 1))
x = torch.cat((x, flow), 1) feat = self.convblock0(feat) + feat
x = self.conv0(x) feat = self.convblock1(feat) + feat
x = self.convblock0(x) + x feat = self.convblock2(feat) + feat
x = self.convblock1(x) + x feat = self.convblock3(feat) + feat
x = self.convblock2(x) + x flow = self.conv1(feat)
x = self.conv1(x) mask = self.conv2(feat)
flow = x flow = F.interpolate(flow, scale_factor=scale*2, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale*2
if scale != 1: mask = F.interpolate(mask, scale_factor=scale*2, mode="bilinear", align_corners=False, recompute_scale_factor=False)
flow = F.interpolate(flow, scale_factor= scale, mode="bilinear", align_corners=False) * scale return flow, mask
return flow
class IFNet(nn.Module): class IFNet(nn.Module):
def __init__(self): def __init__(self):
super(IFNet, self).__init__() super(IFNet, self).__init__()
self.block0 = IFBlock(6, c=80) self.block0 = IFBlock(7+4, c=90)
self.block1 = IFBlock(10, c=80) self.block1 = IFBlock(7+4, c=90)
self.block2 = IFBlock(10, c=80) self.block2 = IFBlock(7+4, c=90)
self.block_tea = IFBlock(10+4, c=90)
# self.contextnet = Contextnet()
# self.unet = Unet()
def forward(self, x, scale_list=[4,2,1]): def forward(self, x, scale_list=[4, 2, 1], scale=1.0, training=False):
flow0 = self.block0(x, scale=scale_list[0]) x = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False)
F1 = flow0 if training == False:
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 channel = x.shape[1] // 2
warped_img0 = warp(x[:, :3], F1_large[:, :2]) img0 = x[:, :channel]
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4]) img1 = x[:, channel:]
flow1 = self.block1(torch.cat((warped_img0, warped_img1), 1), F1_large, scale=scale_list[1]) flow_list = []
F2 = (flow0 + flow1) merged = []
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 mask_list = []
warped_img0 = warp(x[:, :3], F2_large[:, :2]) warped_img0 = img0
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4]) warped_img1 = img1
flow2 = self.block2(torch.cat((warped_img0, warped_img1), 1), F2_large, scale=scale_list[2]) flow = torch.zeros_like(x[:, :4]).to(device)
F3 = (flow0 + flow1 + flow2) mask = torch.zeros_like(x[:, :1]).to(device)
return F3, [F1, F2, F3] loss_cons = 0
block = [self.block0, self.block1, self.block2]
for i in range(3):
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
mask = mask + (m0 + (-m1)) / 2
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append((warped_img0, warped_img1))
if scale != 1.0:
flow = F.interpolate(flow, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale
mask_list[2] = F.interpolate(mask_list[2], scale_factor=1 / scale, mode="bilinear", align_corners=False)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged[2] = (warped_img0, warped_img1)
'''
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, 1:4] * 2 - 1
'''
for i in range(3):
mask_list[i] = torch.sigmoid(mask_list[i])
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
# merged[i] = torch.clamp(merged[i] + res, 0, 1)
return flow_list, mask_list[2], merged

View File

@ -135,7 +135,7 @@ class Model:
self.optimG = AdamW(itertools.chain( self.optimG = AdamW(itertools.chain(
self.flownet.parameters(), self.flownet.parameters(),
self.contextnet.parameters(), self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR( self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE() self.epe = EPE()
@ -188,11 +188,9 @@ class Model:
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False): def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3] img0 = imgs[:, :3]
img1 = imgs[:, 3:] img1 = imgs[:, 3:]
if UHD:
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
c0 = self.contextnet(img0, flow) c0 = self.contextnet(img0, flow)
c1 = self.contextnet(img1, -flow) c1 = self.contextnet(img1, -flow)
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
@ -209,10 +207,10 @@ class Model:
else: else:
return pred return pred
def inference(self, img0, img1, UHD=False): def inference(self, img0, img1, scale=1.0):
imgs = torch.cat((img0, img1), 1) imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs, UHD) flow, _ = self.flownet(imgs, scale)
return self.predict(imgs, flow, training=False, UHD=UHD) return self.predict(imgs, flow, training=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups: for param_group in self.optimG.param_groups:

View File

@ -120,7 +120,7 @@ class Model:
self.optimG = AdamW(itertools.chain( self.optimG = AdamW(itertools.chain(
self.flownet.parameters(), self.flownet.parameters(),
self.contextnet.parameters(), self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR( self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE() self.epe = EPE()
@ -173,11 +173,9 @@ class Model:
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False): def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3] img0 = imgs[:, :3]
img1 = imgs[:, 3:] img1 = imgs[:, 3:]
if UHD:
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
c0 = self.contextnet(img0, flow[:, :2]) c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4]) c1 = self.contextnet(img1, flow[:, 2:4])
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
@ -194,10 +192,10 @@ class Model:
else: else:
return pred return pred
def inference(self, img0, img1, UHD=False): def inference(self, img0, img1, scale=1.0):
imgs = torch.cat((img0, img1), 1) imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs, UHD) flow, _ = self.flownet(imgs, scale)
return self.predict(imgs, flow, training=False, UHD=UHD) return self.predict(imgs, flow, training=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups: for param_group in self.optimG.param_groups:

View File

@ -11,145 +11,28 @@ import torch.nn.functional as F
from model.loss import * from model.loss import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 32
class ContextNet(nn.Module):
def __init__(self):
super(ContextNet, self).__init__()
self.conv0 = Conv2(3, c)
self.conv1 = Conv2(c, c)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow):
x = self.conv0(x)
x = self.conv1(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class FusionNet(nn.Module):
def __init__(self):
super(FusionNet, self).__init__()
self.conv0 = Conv2(10, c)
self.down0 = Conv2(c, 2*c)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.ConvTranspose2d(c, 4, 4, 2, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None
else:
warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1))
s0 = self.down0(x)
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model: class Model:
def __init__(self, local_rank=-1): def __init__(self, local_rank=-1):
self.flownet = IFNet() self.flownet = IFNet()
self.contextnet = ContextNet()
self.fusionnet = FusionNet()
self.device() self.device()
self.optimG = AdamW(itertools.chain( self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
self.flownet.parameters(),
self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5)
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE() self.epe = EPE()
self.ter = Ternary() # self.vgg = VGGPerceptualLoss().to(device)
self.sobel = SOBEL() self.sobel = SOBEL()
if local_rank != -1: if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[ self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
local_rank], output_device=local_rank)
self.contextnet = DDP(self.contextnet, device_ids=[
local_rank], output_device=local_rank)
self.fusionnet = DDP(self.fusionnet, device_ids=[
local_rank], output_device=local_rank)
def train(self): def train(self):
self.flownet.train() self.flownet.train()
self.contextnet.train()
self.fusionnet.train()
def eval(self): def eval(self):
self.flownet.eval() self.flownet.eval()
self.contextnet.eval()
self.fusionnet.eval()
def device(self): def device(self):
self.flownet.to(device) self.flownet.to(device)
self.contextnet.to(device)
self.fusionnet.to(device)
def load_model(self, path, rank): def load_model(self, path, rank=0):
def convert(param): def convert(param):
if rank == -1: if rank == -1:
return { return {
@ -160,90 +43,46 @@ class Model:
else: else:
return param return param
if rank <= 0: if rank <= 0:
self.flownet.load_state_dict( if torch.cuda.is_available():
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device))) self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))))
self.contextnet.load_state_dict( else:
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device))) self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu')))
self.fusionnet.load_state_dict(
convert(torch.load('{}/unet.pkl'.format(path), map_location=device))) def save_model(self, path, rank=0):
def save_model(self, path, rank):
if rank == 0: if rank == 0:
torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path)) torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False): def inference(self, img0, img1, scale=1.0):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if UHD:
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
mask = torch.sigmoid(refine_output[:, 3:4])
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
pred = merged_img + res
pred = torch.clamp(pred, 0, 1)
if training:
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
else:
return pred
def inference(self, img0, img1, UHD=False):
imgs = torch.cat((img0, img1), 1) imgs = torch.cat((img0, img1), 1)
scale_list = [8, 4, 2] scale_list = [4, 2, 1]
flow, _ = self.flownet(imgs, scale_list) flow, mask, merged = self.flownet(imgs, scale_list, scale=scale)
res = self.predict(imgs, flow, training=False, UHD=False) return merged[2]
return res
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups: for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate param_group['lr'] = learning_rate
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if training: if training:
self.train() self.train()
else: else:
self.eval() self.eval()
flow, flow_list = self.flownet(imgs) scale = [4, 2, 1]
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict( flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
imgs, flow, flow_gt=flow_gt) loss_l1 = (merged[2] - gt).abs().mean()
loss_ter = self.ter(pred, gt).mean() loss_smooth = self.sobel(flow[2], flow[2]*0).mean()
if training: # loss_vgg = self.vgg(merged[2], gt)
with torch.no_grad():
loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5).detach()
loss_cons = 0
for i in range(4):
loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01
else:
loss_cons = torch.tensor([0])
loss_flow = torch.abs(warped_img0 - gt).mean()
loss_mask = 1
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
if training: if training:
self.optimG.zero_grad() self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_ter loss_G = loss_cons + loss_smooth * 0.1
loss_G.backward() loss_G.backward()
self.optimG.step() self.optimG.step()
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask else:
flow_teacher = flow[2]
return merged[2], {
if __name__ == '__main__': 'mask': mask,
img0 = torch.zeros(3, 3, 256, 256).float().to(device) 'flow': flow[2][:, :2],
img1 = torch.tensor(np.random.normal( 'loss_l1': loss_l1,
0, 1, (3, 3, 256, 256))).float().to(device) 'loss_cons': loss_cons,
imgs = torch.cat((img0, img1), 1) 'loss_smooth': loss_smooth,
model = Model() }
model.eval()
print(model.inference(imgs).shape)

View File

@ -0,0 +1,91 @@
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
from model.loss import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 16
class Contextnet(nn.Module):
def __init__(self):
super(Contextnet, self).__init__()
self.conv1 = Conv2(3, c)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow):
x = self.conv1(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class Unet(nn.Module):
def __init__(self):
super(Unet, self).__init__()
self.down0 = Conv2(17, 2*c)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.Conv2d(c, 4, 3, 1, 1)
def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return torch.sigmoid(x)

View File

@ -0,0 +1,81 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model_v3_legacy.warplayer import warp
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c, 3, 2, 1),
conv(c, 2*c, 3, 2, 1),
)
self.convblock0 = nn.Sequential(
conv(2*c, 2*c),
conv(2*c, 2*c),
)
self.convblock1 = nn.Sequential(
conv(2*c, 2*c),
conv(2*c, 2*c),
)
self.convblock2 = nn.Sequential(
conv(2*c, 2*c),
conv(2*c, 2*c),
)
self.conv1 = nn.ConvTranspose2d(2*c, 4, 4, 2, 1)
def forward(self, x, flow=None, scale=1):
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
if flow != None:
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * (1. / scale)
x = torch.cat((x, flow), 1)
x = self.conv0(x)
x = self.convblock0(x) + x
x = self.convblock1(x) + x
x = self.convblock2(x) + x
x = self.conv1(x)
flow = x
if scale != 1:
flow = F.interpolate(flow, scale_factor= scale, mode="bilinear", align_corners=False) * scale
return flow
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(6, c=80)
self.block1 = IFBlock(10, c=80)
self.block2 = IFBlock(10, c=80)
def forward(self, x, scale_list=[4,2,1]):
flow0 = self.block0(x, scale=scale_list[0])
F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1), 1), F1_large, scale=scale_list[1])
F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1), 1), F2_large, scale=scale_list[2])
F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3]

View File

@ -0,0 +1,249 @@
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from model_v3_legacy.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from model_v3_legacy.IFNet_HDv3 import *
import torch.nn.functional as F
from model_v3_legacy.loss import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 32
class ContextNet(nn.Module):
def __init__(self):
super(ContextNet, self).__init__()
self.conv0 = Conv2(3, c)
self.conv1 = Conv2(c, c)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow):
x = self.conv0(x)
x = self.conv1(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class FusionNet(nn.Module):
def __init__(self):
super(FusionNet, self).__init__()
self.conv0 = Conv2(10, c)
self.down0 = Conv2(c, 2*c)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.ConvTranspose2d(c, 4, 4, 2, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None
else:
warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1))
s0 = self.down0(x)
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.contextnet = ContextNet()
self.fusionnet = FusionNet()
self.device()
self.optimG = AdamW(itertools.chain(
self.flownet.parameters(),
self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5)
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE()
self.ter = Ternary()
self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[
local_rank], output_device=local_rank)
self.contextnet = DDP(self.contextnet, device_ids=[
local_rank], output_device=local_rank)
self.fusionnet = DDP(self.fusionnet, device_ids=[
local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
self.contextnet.train()
self.fusionnet.train()
def eval(self):
self.flownet.eval()
self.contextnet.eval()
self.fusionnet.eval()
def device(self):
self.flownet.to(device)
self.contextnet.to(device)
self.fusionnet.to(device)
def load_model(self, path, rank):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
if rank <= 0:
self.flownet.load_state_dict(
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
self.contextnet.load_state_dict(
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
self.fusionnet.load_state_dict(
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
def save_model(self, path, rank):
if rank == 0:
torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path))
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if UHD:
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
mask = torch.sigmoid(refine_output[:, 3:4])
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
pred = merged_img + res
pred = torch.clamp(pred, 0, 1)
if training:
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
else:
return pred
def inference(self, img0, img1, UHD=False):
imgs = torch.cat((img0, img1), 1)
scale_list = [8, 4, 2]
flow, _ = self.flownet(imgs, scale_list)
res = self.predict(imgs, flow, training=False, UHD=False)
return res
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
if training:
self.train()
else:
self.eval()
flow, flow_list = self.flownet(imgs)
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
imgs, flow, flow_gt=flow_gt)
loss_ter = self.ter(pred, gt).mean()
if training:
with torch.no_grad():
loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5).detach()
loss_cons = 0
for i in range(4):
loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01
else:
loss_cons = torch.tensor([0])
loss_flow = torch.abs(warped_img0 - gt).mean()
loss_mask = 1
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
if training:
self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_ter
loss_G.backward()
self.optimG.step()
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
model = Model()
model.eval()
print(model.inference(imgs).shape)

View File

@ -0,0 +1,128 @@
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EPE(nn.Module):
def __init__(self):
super(EPE, self).__init__()
def forward(self, flow, gt, loss_mask):
loss_map = (flow - gt.detach()) ** 2
loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
return (loss_map * loss_mask)
class Ternary(nn.Module):
def __init__(self):
super(Ternary, self).__init__()
patch_size = 7
out_channels = patch_size * patch_size
self.w = np.eye(out_channels).reshape(
(patch_size, patch_size, 1, out_channels))
self.w = np.transpose(self.w, (3, 2, 0, 1))
self.w = torch.tensor(self.w).float().to(device)
def transform(self, img):
patches = F.conv2d(img, self.w, padding=3, bias=None)
transf = patches - img
transf_norm = transf / torch.sqrt(0.81 + transf**2)
return transf_norm
def rgb2gray(self, rgb):
r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray
def hamming(self, t1, t2):
dist = (t1 - t2) ** 2
dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
return dist_norm
def valid_mask(self, t, padding):
n, _, h, w = t.size()
inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
mask = F.pad(inner, [padding] * 4)
return mask
def forward(self, img0, img1):
img0 = self.transform(self.rgb2gray(img0))
img1 = self.transform(self.rgb2gray(img1))
return self.hamming(img0, img1) * self.valid_mask(img0, 1)
class SOBEL(nn.Module):
def __init__(self):
super(SOBEL, self).__init__()
self.kernelX = torch.tensor([
[1, 0, -1],
[2, 0, -2],
[1, 0, -1],
]).float()
self.kernelY = self.kernelX.clone().T
self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)
def forward(self, pred, gt):
N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
img_stack = torch.cat(
[pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]
L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
loss = (L1X+L1Y)
return loss
class MeanShift(nn.Conv2d):
def __init__(self, data_mean, data_std, data_range=1, norm=True):
c = len(data_mean)
super(MeanShift, self).__init__(c, c, kernel_size=1)
std = torch.Tensor(data_std)
self.weight.data = torch.eye(c).view(c, c, 1, 1)
if norm:
self.weight.data.div_(std.view(c, 1, 1, 1))
self.bias.data = -1 * data_range * torch.Tensor(data_mean)
self.bias.data.div_(std)
else:
self.weight.data.mul_(std.view(c, 1, 1, 1))
self.bias.data = data_range * torch.Tensor(data_mean)
self.requires_grad = False
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, rank=0):
super(VGGPerceptualLoss, self).__init__()
blocks = []
pretrained = True
self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
for param in self.parameters():
param.requires_grad = False
def forward(self, X, Y, indices=None):
X = self.normalize(X)
Y = self.normalize(Y)
indices = [2, 7, 12, 21, 30]
weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
k = 0
loss = 0
for i in range(indices[-1]):
X = self.vgg_pretrained_features[i](X)
Y = self.vgg_pretrained_features[i](Y)
if (i+1) in indices:
loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
k += 1
return loss
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
ternary_loss = Ternary()
print(ternary_loss(img0, img1).shape)

View File

@ -0,0 +1,22 @@
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backwarp_tenGrid = {}
def warp(tenInput, tenFlow):
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat(
[tenHorizontal, tenVertical], 1).to(device)
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)

View File

@ -6,23 +6,33 @@
}, },
{ {
"name": "RIFE 2.3", "name": "RIFE 2.3",
"desc": "Updated General Model", "desc": "General Model",
"dir": "RIFE23" "dir": "RIFE23"
}, },
{ {
"name": "RIFE 2.4", "name": "RIFE 2.4",
"desc": "Updated General Model (Sometimes worse than 2.3)", "desc": "Latest v2 General Model (Sometimes worse than 2.3)",
"dir": "RIFE24" "dir": "RIFE24"
}, },
{ {
"name": "RIFE 3.0", "name": "RIFE 3.0",
"desc": "Updated General Model", "desc": "v3 General Model",
"dir": "RIFE30", "dir": "RIFE30",
}, },
{ {
"name": "RIFE 3.1", "name": "RIFE 3.1",
"desc": "Latest General Model", "desc": "Updated v3 General Model",
"dir": "RIFE31", "dir": "RIFE31"
},
{
"name": "RIFE 3.4",
"desc": "Updated v3 General/Animation Model",
"dir": "RIFE34"
},
{
"name": "RIFE 3.5",
"desc": "Latest v3 General/Animation Model",
"dir": "RIFE35",
"isDefault": "true" "isDefault": "true"
} }
] ]

View File

@ -35,6 +35,9 @@ parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try
parser.add_argument('--exp', dest='exp', type=int, default=1) parser.add_argument('--exp', dest='exp', type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
assert (not args.input is None) assert (not args.input is None)
if args.UHD and args.scale==1.0:
args.scale = 0.5
assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
@ -56,20 +59,31 @@ except:
try: try:
try: try:
from model.RIFE_HDv2 import Model print(f"Trying to load v3 (new) model from {os.path.join(dname, args.model)}")
model = Model()
model.load_model(os.path.join(dname, args.model), -1)
print("Loaded v2.x HD model.")
except:
from model.RIFE_HDv3 import Model from model.RIFE_HDv3 import Model
model = Model() model = Model()
model.load_model(os.path.join(dname, args.model), -1) model.load_model(os.path.join(dname, args.model), -1)
print("Loaded v3.x HD model.") print("Loaded v3.x HD model.")
except:
try:
print(f"Trying to load v3 (legacy) model from {os.path.join(dname, args.model)}")
from model_v3_legacy.RIFE_HDv3 import Model
model = Model()
model.load_model(os.path.join(dname, args.model), -1)
print("Loaded v3.x HD model.")
except:
print(f"Trying to load v2 model from {os.path.join(dname, args.model)}")
from model.RIFE_HDv2 import Model
model = Model()
model.load_model(os.path.join(dname, args.model), -1)
print("Loaded v2.x HD model.")
except: except:
print(f"Trying to load v1 model from {os.path.join(dname, args.model)}")
from model.RIFE_HD import Model from model.RIFE_HD import Model
model = Model() model = Model()
model.load_model(os.path.join(dname, args.model), -1) model.load_model(os.path.join(dname, args.model), -1)
print("Loaded v1.x HD model") print("Loaded v1.x HD model")
model.eval() model.eval()
model.device() model.device()
@ -118,7 +132,7 @@ def build_read_buffer(user_args, read_buffer, videogen):
def make_inference(I0, I1, exp): def make_inference(I0, I1, exp):
global model global model
middle = model.inference(I0, I1, args.UHD) middle = model.inference(I0, I1, args.scale)
if exp == 1: if exp == 1:
return [middle] return [middle]
first_half = make_inference(I0, middle, exp=exp - 1) first_half = make_inference(I0, middle, exp=exp - 1)