From 1da13688c4658c7c36007a47b4b1191de3aaeeb8 Mon Sep 17 00:00:00 2001 From: N00MKRAD Date: Tue, 18 May 2021 09:21:02 +0200 Subject: [PATCH] Revert broken RIFE v1/v2 code --- Pkgs/rife-cuda/model/IFNet_HD.py | 11 ++++++----- Pkgs/rife-cuda/model/IFNet_HDv2.py | 14 ++++++-------- Pkgs/rife-cuda/model/RIFE_HD.py | 12 +++++++----- Pkgs/rife-cuda/model/RIFE_HDv2.py | 12 +++++++----- 4 files changed, 26 insertions(+), 23 deletions(-) diff --git a/Pkgs/rife-cuda/model/IFNet_HD.py b/Pkgs/rife-cuda/model/IFNet_HD.py index 6975679..fe315b2 100644 --- a/Pkgs/rife-cuda/model/IFNet_HD.py +++ b/Pkgs/rife-cuda/model/IFNet_HD.py @@ -91,9 +91,12 @@ class IFNet(nn.Module): self.block2 = IFBlock(8, scale=2, c=96) self.block3 = IFBlock(8, scale=1, c=48) - def forward(self, x, scale=1.0): - x = F.interpolate(x, scale_factor=0.5 * scale, mode="bilinear", - align_corners=False) + def forward(self, x, UHD=False): + if UHD: + x = F.interpolate(x, scale_factor=0.25, mode="bilinear", align_corners=False) + else: + x = F.interpolate(x, scale_factor=0.5, mode="bilinear", + align_corners=False) flow0 = self.block0(x) F1 = flow0 warped_img0 = warp(x[:, :3], F1) @@ -108,8 +111,6 @@ class IFNet(nn.Module): warped_img1 = warp(x[:, 3:], -F3) flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3), 1)) F4 = (flow0 + flow1 + flow2 + flow3) - F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear", - align_corners=False) / scale return F4, [F1, F2, F3, F4] if __name__ == '__main__': diff --git a/Pkgs/rife-cuda/model/IFNet_HDv2.py b/Pkgs/rife-cuda/model/IFNet_HDv2.py index c7002d3..f9b18cf 100644 --- a/Pkgs/rife-cuda/model/IFNet_HDv2.py +++ b/Pkgs/rife-cuda/model/IFNet_HDv2.py @@ -61,28 +61,26 @@ class IFNet(nn.Module): self.block2 = IFBlock(10, scale=2, c=96) self.block3 = IFBlock(10, scale=1, c=48) - def forward(self, x, scale=1.0): - if scale != 1.0: - x = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False) + def forward(self, x, UHD=False): + if UHD: + x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False) flow0 = self.block0(x) F1 = flow0 - F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 + F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 warped_img0 = warp(x[:, :3], F1_large[:, :2]) warped_img1 = warp(x[:, 3:], F1_large[:, 2:4]) flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1)) F2 = (flow0 + flow1) - F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 + F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 warped_img0 = warp(x[:, :3], F2_large[:, :2]) warped_img1 = warp(x[:, 3:], F2_large[:, 2:4]) flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1)) F3 = (flow0 + flow1 + flow2) - F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 + F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 warped_img0 = warp(x[:, :3], F3_large[:, :2]) warped_img1 = warp(x[:, 3:], F3_large[:, 2:4]) flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1)) F4 = (flow0 + flow1 + flow2 + flow3) - if scale != 1.0: - F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale return F4, [F1, F2, F3, F4] if __name__ == '__main__': diff --git a/Pkgs/rife-cuda/model/RIFE_HD.py b/Pkgs/rife-cuda/model/RIFE_HD.py index 47df49a..b96576f 100644 --- a/Pkgs/rife-cuda/model/RIFE_HD.py +++ b/Pkgs/rife-cuda/model/RIFE_HD.py @@ -135,7 +135,7 @@ class Model: self.optimG = AdamW(itertools.chain( self.flownet.parameters(), self.contextnet.parameters(), - self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4) + self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.schedulerG = optim.lr_scheduler.CyclicLR( self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.epe = EPE() @@ -188,9 +188,11 @@ class Model: torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) - def predict(self, imgs, flow, training=True, flow_gt=None): + def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False): img0 = imgs[:, :3] img1 = imgs[:, 3:] + if UHD: + flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 c0 = self.contextnet(img0, flow) c1 = self.contextnet(img1, -flow) flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", @@ -207,10 +209,10 @@ class Model: else: return pred - def inference(self, img0, img1, scale=1.0): + def inference(self, img0, img1, UHD=False): imgs = torch.cat((img0, img1), 1) - flow, _ = self.flownet(imgs, scale) - return self.predict(imgs, flow, training=False) + flow, _ = self.flownet(imgs, UHD) + return self.predict(imgs, flow, training=False, UHD=UHD) def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): for param_group in self.optimG.param_groups: diff --git a/Pkgs/rife-cuda/model/RIFE_HDv2.py b/Pkgs/rife-cuda/model/RIFE_HDv2.py index ce5cd56..9f19ae2 100644 --- a/Pkgs/rife-cuda/model/RIFE_HDv2.py +++ b/Pkgs/rife-cuda/model/RIFE_HDv2.py @@ -120,7 +120,7 @@ class Model: self.optimG = AdamW(itertools.chain( self.flownet.parameters(), self.contextnet.parameters(), - self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4) + self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.schedulerG = optim.lr_scheduler.CyclicLR( self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.epe = EPE() @@ -173,9 +173,11 @@ class Model: torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) - def predict(self, imgs, flow, training=True, flow_gt=None): + def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False): img0 = imgs[:, :3] img1 = imgs[:, 3:] + if UHD: + flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 c0 = self.contextnet(img0, flow[:, :2]) c1 = self.contextnet(img1, flow[:, 2:4]) flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", @@ -192,10 +194,10 @@ class Model: else: return pred - def inference(self, img0, img1, scale=1.0): + def inference(self, img0, img1, UHD=False): imgs = torch.cat((img0, img1), 1) - flow, _ = self.flownet(imgs, scale) - return self.predict(imgs, flow, training=False) + flow, _ = self.flownet(imgs, UHD) + return self.predict(imgs, flow, training=False, UHD=UHD) def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): for param_group in self.optimG.param_groups: