Added FLAVR CUDA package

This commit is contained in:
N00MKRAD 2021-03-11 15:35:39 +01:00
parent 692f6c9be9
commit 17b4f0fa43
10 changed files with 1340 additions and 0 deletions

View File

@ -0,0 +1,58 @@
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
import glob
class Davis(Dataset):
def __init__(self, data_root , ext="png"):
super().__init__()
self.data_root = data_root
self.images_sets = []
for label_id in os.listdir(self.data_root):
ctg_imgs_ = sorted(os.listdir(os.path.join(self.data_root , label_id)))
ctg_imgs_ = [os.path.join(self.data_root , label_id , img_id) for img_id in ctg_imgs_]
for start_idx in range(0,len(ctg_imgs_)-6,2):
add_files = ctg_imgs_[start_idx : start_idx+7 : 2]
add_files = add_files[:2] + [ctg_imgs_[start_idx+3]] + add_files[2:]
self.images_sets.append(add_files)
self.transforms = transforms.Compose([
transforms.CenterCrop((480,840)),
transforms.ToTensor()
])
print(len(self.images_sets))
def __getitem__(self, idx):
imgpaths = self.images_sets[idx]
images = [Image.open(img) for img in imgpaths]
images = [self.transforms(img) for img in images]
return images[:2] + images[3:] , [images[2]]
def __len__(self):
return len(self.images_sets)
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True):
dataset = Davis(data_root)
return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
if __name__ == "__main__":
dataset = Davis("./Davis_test/")
print(len(dataset))
dataloader = DataLoader(dataset , batch_size=1, shuffle=True, num_workers=0)

View File

@ -0,0 +1,82 @@
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
import glob
import pdb
class GoPro(Dataset):
def __init__(self, data_root , mode="train", interFrames=3, n_inputs=4, ext="png"):
super().__init__()
self.interFrames = interFrames
self.n_inputs = n_inputs
self.setLength = (n_inputs-1)*(interFrames+1)+1 ## We require these many frames in total for interpolating `interFrames` number of
## intermediate frames with `n_input` input frames.
self.data_root = os.path.join(data_root , mode)
video_list = os.listdir(self.data_root)
self.frames_list = []
self.file_list = []
for video in video_list:
frames = sorted(os.listdir(os.path.join(self.data_root , video)))
n_sets = (len(frames) - self.setLength)//(interFrames+1) + 1
videoInputs = [frames[(interFrames+1)*i:(interFrames+1)*i+self.setLength ] for i in range(n_sets)]
videoInputs = [[os.path.join(video , f) for f in group] for group in videoInputs]
self.file_list.extend(videoInputs)
self.transforms = transforms.Compose([
transforms.CenterCrop(512),
transforms.ToTensor()
])
def __getitem__(self, idx):
imgpaths = [os.path.join(self.data_root , fp) for fp in self.file_list[idx]]
if random.random() > 0.5:
imgpaths = imgpaths[::-1] ## random temporal flip
# We can use compression based augmentations
pick_idxs = list(range(0,self.setLength,self.interFrames+1))
rem = self.interFrames%2
gt_idx = list(range(self.setLength//2-self.interFrames//2 , self.setLength//2+self.interFrames//2+rem))
input_paths = [imgpaths[idx] for idx in pick_idxs]
gt_paths = [imgpaths[idx] for idx in gt_idx]
images = [Image.open(pth_) for pth_ in input_paths]
images = [self.transforms(img_) for img_ in images]
gt_images = [Image.open(pth_) for pth_ in gt_paths]
gt_images = [self.transforms(img_) for img_ in gt_images]
return images , gt_images
def __len__(self):
return len(self.file_list)
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True, interFrames=3, n_inputs=4):
if test_mode:
mode = "test"
else:
mode = "train"
dataset = GoPro(data_root , mode, interFrames=interFrames, n_inputs=n_inputs)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
if __name__ == "__main__":
dataset = GoPro("./GoPro" , mode="train", interFrames=3, n_inputs=4)
print(len(dataset))
dataloader = DataLoader(dataset , batch_size=1, shuffle=True, num_workers=0)

View File

@ -0,0 +1,46 @@
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
import glob
class Middelburry(Dataset):
def __init__(self, data_root , ext="png"):
super().__init__()
self.data_root = data_root
self.file_list = os.listdir(self.data_root)
self.transforms = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, idx):
imgpath = os.path.join(self.data_root , self.file_list[idx])
name = self.file_list[idx]
if name == "Teddy": ## Handle inputs with just two inout frames. FLAVR takes atleast 4.
imgpaths = [os.path.join(imgpath , "frame10.png") , os.path.join(imgpath , "frame10.png") ,os.path.join(imgpath , "frame11.png") ,os.path.join(imgpath , "frame11.png") ]
else:
imgpaths = [os.path.join(imgpath , "frame09.png") , os.path.join(imgpath , "frame10.png") ,os.path.join(imgpath , "frame11.png") ,os.path.join(imgpath , "frame12.png") ]
images = [Image.open(img).convert('RGB') for img in imgpaths]
images = [self.transforms(img) for img in images]
sizes = images[0].shape
return images , name
def __len__(self):
return len(self.file_list)
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True):
dataset = Middelburry(data_root)
return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

View File

@ -0,0 +1,419 @@
# from https://github.com/facebookresearch/VMZ
import torch
import numbers
import random
from torchvision.transforms import RandomCrop, RandomResizedCrop
__all__ = [
"RandomCropVideo",
"RandomResizedCropVideo",
"CenterCropVideo",
"NormalizeVideo",
"ToTensorVideo",
"RandomHorizontalFlipVideo",
"Resize",
"TemporalCenterCrop",
"RandomTemporalFlipVideo",
"RandomVerticalFlipVideo"
]
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tesnor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
"""
assert len(clip.size()) == 4, "clip should be a 4D tensor"
return clip[..., i : i + h, j : j + w]
def temporal_center_crop(clip, clip_len):
"""
Args:
clip (torch.tensor): Video clip to be
cropped along the temporal axis. Size is (C, T, H, W)
"""
assert len(clip.size()) == 4, "clip should be a 4D tensor"
assert clip.size(1) >= clip_len, "clip is shorter than the proposed lenght"
middle = int(clip.size(1) // 2)
start = middle - clip_len // 2
return clip[:, start : start + clip_len, ...]
def resize(clip, target_size, interpolation_mode):
assert len(target_size) == 2, "target size should be tuple (height, width)"
return torch.nn.functional.interpolate(
clip, size=target_size, mode=interpolation_mode, align_corners=False
)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
assert h >= th and w >= tw, "height and width must be >= than crop_size"
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimenions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError(
"clip tensor should have data type uint8. Got %s" % str(clip.dtype)
)
return clip.float().permute(3, 0, 1, 2) / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
return clip.flip((-1))
def vflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
return clip.flip((-2))
def tflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
return clip.flip((-3))
class RandomCropVideo(RandomCrop):
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, OH, OW)
"""
i, j, h, w = self.get_params(clip, self.size)
return crop(clip, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + "(size={0})".format(self.size)
class RandomResizedCropVideo(RandomResizedCrop):
def __init__(
self,
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
assert len(size) == 2, "size should be tuple (height, width)"
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
self.scale = scale
self.ratio = ratio
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, H, W)
"""
i, j, h, w = self.get_params(clip, self.scale, self.ratio)
return resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
def __repr__(self):
return (
self.__class__.__name__
+ "(size={0}, interpolation_mode={1}, scale={2}, ratio={3})".format(
self.size, self.interpolation_mode, self.scale, self.ratio
)
)
class CenterCropVideo(object):
def __init__(self, crop_size):
if isinstance(crop_size, numbers.Number):
self.crop_size = (int(crop_size), int(crop_size))
else:
self.crop_size = crop_size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: central cropping of video clip. Size is
(C, T, crop_size, crop_size)
"""
return center_crop(clip, self.crop_size)
def __repr__(self):
r = self.__class__.__name__ + "(crop_size={0})".format(self.crop_size)
return r
class TemporalCenterCrop(object):
def __init__(self, clip_len):
self.clip_len = clip_len
def __call__(self, clip):
return temporal_center_crop(clip, self.clip_len)
class UnfoldClips(object):
def __init__(self, clip_len, overlap):
self.clip_len = clip_len
assert overlap > 0 and overlap <= 1
self.step = round(clip_len * overlap)
def __call__(self, clip):
if clip.size(1) < self.clip_len:
return clip.unfold(1, clip.size(1), clip.size(1)).permute(1, 0, 4, 2, 3)
results = clip.unfold(1, self.clip_len, self.clip_len).permute(1, 0, 4, 2, 3)
return results
class TempPadClip(object):
def __init__(self, clip_len):
self.num_frames = clip_len
def __call__(self, clip):
if clip.size(1) == 0:
return clip
if clip.size(1) < self.num_frames:
# do something and return
step = clip.size(1) / self.num_frames
idxs = torch.arange(self.num_frames, dtype=torch.float32) * step
idxs = idxs.floor().to(torch.int64)
return clip[:, idxs, ...]
step = clip.size(1) / self.num_frames
if step.is_integer():
# optimization: if step is integer, don't need to perform
# advanced indexing
step = int(step)
return clip[:, slice(None, None, step), ...]
idxs = torch.arange(self.num_frames, dtype=torch.float32) * step
idxs = idxs.floor().to(torch.int64)
return clip[:, idxs, ...]
class NormalizeVideo(object):
"""
Normalize the video clip by mean subtraction
and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip to be
normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1}, inplace={2})".format(
self.mean, self.std, self.inplace
)
class ToTensorVideo(object):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimenions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
return to_tensor(clip)
def __repr__(self):
return self.__class__.__name__
class RandomHorizontalFlipVideo(object):
"""
Flip the video clip along the horizonal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (C, T, H, W)
Return:
clip (torch.tensor): Size is (C, T, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self):
return self.__class__.__name__ + "(p={0})".format(self.p)
class RandomVerticalFlipVideo(object):
"""
Flip the video clip along the horizonal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (C, T, H, W)
Return:
clip (torch.tensor): Size is (C, T, H, W)
"""
if random.random() < self.p:
clip = vflip(clip)
return clip
def __repr__(self):
return self.__class__.__name__ + "(p={0})".format(self.p)
class RandomTemporalFlipVideo(object):
"""
Flip the video clip along the horizonal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (C, T, H, W)
Return:
clip (torch.tensor): Size is (C, T, H, W)
"""
if random.random() < self.p:
clip = tflip(clip)
return clip
def __repr__(self):
return self.__class__.__name__ + "(p={0})".format(self.p)
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, vid):
return resize(vid, self.size, interpolation_mode="bilinear")

View File

@ -0,0 +1,100 @@
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
class VimeoSepTuplet(Dataset):
def __init__(self, data_root, is_training , input_frames="1357"):
"""
Creates a Vimeo Septuplet object.
Inputs.
data_root: Root path for the Vimeo dataset containing the sep tuples.
is_training: Train/Test.
input_frames: Which frames to input for frame interpolation network.
"""
self.data_root = data_root
self.image_root = os.path.join(self.data_root, 'sequences')
self.training = is_training
self.inputs = input_frames
train_fn = os.path.join(self.data_root, 'sep_trainlist.txt')
test_fn = os.path.join(self.data_root, 'sep_testlist.txt')
with open(train_fn, 'r') as f:
self.trainlist = f.read().splitlines()
with open(test_fn, 'r') as f:
self.testlist = f.read().splitlines()
if self.training:
self.transforms = transforms.Compose([
transforms.RandomCrop(256),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomVerticalFlip(0.5),
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
transforms.ToTensor()
])
else:
self.transforms = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, index):
if self.training:
imgpath = os.path.join(self.image_root, self.trainlist[index])
else:
imgpath = os.path.join(self.image_root, self.testlist[index])
imgpaths = [imgpath + f'/im{i}.png' for i in range(1,8)]
pth_ = imgpaths
# Load images
images = [Image.open(pth) for pth in imgpaths]
## Select only relevant inputs
inputs = [int(e)-1 for e in list(self.inputs)]
inputs = inputs[:len(inputs)//2] + [3] + inputs[len(inputs)//2:]
images = [images[i] for i in inputs]
imgpaths = [imgpaths[i] for i in inputs]
# Data augmentation
if self.training:
seed = random.randint(0, 2**32)
images_ = []
for img_ in images:
random.seed(seed)
images_.append(self.transforms(img_))
images = images_
# Random Temporal Flip
if random.random() >= 0.5:
images = images[::-1]
imgpaths = imgpaths[::-1]
else:
T = self.transforms
images = [T(img_) for img_ in images]
gt = images[len(images)//2]
images = images[:len(images)//2] + images[len(images)//2+1:]
return images, [gt]
def __len__(self):
if self.training:
return len(self.trainlist)
else:
return len(self.testlist)
def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode=None):
if mode == 'train':
is_training = True
else:
is_training = False
dataset = VimeoSepTuplet(data_root, is_training=is_training)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
if __name__ == "__main__":
dataset = VimeoSepTuplet("./vimeo_septuplet/", is_training=True)
dataloader = DataLoader(dataset, batch_size=100, shuffle=False, num_workers=32, pin_memory=True)

166
Pkgs/flavr-cuda/flavr.py Normal file
View File

@ -0,0 +1,166 @@
import os
import torch
import cv2
import pdb
import time
import sys
import torchvision
from PIL import Image
import numpy as np
import _thread
from torchvision.io import read_video, write_video
from dataset.transforms import ToTensorVideo, Resize
import torch.nn.functional as F
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
print("Changing working dir to {0}".format(dname))
os.chdir(os.path.dirname(dname))
print("Added {0} to temporary PATH".format(dname))
sys.path.append(dname)
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input', dest='input', type=str, default=None)
parser.add_argument('--output', required=False, default='frames-interpolated')
parser.add_argument("--factor", type=int, choices=[2,4,8], help="How much interpolation needed. 2x/4x/8x.")
parser.add_argument("--model", type=str, help="path for stored model")
parser.add_argument("--up_mode", type=str, help="Upsample Mode", default="transpose")
parser.add_argument('--fp16', dest='fp16', action='store_true', help='half-precision mode')
parser.add_argument('--imgformat', default="png")
parser.add_argument("--output_ext", type=str, help="Output video format", default=".avi")
parser.add_argument("--input_ext", type=str, help="Input video format", default=".mp4")
parser.add_argument("--downscale", type=float, help="Downscale input res. for memory", default=1)
args = parser.parse_args()
input_ext = args.input_ext
path = args.input
base = os.path.basename(path)
interp_input_path = os.path.join(dname, args.input)
interp_output_path = os.path.join(dname, args.output)
print("\interp_input_path: " + interp_input_path)
print("\ninterp_output_path: " + interp_output_path)
if args.input.endswith("/"):
video_name = args.input.split("/")[-2].split(input_ext)[0]
else:
video_name = args.input.split("/")[-1].split(input_ext)[0]
output_video = os.path.join(video_name + f"_{args.factor}x" + str(args.output_ext))
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if(args.fp16):
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print("FLAVR is running in FP16 mode.")
else:
print("WARNING: CUDA is not available, FLAVR is running on CPU! [ff:nocuda-cpu]")
n_outputs = args.factor - 1
model_name = "unet_18"
nbr_frame = 4
joinType = "concat"
def loadModel(model, checkpoint):
saved_state_dict = torch.load(checkpoint)['state_dict']
saved_state_dict = {k.partition("module.")[-1]:v for k,v in saved_state_dict.items()}
model.load_state_dict(saved_state_dict)
checkpoint = os.path.join(dname, args.model)
from model.FLAVR_arch import UNet_3D_3D
model = UNet_3D_3D(model_name.lower(), n_inputs=4, n_outputs=n_outputs, joinType=joinType, upmode=args.up_mode)
loadModel(model, checkpoint)
model = model.cuda()
in_files = sorted(os.listdir(interp_input_path))
def make_image(img):
q_im = img.data.mul(255.).clamp(0,255).round()
im = q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
return im
def files_to_videoTensor(path, downscale=1.):
from PIL import Image
global in_files
in_files_fixed = in_files
in_files_fixed.insert(0, in_files[0]) # Workaround: Insert extra entry before
in_files_fixed.append(in_files[-1]) # Workaround: Insert extra entry after
images = [torch.Tensor(np.asarray(Image.open(os.path.join(path, f)))).type(torch.uint8) for f in in_files]
print(images[0].shape)
videoTensor = torch.stack(images)
return videoTensor
def video_transform(videoTensor, downscale=1):
T, H, W = videoTensor.size(0), videoTensor.size(1), videoTensor.size(2)
downscale = int(downscale * 8)
resizes = 8*(H//downscale), 8*(W//downscale)
transforms = torchvision.transforms.Compose([ToTensorVideo(), Resize(resizes)])
videoTensor = transforms(videoTensor)
print("Resizing to %dx%d"%(resizes[0], resizes[1]) )
return videoTensor, resizes
videoTensor = files_to_videoTensor(interp_input_path, args.downscale)
print(f"Video Tensor len: {len(videoTensor)}")
idxs = torch.Tensor(range(len(videoTensor))).type(torch.long).view(1, -1).unfold(1,size=nbr_frame,step=1).squeeze(0)
print(f"len(idxs): {len(idxs)}")
videoTensor, resizes = video_transform(videoTensor, args.downscale)
print("Video tensor shape is ", videoTensor.shape)
frames = torch.unbind(videoTensor, 1)
n_inputs = len(frames)
width = n_outputs + 1
model = model.eval()
frame_num = 1
def load_and_write_img (path_write, path_load):
cv2.imwrite(path_write, cv2.imread(path_load), [cv2.IMWRITE_PNG_COMPRESSION, 1])
def write_img (path_write, img):
cv2.imwrite(path_write, img, [cv2.IMWRITE_PNG_COMPRESSION, 1])
for i in (range(len(idxs))):
idxSet = idxs[i]
inputs = [frames[idx_].cuda().unsqueeze(0) for idx_ in idxSet]
with torch.no_grad():
outputFrame = model(inputs)
outputFrame = [of.squeeze(0).cpu().data for of in outputFrame]
#outputs.extend(outputFrame)
#outputs.append(inputs[2].squeeze(0).cpu().data)
print(f"Frame {i}")
print(f"Writing source frame {'{:0>8d}.{}'.format(frame_num, args.imgformat)}")
input_frame_path = os.path.join(interp_input_path, in_files[i+1])
#cv2.imwrite('{}/{:0>8d}.{}'.format(interp_output_path, frame_num, args.imgformat), cv2.imread(input_frame_path), [cv2.IMWRITE_PNG_COMPRESSION, 2])
_thread.start_new_thread(load_and_write_img, ('{}/{:0>8d}.{}'.format(interp_output_path, frame_num, args.imgformat), input_frame_path))
frame_num += 1
for img in outputFrame:
print(f"Writing interp frame {'{:0>8d}.{}'.format(frame_num, args.imgformat)}")
# cv2.imwrite('{}/{:0>8d}.{}'.format(interp_output_path, frame_num, args.imgformat), make_image(img), [cv2.IMWRITE_PNG_COMPRESSION, 2])
_thread.start_new_thread(write_img, ('{}/{:0>8d}.{}'.format(interp_output_path, frame_num, args.imgformat), make_image(img)))
frame_num += 1
print(f"Writing source frame {frame_num} [LAST]")
input_frame_path = os.path.join(interp_input_path, in_files[-1])
cv2.imwrite('{}/{:0>8d}.{}'.format(interp_output_path, frame_num, args.imgformat), cv2.imread(input_frame_path), [cv2.IMWRITE_PNG_COMPRESSION, 2]) # Last input frame
time.sleep(0.5)

View File

@ -0,0 +1,178 @@
import math
import numpy as np
import importlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet_3D import SEGating
def joinTensors(X1 , X2 , type="concat"):
if type == "concat":
return torch.cat([X1 , X2] , dim=1)
elif type == "add":
return X1 + X2
else:
return X1
class Conv_2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=False, batchnorm=False):
super().__init__()
self.conv = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
if batchnorm:
self.conv += [nn.BatchNorm2d(out_ch)]
self.conv = nn.Sequential(*self.conv)
def forward(self, x):
return self.conv(x)
class upConv3D(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
super().__init__()
self.upmode = upmode
if self.upmode=="transpose":
self.upconv = nn.ModuleList(
[nn.ConvTranspose3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding),
SEGating(out_ch)
]
)
else:
self.upconv = nn.ModuleList(
[nn.Upsample(mode='trilinear', scale_factor=(1,2,2), align_corners=False),
nn.Conv3d(in_ch, out_ch , kernel_size=1 , stride=1),
SEGating(out_ch)
]
)
if batchnorm:
self.upconv += [nn.BatchNorm3d(out_ch)]
self.upconv = nn.Sequential(*self.upconv)
def forward(self, x):
return self.upconv(x)
class Conv_3d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True, batchnorm=False):
super().__init__()
self.conv = [nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
SEGating(out_ch)
]
if batchnorm:
self.conv += [nn.BatchNorm3d(out_ch)]
self.conv = nn.Sequential(*self.conv)
def forward(self, x):
return self.conv(x)
class upConv2D(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
super().__init__()
self.upmode = upmode
if self.upmode=="transpose":
self.upconv = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)]
else:
self.upconv = [
nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
nn.Conv2d(in_ch, out_ch , kernel_size=1 , stride=1)
]
if batchnorm:
self.upconv += [nn.BatchNorm2d(out_ch)]
self.upconv = nn.Sequential(*self.upconv)
def forward(self, x):
return self.upconv(x)
class UNet_3D_3D(nn.Module):
def __init__(self, block , n_inputs, n_outputs, batchnorm=False , joinType="concat" , upmode="transpose"):
super().__init__()
nf = [512 , 256 , 128 , 64]
out_channels = 3*n_outputs
self.joinType = joinType
self.n_outputs = n_outputs
growth = 2 if joinType == "concat" else 1
self.lrelu = nn.LeakyReLU(0.2, True)
unet_3D = importlib.import_module(".resnet_3D" , "model")
if n_outputs > 1:
unet_3D.useBias = True
self.encoder = getattr(unet_3D , block)(pretrained=False , bn=batchnorm)
self.decoder = nn.Sequential(
Conv_3d(nf[0], nf[1] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
upConv3D(nf[1]*growth, nf[2], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
upConv3D(nf[2]*growth, nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
Conv_3d(nf[3]*growth, nf[3] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
upConv3D(nf[3]*growth , nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm)
)
self.feature_fuse = Conv_2d(nf[3]*n_inputs , nf[3] , kernel_size=1 , stride=1, batchnorm=batchnorm)
self.outconv = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(nf[3], out_channels , kernel_size=7 , stride=1, padding=0)
)
def forward(self, images):
images = torch.stack(images , dim=2)
## Batch mean normalization works slightly better than global mean normalization, thanks to https://github.com/myungsub/CAIN
mean_ = images.mean(2, keepdim=True).mean(3, keepdim=True).mean(4,keepdim=True)
images = images-mean_
x_0 , x_1 , x_2 , x_3 , x_4 = self.encoder(images)
dx_3 = self.lrelu(self.decoder[0](x_4))
dx_3 = joinTensors(dx_3 , x_3 , type=self.joinType)
dx_2 = self.lrelu(self.decoder[1](dx_3))
dx_2 = joinTensors(dx_2 , x_2 , type=self.joinType)
dx_1 = self.lrelu(self.decoder[2](dx_2))
dx_1 = joinTensors(dx_1 , x_1 , type=self.joinType)
dx_0 = self.lrelu(self.decoder[3](dx_1))
dx_0 = joinTensors(dx_0 , x_0 , type=self.joinType)
dx_out = self.lrelu(self.decoder[4](dx_0))
dx_out = torch.cat(torch.unbind(dx_out , 2) , 1)
out = self.lrelu(self.feature_fuse(dx_out))
out = self.outconv(out)
out = torch.split(out, dim=1, split_size_or_sections=3)
mean_ = mean_.squeeze(2)
out = [o+mean_ for o in out]
return out

View File

View File

@ -0,0 +1,288 @@
# Modified from https://github.com/pytorch/vision/tree/master/torchvision/models/video
import torch
import torch.nn as nn
__all__ = ['unet_18', 'unet_34']
useBias = False
class identity(nn.Module):
def __init__(self , *args , **kwargs):
super().__init__()
def forward(self , x):
return x
class Conv3DSimple(nn.Conv3d):
def __init__(self,
in_planes,
out_planes,
midplanes=None,
stride=1,
padding=1):
super(Conv3DSimple, self).__init__(
in_channels=in_planes,
out_channels=out_planes,
kernel_size=(3, 3, 3),
stride=stride,
padding=padding,
bias=useBias)
@staticmethod
def get_downsample_stride(stride , temporal_stride):
if temporal_stride:
return (temporal_stride, stride, stride)
else:
return (stride , stride , stride)
class BasicStem(nn.Sequential):
"""The default conv-batchnorm-relu stem
"""
def __init__(self):
super().__init__(
nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
padding=(1, 3, 3), bias=useBias),
batchnorm(64),
nn.ReLU(inplace=False))
class Conv2Plus1D(nn.Sequential):
def __init__(self,
in_planes,
out_planes,
midplanes,
stride=1,
padding=1):
if not isinstance(stride , int):
temporal_stride , stride , stride = stride
else:
temporal_stride = stride
super(Conv2Plus1D, self).__init__(
nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
stride=(1, stride, stride), padding=(0, padding, padding),
bias=False),
# batchnorm(midplanes),
nn.ReLU(inplace=True),
nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
stride=(temporal_stride, 1, 1), padding=(padding, 0, 0),
bias=False))
@staticmethod
def get_downsample_stride(stride , temporal_stride):
if temporal_stride:
return (temporal_stride, stride, stride)
else:
return (stride , stride , stride)
class R2Plus1dStem(nn.Sequential):
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution
"""
def __init__(self):
super().__init__(
nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
stride=(1, 2, 2), padding=(0, 3, 3),
bias=False),
batchnorm(45),
nn.ReLU(inplace=True),
nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
stride=(1, 1, 1), padding=(1, 0, 0),
bias=False),
batchnorm(64),
nn.ReLU(inplace=True))
class SEGating(nn.Module):
def __init__(self , inplanes , reduction=16):
super().__init__()
self.pool = nn.AdaptiveAvgPool3d(1)
self.attn_layer = nn.Sequential(
nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True),
nn.Sigmoid()
)
def forward(self , x):
out = self.pool(x)
y = self.attn_layer(out)
return x * y
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
super(BasicBlock, self).__init__()
self.conv1 = nn.Sequential(
conv_builder(inplanes, planes, midplanes, stride),
batchnorm(planes),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
conv_builder(planes, planes, midplanes),
batchnorm(planes)
)
self.fg = SEGating(planes) ## Feature Gating
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
out = self.fg(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class VideoResNet(nn.Module):
def __init__(self, block, conv_makers, layers,
stem, zero_init_residual=False):
"""Generic resnet video generator.
Args:
block (nn.Module): resnet building block
conv_makers (list(functions)): generator function for each layer
layers (List[int]): number of blocks per layer
stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
"""
super(VideoResNet, self).__init__()
self.inplanes = 64
self.stem = stem()
self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1 )
self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2 , temporal_stride=1)
self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2 , temporal_stride=1)
self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=1, temporal_stride=1)
# init weights
self._initialize_weights()
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
def forward(self, x):
x_0 = self.stem(x)
x_1 = self.layer1(x_0)
x_2 = self.layer2(x_1)
x_3 = self.layer3(x_2)
x_4 = self.layer4(x_3)
return x_0 , x_1 , x_2 , x_3 , x_4
def _make_layer(self, block, conv_builder, planes, blocks, stride=1, temporal_stride=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
ds_stride = conv_builder.get_downsample_stride(stride , temporal_stride)
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=ds_stride, bias=False),
batchnorm(planes * block.expansion)
)
stride = ds_stride
layers = []
layers.append(block(self.inplanes, planes, conv_builder, stride, downsample ))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, conv_builder ))
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out',
nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
model = VideoResNet(**kwargs)
## TODO: Other 3D resnet models, like S3D, r(2+1)D.
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def unet_18(pretrained=False, bn=False, progress=True, **kwargs):
"""
Construct 18 layer Unet3D model as in
https://arxiv.org/abs/1711.11248
Args:
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
nn.Module: R3D-18 encoder
"""
global batchnorm
if bn:
batchnorm = nn.BatchNorm3d
else:
batchnorm = identity
return _video_resnet('r3d_18',
pretrained, progress,
block=BasicBlock,
conv_makers=[Conv3DSimple] * 4,
layers=[2, 2, 2, 2],
stem=BasicStem, **kwargs)
def unet_34(pretrained=False, bn=False, progress=True, **kwargs):
"""
Construct 34 layer Unet3D model as in
https://arxiv.org/abs/1711.11248
Args:
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
nn.Module: R3D-18 encoder
"""
global batchnorm
# bn = False
if bn:
batchnorm = nn.BatchNorm3d
else:
batchnorm = identity
return _video_resnet('r3d_34',
pretrained, progress,
block=BasicBlock,
conv_makers=[Conv3DSimple] * 4,
layers=[3, 4, 6, 3],
stem=BasicStem, **kwargs)

View File

@ -0,0 +1,3 @@
FLAVR 2x - Official model for 2x interpolation
FLAVR 4x - Official model for 4x interpolation
FLAVR 8x - Official model for 8x interpolation