import torch import torch.nn as nn import numpy as np from torch.optim import AdamW import torch.optim as optim import itertools from model.warplayer import warp from torch.nn.parallel import DistributedDataParallel as DDP from model.IFNet_HD import * import torch.nn.functional as F from model.loss import * device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True), nn.PReLU(out_planes) ) def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): return nn.Sequential( torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), nn.PReLU(out_planes) ) def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True), ) class ResBlock(nn.Module): def __init__(self, in_planes, out_planes, stride=2): super(ResBlock, self).__init__() if in_planes == out_planes and stride == 1: self.conv0 = nn.Identity() else: self.conv0 = nn.Conv2d(in_planes, out_planes, 3, stride, 1, bias=False) self.conv1 = conv(in_planes, out_planes, 3, stride, 1) self.conv2 = conv_woact(out_planes, out_planes, 3, 1, 1) self.relu1 = nn.PReLU(1) self.relu2 = nn.PReLU(out_planes) self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False) self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False) def forward(self, x): y = self.conv0(x) x = self.conv1(x) x = self.conv2(x) w = x.mean(3, True).mean(2, True) w = self.relu1(self.fc1(w)) w = torch.sigmoid(self.fc2(w)) x = self.relu2(x * w + y) return x c = 32 class ContextNet(nn.Module): def __init__(self): super(ContextNet, self).__init__() self.conv0 = conv(3, c, 3, 2, 1) self.conv1 = ResBlock(c, c) self.conv2 = ResBlock(c, 2*c) self.conv3 = ResBlock(2*c, 4*c) self.conv4 = ResBlock(4*c, 8*c) def forward(self, x, flow): x = self.conv0(x) x = self.conv1(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 f1 = warp(x, flow) x = self.conv2(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 f2 = warp(x, flow) x = self.conv3(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 f3 = warp(x, flow) x = self.conv4(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 f4 = warp(x, flow) return [f1, f2, f3, f4] class FusionNet(nn.Module): def __init__(self): super(FusionNet, self).__init__() self.conv0 = conv(8, c, 3, 2, 1) self.down0 = ResBlock(c, 2*c) self.down1 = ResBlock(4*c, 4*c) self.down2 = ResBlock(8*c, 8*c) self.down3 = ResBlock(16*c, 16*c) self.up0 = deconv(32*c, 8*c) self.up1 = deconv(16*c, 4*c) self.up2 = deconv(8*c, 2*c) self.up3 = deconv(4*c, c) self.conv = nn.Conv2d(c, 16, 3, 1, 1) self.up4 = nn.PixelShuffle(2) def forward(self, img0, img1, flow, c0, c1, flow_gt): warped_img0 = warp(img0, flow) warped_img1 = warp(img1, -flow) if flow_gt == None: warped_img0_gt, warped_img1_gt = None, None else: warped_img0_gt = warp(img0, flow_gt[:, :2]) warped_img1_gt = warp(img1, flow_gt[:, 2:4]) x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1)) s0 = self.down0(x) s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) x = self.up1(torch.cat((x, s2), 1)) x = self.up2(torch.cat((x, s1), 1)) x = self.up3(torch.cat((x, s0), 1)) x = self.up4(self.conv(x)) return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt class Model: def __init__(self, local_rank=-1): self.flownet = IFNet() self.contextnet = ContextNet() self.fusionnet = FusionNet() self.device() self.optimG = AdamW(itertools.chain( self.flownet.parameters(), self.contextnet.parameters(), self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4) self.schedulerG = optim.lr_scheduler.CyclicLR( self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.epe = EPE() self.ter = Ternary() self.sobel = SOBEL() if local_rank != -1: self.flownet = DDP(self.flownet, device_ids=[ local_rank], output_device=local_rank) self.contextnet = DDP(self.contextnet, device_ids=[ local_rank], output_device=local_rank) self.fusionnet = DDP(self.fusionnet, device_ids=[ local_rank], output_device=local_rank) def train(self): self.flownet.train() self.contextnet.train() self.fusionnet.train() def eval(self): self.flownet.eval() self.contextnet.eval() self.fusionnet.eval() def device(self): self.flownet.to(device) self.contextnet.to(device) self.fusionnet.to(device) def load_model(self, path, rank): def convert(param): if rank == -1: return { k.replace("module.", ""): v for k, v in param.items() if "module." in k } else: return param if rank <= 0: self.flownet.load_state_dict( convert(torch.load('{}/flownet.pkl'.format(path), map_location=device))) self.contextnet.load_state_dict( convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device))) self.fusionnet.load_state_dict( convert(torch.load('{}/unet.pkl'.format(path), map_location=device))) def save_model(self, path, rank): if rank == 0: torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path)) torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) def predict(self, imgs, flow, training=True, flow_gt=None): img0 = imgs[:, :3] img1 = imgs[:, 3:] c0 = self.contextnet(img0, flow) c1 = self.contextnet(img1, -flow) flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet( img0, img1, flow, c0, c1, flow_gt) res = torch.sigmoid(refine_output[:, :3]) * 2 - 1 mask = torch.sigmoid(refine_output[:, 3:4]) merged_img = warped_img0 * mask + warped_img1 * (1 - mask) pred = merged_img + res pred = torch.clamp(pred, 0, 1) if training: return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt else: return pred def inference(self, img0, img1, scale=1.0): imgs = torch.cat((img0, img1), 1) flow, _ = self.flownet(imgs, scale) return self.predict(imgs, flow, training=False) def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): for param_group in self.optimG.param_groups: param_group['lr'] = learning_rate if training: self.train() else: self.eval() flow, flow_list = self.flownet(imgs) pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict( imgs, flow, flow_gt=flow_gt) loss_ter = self.ter(pred, gt).mean() if training: with torch.no_grad(): loss_flow = torch.abs(warped_img0_gt - gt).mean() loss_mask = torch.abs( merged_img - gt).sum(1, True).float().detach() loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear", align_corners=False).detach() flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5).detach() loss_cons = 0 for i in range(3): loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1) loss_cons += self.epe(-flow_list[i], flow_gt[:, 2:4], 1) loss_cons = loss_cons.mean() * 0.01 else: loss_cons = torch.tensor([0]) loss_flow = torch.abs(warped_img0 - gt).mean() loss_mask = 1 loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean() if training: self.optimG.zero_grad() loss_G = loss_l1 + loss_cons + loss_ter loss_G.backward() self.optimG.step() return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask if __name__ == '__main__': img0 = torch.zeros(3, 3, 256, 256).float().to(device) img1 = torch.tensor(np.random.normal( 0, 1, (3, 3, 256, 256))).float().to(device) imgs = torch.cat((img0, img1), 1) model = Model() model.eval() print(model.inference(imgs).shape)