mirror of https://github.com/n00mkrad/flowframes
Added FLAVR CUDA package
This commit is contained in:
parent
692f6c9be9
commit
17b4f0fa43
|
@ -0,0 +1,58 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import random
|
||||
import glob
|
||||
|
||||
|
||||
class Davis(Dataset):
|
||||
def __init__(self, data_root , ext="png"):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.data_root = data_root
|
||||
self.images_sets = []
|
||||
|
||||
for label_id in os.listdir(self.data_root):
|
||||
|
||||
ctg_imgs_ = sorted(os.listdir(os.path.join(self.data_root , label_id)))
|
||||
ctg_imgs_ = [os.path.join(self.data_root , label_id , img_id) for img_id in ctg_imgs_]
|
||||
for start_idx in range(0,len(ctg_imgs_)-6,2):
|
||||
add_files = ctg_imgs_[start_idx : start_idx+7 : 2]
|
||||
add_files = add_files[:2] + [ctg_imgs_[start_idx+3]] + add_files[2:]
|
||||
self.images_sets.append(add_files)
|
||||
|
||||
self.transforms = transforms.Compose([
|
||||
transforms.CenterCrop((480,840)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
print(len(self.images_sets))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
imgpaths = self.images_sets[idx]
|
||||
images = [Image.open(img) for img in imgpaths]
|
||||
images = [self.transforms(img) for img in images]
|
||||
|
||||
return images[:2] + images[3:] , [images[2]]
|
||||
|
||||
def __len__(self):
|
||||
|
||||
return len(self.images_sets)
|
||||
|
||||
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True):
|
||||
|
||||
dataset = Davis(data_root)
|
||||
return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dataset = Davis("./Davis_test/")
|
||||
|
||||
print(len(dataset))
|
||||
|
||||
dataloader = DataLoader(dataset , batch_size=1, shuffle=True, num_workers=0)
|
|
@ -0,0 +1,82 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import random
|
||||
import glob
|
||||
import pdb
|
||||
|
||||
|
||||
class GoPro(Dataset):
|
||||
def __init__(self, data_root , mode="train", interFrames=3, n_inputs=4, ext="png"):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.interFrames = interFrames
|
||||
self.n_inputs = n_inputs
|
||||
self.setLength = (n_inputs-1)*(interFrames+1)+1 ## We require these many frames in total for interpolating `interFrames` number of
|
||||
## intermediate frames with `n_input` input frames.
|
||||
self.data_root = os.path.join(data_root , mode)
|
||||
|
||||
video_list = os.listdir(self.data_root)
|
||||
self.frames_list = []
|
||||
|
||||
self.file_list = []
|
||||
for video in video_list:
|
||||
frames = sorted(os.listdir(os.path.join(self.data_root , video)))
|
||||
n_sets = (len(frames) - self.setLength)//(interFrames+1) + 1
|
||||
videoInputs = [frames[(interFrames+1)*i:(interFrames+1)*i+self.setLength ] for i in range(n_sets)]
|
||||
videoInputs = [[os.path.join(video , f) for f in group] for group in videoInputs]
|
||||
self.file_list.extend(videoInputs)
|
||||
|
||||
self.transforms = transforms.Compose([
|
||||
transforms.CenterCrop(512),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
imgpaths = [os.path.join(self.data_root , fp) for fp in self.file_list[idx]]
|
||||
if random.random() > 0.5:
|
||||
imgpaths = imgpaths[::-1] ## random temporal flip
|
||||
|
||||
# We can use compression based augmentations
|
||||
|
||||
pick_idxs = list(range(0,self.setLength,self.interFrames+1))
|
||||
rem = self.interFrames%2
|
||||
gt_idx = list(range(self.setLength//2-self.interFrames//2 , self.setLength//2+self.interFrames//2+rem))
|
||||
|
||||
input_paths = [imgpaths[idx] for idx in pick_idxs]
|
||||
gt_paths = [imgpaths[idx] for idx in gt_idx]
|
||||
|
||||
images = [Image.open(pth_) for pth_ in input_paths]
|
||||
images = [self.transforms(img_) for img_ in images]
|
||||
|
||||
gt_images = [Image.open(pth_) for pth_ in gt_paths]
|
||||
gt_images = [self.transforms(img_) for img_ in gt_images]
|
||||
|
||||
return images , gt_images
|
||||
|
||||
def __len__(self):
|
||||
|
||||
return len(self.file_list)
|
||||
|
||||
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True, interFrames=3, n_inputs=4):
|
||||
|
||||
if test_mode:
|
||||
mode = "test"
|
||||
else:
|
||||
mode = "train"
|
||||
|
||||
dataset = GoPro(data_root , mode, interFrames=interFrames, n_inputs=n_inputs)
|
||||
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dataset = GoPro("./GoPro" , mode="train", interFrames=3, n_inputs=4)
|
||||
|
||||
print(len(dataset))
|
||||
|
||||
dataloader = DataLoader(dataset , batch_size=1, shuffle=True, num_workers=0)
|
|
@ -0,0 +1,46 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import random
|
||||
import glob
|
||||
|
||||
|
||||
class Middelburry(Dataset):
|
||||
def __init__(self, data_root , ext="png"):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.data_root = data_root
|
||||
self.file_list = os.listdir(self.data_root)
|
||||
|
||||
self.transforms = transforms.Compose([
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
imgpath = os.path.join(self.data_root , self.file_list[idx])
|
||||
name = self.file_list[idx]
|
||||
if name == "Teddy": ## Handle inputs with just two inout frames. FLAVR takes atleast 4.
|
||||
imgpaths = [os.path.join(imgpath , "frame10.png") , os.path.join(imgpath , "frame10.png") ,os.path.join(imgpath , "frame11.png") ,os.path.join(imgpath , "frame11.png") ]
|
||||
else:
|
||||
imgpaths = [os.path.join(imgpath , "frame09.png") , os.path.join(imgpath , "frame10.png") ,os.path.join(imgpath , "frame11.png") ,os.path.join(imgpath , "frame12.png") ]
|
||||
|
||||
images = [Image.open(img).convert('RGB') for img in imgpaths]
|
||||
images = [self.transforms(img) for img in images]
|
||||
|
||||
sizes = images[0].shape
|
||||
|
||||
return images , name
|
||||
|
||||
def __len__(self):
|
||||
|
||||
return len(self.file_list)
|
||||
|
||||
def get_loader(data_root, batch_size, shuffle, num_workers, test_mode=True):
|
||||
|
||||
dataset = Middelburry(data_root)
|
||||
return DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
|
|
@ -0,0 +1,419 @@
|
|||
# from https://github.com/facebookresearch/VMZ
|
||||
|
||||
import torch
|
||||
import numbers
|
||||
import random
|
||||
|
||||
from torchvision.transforms import RandomCrop, RandomResizedCrop
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RandomCropVideo",
|
||||
"RandomResizedCropVideo",
|
||||
"CenterCropVideo",
|
||||
"NormalizeVideo",
|
||||
"ToTensorVideo",
|
||||
"RandomHorizontalFlipVideo",
|
||||
"Resize",
|
||||
"TemporalCenterCrop",
|
||||
"RandomTemporalFlipVideo",
|
||||
"RandomVerticalFlipVideo"
|
||||
]
|
||||
|
||||
|
||||
def _is_tensor_video_clip(clip):
|
||||
if not torch.is_tensor(clip):
|
||||
raise TypeError("clip should be Tesnor. Got %s" % type(clip))
|
||||
|
||||
if not clip.ndimension() == 4:
|
||||
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def crop(clip, i, j, h, w):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
||||
"""
|
||||
assert len(clip.size()) == 4, "clip should be a 4D tensor"
|
||||
return clip[..., i : i + h, j : j + w]
|
||||
|
||||
|
||||
def temporal_center_crop(clip, clip_len):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be
|
||||
cropped along the temporal axis. Size is (C, T, H, W)
|
||||
"""
|
||||
assert len(clip.size()) == 4, "clip should be a 4D tensor"
|
||||
assert clip.size(1) >= clip_len, "clip is shorter than the proposed lenght"
|
||||
middle = int(clip.size(1) // 2)
|
||||
start = middle - clip_len // 2
|
||||
return clip[:, start : start + clip_len, ...]
|
||||
|
||||
|
||||
def resize(clip, target_size, interpolation_mode):
|
||||
assert len(target_size) == 2, "target size should be tuple (height, width)"
|
||||
return torch.nn.functional.interpolate(
|
||||
clip, size=target_size, mode=interpolation_mode, align_corners=False
|
||||
)
|
||||
|
||||
|
||||
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
||||
"""
|
||||
Do spatial cropping and resizing to the video clip
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
||||
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
||||
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
||||
h (int): Height of the cropped region.
|
||||
w (int): Width of the cropped region.
|
||||
size (tuple(int, int)): height and width of resized clip
|
||||
Returns:
|
||||
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
|
||||
"""
|
||||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
||||
clip = crop(clip, i, j, h, w)
|
||||
clip = resize(clip, size, interpolation_mode)
|
||||
return clip
|
||||
|
||||
|
||||
def center_crop(clip, crop_size):
|
||||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
||||
h, w = clip.size(-2), clip.size(-1)
|
||||
th, tw = crop_size
|
||||
assert h >= th and w >= tw, "height and width must be >= than crop_size"
|
||||
|
||||
i = int(round((h - th) / 2.0))
|
||||
j = int(round((w - tw) / 2.0))
|
||||
return crop(clip, i, j, th, tw)
|
||||
|
||||
|
||||
def to_tensor(clip):
|
||||
"""
|
||||
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
||||
permute the dimenions of clip tensor
|
||||
Args:
|
||||
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
|
||||
Return:
|
||||
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
|
||||
"""
|
||||
_is_tensor_video_clip(clip)
|
||||
if not clip.dtype == torch.uint8:
|
||||
raise TypeError(
|
||||
"clip tensor should have data type uint8. Got %s" % str(clip.dtype)
|
||||
)
|
||||
return clip.float().permute(3, 0, 1, 2) / 255.0
|
||||
|
||||
|
||||
def normalize(clip, mean, std, inplace=False):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
|
||||
mean (tuple): pixel RGB mean. Size is (3)
|
||||
std (tuple): pixel standard deviation. Size is (3)
|
||||
Returns:
|
||||
normalized clip (torch.tensor): Size is (C, T, H, W)
|
||||
"""
|
||||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
||||
if not inplace:
|
||||
clip = clip.clone()
|
||||
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
||||
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
||||
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
||||
return clip
|
||||
|
||||
|
||||
def hflip(clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
|
||||
Returns:
|
||||
flipped clip (torch.tensor): Size is (C, T, H, W)
|
||||
"""
|
||||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
||||
return clip.flip((-1))
|
||||
|
||||
|
||||
def vflip(clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
|
||||
Returns:
|
||||
flipped clip (torch.tensor): Size is (C, T, H, W)
|
||||
"""
|
||||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
||||
return clip.flip((-2))
|
||||
|
||||
def tflip(clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
|
||||
Returns:
|
||||
flipped clip (torch.tensor): Size is (C, T, H, W)
|
||||
"""
|
||||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
||||
return clip.flip((-3))
|
||||
|
||||
|
||||
class RandomCropVideo(RandomCrop):
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
||||
Returns:
|
||||
torch.tensor: randomly cropped/resized video clip.
|
||||
size is (C, T, OH, OW)
|
||||
"""
|
||||
i, j, h, w = self.get_params(clip, self.size)
|
||||
return crop(clip, i, j, h, w)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "(size={0})".format(self.size)
|
||||
|
||||
|
||||
class RandomResizedCropVideo(RandomResizedCrop):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
scale=(0.08, 1.0),
|
||||
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
||||
interpolation_mode="bilinear",
|
||||
):
|
||||
if isinstance(size, tuple):
|
||||
assert len(size) == 2, "size should be tuple (height, width)"
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
|
||||
self.interpolation_mode = interpolation_mode
|
||||
self.scale = scale
|
||||
self.ratio = ratio
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
||||
Returns:
|
||||
torch.tensor: randomly cropped/resized video clip.
|
||||
size is (C, T, H, W)
|
||||
"""
|
||||
i, j, h, w = self.get_params(clip, self.scale, self.ratio)
|
||||
return resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "(size={0}, interpolation_mode={1}, scale={2}, ratio={3})".format(
|
||||
self.size, self.interpolation_mode, self.scale, self.ratio
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CenterCropVideo(object):
|
||||
def __init__(self, crop_size):
|
||||
if isinstance(crop_size, numbers.Number):
|
||||
self.crop_size = (int(crop_size), int(crop_size))
|
||||
else:
|
||||
self.crop_size = crop_size
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
||||
Returns:
|
||||
torch.tensor: central cropping of video clip. Size is
|
||||
(C, T, crop_size, crop_size)
|
||||
"""
|
||||
return center_crop(clip, self.crop_size)
|
||||
|
||||
def __repr__(self):
|
||||
r = self.__class__.__name__ + "(crop_size={0})".format(self.crop_size)
|
||||
return r
|
||||
|
||||
|
||||
class TemporalCenterCrop(object):
|
||||
def __init__(self, clip_len):
|
||||
self.clip_len = clip_len
|
||||
|
||||
def __call__(self, clip):
|
||||
return temporal_center_crop(clip, self.clip_len)
|
||||
|
||||
|
||||
class UnfoldClips(object):
|
||||
def __init__(self, clip_len, overlap):
|
||||
self.clip_len = clip_len
|
||||
assert overlap > 0 and overlap <= 1
|
||||
self.step = round(clip_len * overlap)
|
||||
|
||||
def __call__(self, clip):
|
||||
if clip.size(1) < self.clip_len:
|
||||
return clip.unfold(1, clip.size(1), clip.size(1)).permute(1, 0, 4, 2, 3)
|
||||
|
||||
results = clip.unfold(1, self.clip_len, self.clip_len).permute(1, 0, 4, 2, 3)
|
||||
return results
|
||||
|
||||
|
||||
class TempPadClip(object):
|
||||
def __init__(self, clip_len):
|
||||
self.num_frames = clip_len
|
||||
|
||||
def __call__(self, clip):
|
||||
if clip.size(1) == 0:
|
||||
return clip
|
||||
if clip.size(1) < self.num_frames:
|
||||
# do something and return
|
||||
step = clip.size(1) / self.num_frames
|
||||
idxs = torch.arange(self.num_frames, dtype=torch.float32) * step
|
||||
idxs = idxs.floor().to(torch.int64)
|
||||
return clip[:, idxs, ...]
|
||||
step = clip.size(1) / self.num_frames
|
||||
if step.is_integer():
|
||||
# optimization: if step is integer, don't need to perform
|
||||
# advanced indexing
|
||||
step = int(step)
|
||||
return clip[:, slice(None, None, step), ...]
|
||||
idxs = torch.arange(self.num_frames, dtype=torch.float32) * step
|
||||
idxs = idxs.floor().to(torch.int64)
|
||||
return clip[:, idxs, ...]
|
||||
|
||||
|
||||
class NormalizeVideo(object):
|
||||
"""
|
||||
Normalize the video clip by mean subtraction
|
||||
and division by standard deviation
|
||||
Args:
|
||||
mean (3-tuple): pixel RGB mean
|
||||
std (3-tuple): pixel RGB standard deviation
|
||||
inplace (boolean): whether do in-place normalization
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, inplace=False):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.inplace = inplace
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): video clip to be
|
||||
normalized. Size is (C, T, H, W)
|
||||
"""
|
||||
return normalize(clip, self.mean, self.std, self.inplace)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "(mean={0}, std={1}, inplace={2})".format(
|
||||
self.mean, self.std, self.inplace
|
||||
)
|
||||
|
||||
|
||||
class ToTensorVideo(object):
|
||||
"""
|
||||
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
||||
permute the dimenions of clip tensor
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
|
||||
Return:
|
||||
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
|
||||
"""
|
||||
return to_tensor(clip)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
class RandomHorizontalFlipVideo(object):
|
||||
"""
|
||||
Flip the video clip along the horizonal direction with a given probability
|
||||
Args:
|
||||
p (float): probability of the clip being flipped. Default value is 0.5
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Size is (C, T, H, W)
|
||||
Return:
|
||||
clip (torch.tensor): Size is (C, T, H, W)
|
||||
"""
|
||||
if random.random() < self.p:
|
||||
clip = hflip(clip)
|
||||
return clip
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "(p={0})".format(self.p)
|
||||
|
||||
|
||||
class RandomVerticalFlipVideo(object):
|
||||
"""
|
||||
Flip the video clip along the horizonal direction with a given probability
|
||||
Args:
|
||||
p (float): probability of the clip being flipped. Default value is 0.5
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Size is (C, T, H, W)
|
||||
Return:
|
||||
clip (torch.tensor): Size is (C, T, H, W)
|
||||
"""
|
||||
if random.random() < self.p:
|
||||
clip = vflip(clip)
|
||||
return clip
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "(p={0})".format(self.p)
|
||||
|
||||
|
||||
class RandomTemporalFlipVideo(object):
|
||||
"""
|
||||
Flip the video clip along the horizonal direction with a given probability
|
||||
Args:
|
||||
p (float): probability of the clip being flipped. Default value is 0.5
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Size is (C, T, H, W)
|
||||
Return:
|
||||
clip (torch.tensor): Size is (C, T, H, W)
|
||||
"""
|
||||
if random.random() < self.p:
|
||||
clip = tflip(clip)
|
||||
return clip
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "(p={0})".format(self.p)
|
||||
|
||||
|
||||
class Resize(object):
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, vid):
|
||||
return resize(vid, self.size, interpolation_mode="bilinear")
|
|
@ -0,0 +1,100 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import random
|
||||
|
||||
class VimeoSepTuplet(Dataset):
|
||||
def __init__(self, data_root, is_training , input_frames="1357"):
|
||||
"""
|
||||
Creates a Vimeo Septuplet object.
|
||||
Inputs.
|
||||
data_root: Root path for the Vimeo dataset containing the sep tuples.
|
||||
is_training: Train/Test.
|
||||
input_frames: Which frames to input for frame interpolation network.
|
||||
"""
|
||||
self.data_root = data_root
|
||||
self.image_root = os.path.join(self.data_root, 'sequences')
|
||||
self.training = is_training
|
||||
self.inputs = input_frames
|
||||
|
||||
train_fn = os.path.join(self.data_root, 'sep_trainlist.txt')
|
||||
test_fn = os.path.join(self.data_root, 'sep_testlist.txt')
|
||||
with open(train_fn, 'r') as f:
|
||||
self.trainlist = f.read().splitlines()
|
||||
with open(test_fn, 'r') as f:
|
||||
self.testlist = f.read().splitlines()
|
||||
|
||||
if self.training:
|
||||
self.transforms = transforms.Compose([
|
||||
transforms.RandomCrop(256),
|
||||
transforms.RandomHorizontalFlip(0.5),
|
||||
transforms.RandomVerticalFlip(0.5),
|
||||
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
else:
|
||||
self.transforms = transforms.Compose([
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.training:
|
||||
imgpath = os.path.join(self.image_root, self.trainlist[index])
|
||||
else:
|
||||
imgpath = os.path.join(self.image_root, self.testlist[index])
|
||||
|
||||
imgpaths = [imgpath + f'/im{i}.png' for i in range(1,8)]
|
||||
|
||||
pth_ = imgpaths
|
||||
|
||||
# Load images
|
||||
images = [Image.open(pth) for pth in imgpaths]
|
||||
|
||||
## Select only relevant inputs
|
||||
inputs = [int(e)-1 for e in list(self.inputs)]
|
||||
inputs = inputs[:len(inputs)//2] + [3] + inputs[len(inputs)//2:]
|
||||
images = [images[i] for i in inputs]
|
||||
imgpaths = [imgpaths[i] for i in inputs]
|
||||
# Data augmentation
|
||||
if self.training:
|
||||
seed = random.randint(0, 2**32)
|
||||
images_ = []
|
||||
for img_ in images:
|
||||
random.seed(seed)
|
||||
images_.append(self.transforms(img_))
|
||||
images = images_
|
||||
# Random Temporal Flip
|
||||
if random.random() >= 0.5:
|
||||
images = images[::-1]
|
||||
imgpaths = imgpaths[::-1]
|
||||
else:
|
||||
T = self.transforms
|
||||
images = [T(img_) for img_ in images]
|
||||
|
||||
gt = images[len(images)//2]
|
||||
images = images[:len(images)//2] + images[len(images)//2+1:]
|
||||
|
||||
return images, [gt]
|
||||
|
||||
def __len__(self):
|
||||
if self.training:
|
||||
return len(self.trainlist)
|
||||
else:
|
||||
return len(self.testlist)
|
||||
|
||||
def get_loader(mode, data_root, batch_size, shuffle, num_workers, test_mode=None):
|
||||
if mode == 'train':
|
||||
is_training = True
|
||||
else:
|
||||
is_training = False
|
||||
dataset = VimeoSepTuplet(data_root, is_training=is_training)
|
||||
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dataset = VimeoSepTuplet("./vimeo_septuplet/", is_training=True)
|
||||
dataloader = DataLoader(dataset, batch_size=100, shuffle=False, num_workers=32, pin_memory=True)
|
|
@ -0,0 +1,166 @@
|
|||
import os
|
||||
import torch
|
||||
import cv2
|
||||
import pdb
|
||||
import time
|
||||
import sys
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import _thread
|
||||
from torchvision.io import read_video, write_video
|
||||
from dataset.transforms import ToTensorVideo, Resize
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
abspath = os.path.abspath(__file__)
|
||||
dname = os.path.dirname(abspath)
|
||||
print("Changing working dir to {0}".format(dname))
|
||||
os.chdir(os.path.dirname(dname))
|
||||
print("Added {0} to temporary PATH".format(dname))
|
||||
sys.path.append(dname)
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--input', dest='input', type=str, default=None)
|
||||
parser.add_argument('--output', required=False, default='frames-interpolated')
|
||||
parser.add_argument("--factor", type=int, choices=[2,4,8], help="How much interpolation needed. 2x/4x/8x.")
|
||||
parser.add_argument("--model", type=str, help="path for stored model")
|
||||
parser.add_argument("--up_mode", type=str, help="Upsample Mode", default="transpose")
|
||||
parser.add_argument('--fp16', dest='fp16', action='store_true', help='half-precision mode')
|
||||
parser.add_argument('--imgformat', default="png")
|
||||
parser.add_argument("--output_ext", type=str, help="Output video format", default=".avi")
|
||||
parser.add_argument("--input_ext", type=str, help="Input video format", default=".mp4")
|
||||
parser.add_argument("--downscale", type=float, help="Downscale input res. for memory", default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
input_ext = args.input_ext
|
||||
|
||||
path = args.input
|
||||
base = os.path.basename(path)
|
||||
interp_input_path = os.path.join(dname, args.input)
|
||||
interp_output_path = os.path.join(dname, args.output)
|
||||
print("\interp_input_path: " + interp_input_path)
|
||||
print("\ninterp_output_path: " + interp_output_path)
|
||||
|
||||
if args.input.endswith("/"):
|
||||
video_name = args.input.split("/")[-2].split(input_ext)[0]
|
||||
else:
|
||||
video_name = args.input.split("/")[-1].split(input_ext)[0]
|
||||
|
||||
output_video = os.path.join(video_name + f"_{args.factor}x" + str(args.output_ext))
|
||||
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if(args.fp16):
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
print("FLAVR is running in FP16 mode.")
|
||||
else:
|
||||
print("WARNING: CUDA is not available, FLAVR is running on CPU! [ff:nocuda-cpu]")
|
||||
|
||||
|
||||
n_outputs = args.factor - 1
|
||||
|
||||
model_name = "unet_18"
|
||||
nbr_frame = 4
|
||||
joinType = "concat"
|
||||
|
||||
def loadModel(model, checkpoint):
|
||||
|
||||
saved_state_dict = torch.load(checkpoint)['state_dict']
|
||||
saved_state_dict = {k.partition("module.")[-1]:v for k,v in saved_state_dict.items()}
|
||||
model.load_state_dict(saved_state_dict)
|
||||
|
||||
checkpoint = os.path.join(dname, args.model)
|
||||
from model.FLAVR_arch import UNet_3D_3D
|
||||
|
||||
model = UNet_3D_3D(model_name.lower(), n_inputs=4, n_outputs=n_outputs, joinType=joinType, upmode=args.up_mode)
|
||||
loadModel(model, checkpoint)
|
||||
model = model.cuda()
|
||||
|
||||
in_files = sorted(os.listdir(interp_input_path))
|
||||
|
||||
def make_image(img):
|
||||
q_im = img.data.mul(255.).clamp(0,255).round()
|
||||
im = q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
|
||||
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
||||
return im
|
||||
|
||||
def files_to_videoTensor(path, downscale=1.):
|
||||
from PIL import Image
|
||||
global in_files
|
||||
in_files_fixed = in_files
|
||||
in_files_fixed.insert(0, in_files[0]) # Workaround: Insert extra entry before
|
||||
in_files_fixed.append(in_files[-1]) # Workaround: Insert extra entry after
|
||||
images = [torch.Tensor(np.asarray(Image.open(os.path.join(path, f)))).type(torch.uint8) for f in in_files]
|
||||
print(images[0].shape)
|
||||
videoTensor = torch.stack(images)
|
||||
return videoTensor
|
||||
|
||||
def video_transform(videoTensor, downscale=1):
|
||||
T, H, W = videoTensor.size(0), videoTensor.size(1), videoTensor.size(2)
|
||||
downscale = int(downscale * 8)
|
||||
resizes = 8*(H//downscale), 8*(W//downscale)
|
||||
transforms = torchvision.transforms.Compose([ToTensorVideo(), Resize(resizes)])
|
||||
videoTensor = transforms(videoTensor)
|
||||
|
||||
print("Resizing to %dx%d"%(resizes[0], resizes[1]) )
|
||||
return videoTensor, resizes
|
||||
|
||||
videoTensor = files_to_videoTensor(interp_input_path, args.downscale)
|
||||
|
||||
print(f"Video Tensor len: {len(videoTensor)}")
|
||||
idxs = torch.Tensor(range(len(videoTensor))).type(torch.long).view(1, -1).unfold(1,size=nbr_frame,step=1).squeeze(0)
|
||||
print(f"len(idxs): {len(idxs)}")
|
||||
videoTensor, resizes = video_transform(videoTensor, args.downscale)
|
||||
print("Video tensor shape is ", videoTensor.shape)
|
||||
|
||||
frames = torch.unbind(videoTensor, 1)
|
||||
n_inputs = len(frames)
|
||||
width = n_outputs + 1
|
||||
|
||||
|
||||
model = model.eval()
|
||||
|
||||
frame_num = 1
|
||||
|
||||
def load_and_write_img (path_write, path_load):
|
||||
cv2.imwrite(path_write, cv2.imread(path_load), [cv2.IMWRITE_PNG_COMPRESSION, 1])
|
||||
|
||||
def write_img (path_write, img):
|
||||
cv2.imwrite(path_write, img, [cv2.IMWRITE_PNG_COMPRESSION, 1])
|
||||
|
||||
|
||||
for i in (range(len(idxs))):
|
||||
idxSet = idxs[i]
|
||||
inputs = [frames[idx_].cuda().unsqueeze(0) for idx_ in idxSet]
|
||||
with torch.no_grad():
|
||||
outputFrame = model(inputs)
|
||||
outputFrame = [of.squeeze(0).cpu().data for of in outputFrame]
|
||||
#outputs.extend(outputFrame)
|
||||
#outputs.append(inputs[2].squeeze(0).cpu().data)
|
||||
|
||||
print(f"Frame {i}")
|
||||
|
||||
print(f"Writing source frame {'{:0>8d}.{}'.format(frame_num, args.imgformat)}")
|
||||
input_frame_path = os.path.join(interp_input_path, in_files[i+1])
|
||||
#cv2.imwrite('{}/{:0>8d}.{}'.format(interp_output_path, frame_num, args.imgformat), cv2.imread(input_frame_path), [cv2.IMWRITE_PNG_COMPRESSION, 2])
|
||||
_thread.start_new_thread(load_and_write_img, ('{}/{:0>8d}.{}'.format(interp_output_path, frame_num, args.imgformat), input_frame_path))
|
||||
frame_num += 1
|
||||
|
||||
for img in outputFrame:
|
||||
print(f"Writing interp frame {'{:0>8d}.{}'.format(frame_num, args.imgformat)}")
|
||||
# cv2.imwrite('{}/{:0>8d}.{}'.format(interp_output_path, frame_num, args.imgformat), make_image(img), [cv2.IMWRITE_PNG_COMPRESSION, 2])
|
||||
_thread.start_new_thread(write_img, ('{}/{:0>8d}.{}'.format(interp_output_path, frame_num, args.imgformat), make_image(img)))
|
||||
frame_num += 1
|
||||
|
||||
print(f"Writing source frame {frame_num} [LAST]")
|
||||
input_frame_path = os.path.join(interp_input_path, in_files[-1])
|
||||
cv2.imwrite('{}/{:0>8d}.{}'.format(interp_output_path, frame_num, args.imgformat), cv2.imread(input_frame_path), [cv2.IMWRITE_PNG_COMPRESSION, 2]) # Last input frame
|
||||
|
||||
time.sleep(0.5)
|
|
@ -0,0 +1,178 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .resnet_3D import SEGating
|
||||
|
||||
def joinTensors(X1 , X2 , type="concat"):
|
||||
|
||||
if type == "concat":
|
||||
return torch.cat([X1 , X2] , dim=1)
|
||||
elif type == "add":
|
||||
return X1 + X2
|
||||
else:
|
||||
return X1
|
||||
|
||||
|
||||
class Conv_2d(nn.Module):
|
||||
|
||||
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=False, batchnorm=False):
|
||||
|
||||
super().__init__()
|
||||
self.conv = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
|
||||
|
||||
if batchnorm:
|
||||
self.conv += [nn.BatchNorm2d(out_ch)]
|
||||
|
||||
self.conv = nn.Sequential(*self.conv)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
class upConv3D(nn.Module):
|
||||
|
||||
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.upmode = upmode
|
||||
|
||||
if self.upmode=="transpose":
|
||||
self.upconv = nn.ModuleList(
|
||||
[nn.ConvTranspose3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding),
|
||||
SEGating(out_ch)
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
self.upconv = nn.ModuleList(
|
||||
[nn.Upsample(mode='trilinear', scale_factor=(1,2,2), align_corners=False),
|
||||
nn.Conv3d(in_ch, out_ch , kernel_size=1 , stride=1),
|
||||
SEGating(out_ch)
|
||||
]
|
||||
)
|
||||
|
||||
if batchnorm:
|
||||
self.upconv += [nn.BatchNorm3d(out_ch)]
|
||||
|
||||
self.upconv = nn.Sequential(*self.upconv)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
return self.upconv(x)
|
||||
|
||||
class Conv_3d(nn.Module):
|
||||
|
||||
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True, batchnorm=False):
|
||||
|
||||
super().__init__()
|
||||
self.conv = [nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
|
||||
SEGating(out_ch)
|
||||
]
|
||||
|
||||
if batchnorm:
|
||||
self.conv += [nn.BatchNorm3d(out_ch)]
|
||||
|
||||
self.conv = nn.Sequential(*self.conv)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
class upConv2D(nn.Module):
|
||||
|
||||
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.upmode = upmode
|
||||
|
||||
if self.upmode=="transpose":
|
||||
self.upconv = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)]
|
||||
|
||||
else:
|
||||
self.upconv = [
|
||||
nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
|
||||
nn.Conv2d(in_ch, out_ch , kernel_size=1 , stride=1)
|
||||
]
|
||||
|
||||
if batchnorm:
|
||||
self.upconv += [nn.BatchNorm2d(out_ch)]
|
||||
|
||||
self.upconv = nn.Sequential(*self.upconv)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
return self.upconv(x)
|
||||
|
||||
|
||||
class UNet_3D_3D(nn.Module):
|
||||
def __init__(self, block , n_inputs, n_outputs, batchnorm=False , joinType="concat" , upmode="transpose"):
|
||||
super().__init__()
|
||||
|
||||
nf = [512 , 256 , 128 , 64]
|
||||
out_channels = 3*n_outputs
|
||||
self.joinType = joinType
|
||||
self.n_outputs = n_outputs
|
||||
|
||||
growth = 2 if joinType == "concat" else 1
|
||||
self.lrelu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
unet_3D = importlib.import_module(".resnet_3D" , "model")
|
||||
if n_outputs > 1:
|
||||
unet_3D.useBias = True
|
||||
self.encoder = getattr(unet_3D , block)(pretrained=False , bn=batchnorm)
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
Conv_3d(nf[0], nf[1] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
|
||||
upConv3D(nf[1]*growth, nf[2], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
|
||||
upConv3D(nf[2]*growth, nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
|
||||
Conv_3d(nf[3]*growth, nf[3] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
|
||||
upConv3D(nf[3]*growth , nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm)
|
||||
)
|
||||
|
||||
self.feature_fuse = Conv_2d(nf[3]*n_inputs , nf[3] , kernel_size=1 , stride=1, batchnorm=batchnorm)
|
||||
|
||||
self.outconv = nn.Sequential(
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(nf[3], out_channels , kernel_size=7 , stride=1, padding=0)
|
||||
)
|
||||
|
||||
def forward(self, images):
|
||||
|
||||
images = torch.stack(images , dim=2)
|
||||
|
||||
## Batch mean normalization works slightly better than global mean normalization, thanks to https://github.com/myungsub/CAIN
|
||||
mean_ = images.mean(2, keepdim=True).mean(3, keepdim=True).mean(4,keepdim=True)
|
||||
images = images-mean_
|
||||
|
||||
x_0 , x_1 , x_2 , x_3 , x_4 = self.encoder(images)
|
||||
|
||||
dx_3 = self.lrelu(self.decoder[0](x_4))
|
||||
dx_3 = joinTensors(dx_3 , x_3 , type=self.joinType)
|
||||
|
||||
dx_2 = self.lrelu(self.decoder[1](dx_3))
|
||||
dx_2 = joinTensors(dx_2 , x_2 , type=self.joinType)
|
||||
|
||||
dx_1 = self.lrelu(self.decoder[2](dx_2))
|
||||
dx_1 = joinTensors(dx_1 , x_1 , type=self.joinType)
|
||||
|
||||
dx_0 = self.lrelu(self.decoder[3](dx_1))
|
||||
dx_0 = joinTensors(dx_0 , x_0 , type=self.joinType)
|
||||
|
||||
dx_out = self.lrelu(self.decoder[4](dx_0))
|
||||
dx_out = torch.cat(torch.unbind(dx_out , 2) , 1)
|
||||
|
||||
out = self.lrelu(self.feature_fuse(dx_out))
|
||||
out = self.outconv(out)
|
||||
|
||||
out = torch.split(out, dim=1, split_size_or_sections=3)
|
||||
mean_ = mean_.squeeze(2)
|
||||
out = [o+mean_ for o in out]
|
||||
|
||||
return out
|
||||
|
|
@ -0,0 +1,288 @@
|
|||
# Modified from https://github.com/pytorch/vision/tree/master/torchvision/models/video
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['unet_18', 'unet_34']
|
||||
|
||||
useBias = False
|
||||
|
||||
class identity(nn.Module):
|
||||
|
||||
def __init__(self , *args , **kwargs):
|
||||
|
||||
super().__init__()
|
||||
|
||||
def forward(self , x):
|
||||
return x
|
||||
|
||||
class Conv3DSimple(nn.Conv3d):
|
||||
def __init__(self,
|
||||
in_planes,
|
||||
out_planes,
|
||||
midplanes=None,
|
||||
stride=1,
|
||||
padding=1):
|
||||
|
||||
super(Conv3DSimple, self).__init__(
|
||||
in_channels=in_planes,
|
||||
out_channels=out_planes,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=useBias)
|
||||
|
||||
@staticmethod
|
||||
def get_downsample_stride(stride , temporal_stride):
|
||||
if temporal_stride:
|
||||
return (temporal_stride, stride, stride)
|
||||
else:
|
||||
return (stride , stride , stride)
|
||||
|
||||
class BasicStem(nn.Sequential):
|
||||
"""The default conv-batchnorm-relu stem
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
|
||||
padding=(1, 3, 3), bias=useBias),
|
||||
batchnorm(64),
|
||||
nn.ReLU(inplace=False))
|
||||
|
||||
|
||||
class Conv2Plus1D(nn.Sequential):
|
||||
|
||||
def __init__(self,
|
||||
in_planes,
|
||||
out_planes,
|
||||
midplanes,
|
||||
stride=1,
|
||||
padding=1):
|
||||
if not isinstance(stride , int):
|
||||
temporal_stride , stride , stride = stride
|
||||
else:
|
||||
temporal_stride = stride
|
||||
|
||||
super(Conv2Plus1D, self).__init__(
|
||||
nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
|
||||
stride=(1, stride, stride), padding=(0, padding, padding),
|
||||
bias=False),
|
||||
# batchnorm(midplanes),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
|
||||
stride=(temporal_stride, 1, 1), padding=(padding, 0, 0),
|
||||
bias=False))
|
||||
|
||||
@staticmethod
|
||||
def get_downsample_stride(stride , temporal_stride):
|
||||
if temporal_stride:
|
||||
return (temporal_stride, stride, stride)
|
||||
else:
|
||||
return (stride , stride , stride)
|
||||
|
||||
class R2Plus1dStem(nn.Sequential):
|
||||
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
|
||||
stride=(1, 2, 2), padding=(0, 3, 3),
|
||||
bias=False),
|
||||
batchnorm(45),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
|
||||
stride=(1, 1, 1), padding=(1, 0, 0),
|
||||
bias=False),
|
||||
batchnorm(64),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
|
||||
class SEGating(nn.Module):
|
||||
|
||||
def __init__(self , inplanes , reduction=16):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.pool = nn.AdaptiveAvgPool3d(1)
|
||||
self.attn_layer = nn.Sequential(
|
||||
nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self , x):
|
||||
|
||||
out = self.pool(x)
|
||||
y = self.attn_layer(out)
|
||||
return x * y
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
|
||||
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
|
||||
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
conv_builder(inplanes, planes, midplanes, stride),
|
||||
batchnorm(planes),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
conv_builder(planes, planes, midplanes),
|
||||
batchnorm(planes)
|
||||
)
|
||||
self.fg = SEGating(planes) ## Feature Gating
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.fg(out)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class VideoResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, conv_makers, layers,
|
||||
stem, zero_init_residual=False):
|
||||
"""Generic resnet video generator.
|
||||
|
||||
Args:
|
||||
block (nn.Module): resnet building block
|
||||
conv_makers (list(functions)): generator function for each layer
|
||||
layers (List[int]): number of blocks per layer
|
||||
stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
|
||||
"""
|
||||
super(VideoResNet, self).__init__()
|
||||
self.inplanes = 64
|
||||
|
||||
self.stem = stem()
|
||||
|
||||
self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1 )
|
||||
self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2 , temporal_stride=1)
|
||||
self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2 , temporal_stride=1)
|
||||
self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=1, temporal_stride=1)
|
||||
|
||||
# init weights
|
||||
self._initialize_weights()
|
||||
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x_0 = self.stem(x)
|
||||
x_1 = self.layer1(x_0)
|
||||
x_2 = self.layer2(x_1)
|
||||
x_3 = self.layer3(x_2)
|
||||
x_4 = self.layer4(x_3)
|
||||
return x_0 , x_1 , x_2 , x_3 , x_4
|
||||
|
||||
def _make_layer(self, block, conv_builder, planes, blocks, stride=1, temporal_stride=None):
|
||||
downsample = None
|
||||
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
ds_stride = conv_builder.get_downsample_stride(stride , temporal_stride)
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv3d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=ds_stride, bias=False),
|
||||
batchnorm(planes * block.expansion)
|
||||
)
|
||||
stride = ds_stride
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, conv_builder, stride, downsample ))
|
||||
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, conv_builder ))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv3d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out',
|
||||
nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm3d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
|
||||
model = VideoResNet(**kwargs)
|
||||
## TODO: Other 3D resnet models, like S3D, r(2+1)D.
|
||||
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def unet_18(pretrained=False, bn=False, progress=True, **kwargs):
|
||||
"""
|
||||
Construct 18 layer Unet3D model as in
|
||||
https://arxiv.org/abs/1711.11248
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
|
||||
Returns:
|
||||
nn.Module: R3D-18 encoder
|
||||
"""
|
||||
global batchnorm
|
||||
if bn:
|
||||
batchnorm = nn.BatchNorm3d
|
||||
else:
|
||||
batchnorm = identity
|
||||
|
||||
return _video_resnet('r3d_18',
|
||||
pretrained, progress,
|
||||
block=BasicBlock,
|
||||
conv_makers=[Conv3DSimple] * 4,
|
||||
layers=[2, 2, 2, 2],
|
||||
stem=BasicStem, **kwargs)
|
||||
|
||||
def unet_34(pretrained=False, bn=False, progress=True, **kwargs):
|
||||
"""
|
||||
Construct 34 layer Unet3D model as in
|
||||
https://arxiv.org/abs/1711.11248
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
|
||||
Returns:
|
||||
nn.Module: R3D-18 encoder
|
||||
"""
|
||||
global batchnorm
|
||||
# bn = False
|
||||
if bn:
|
||||
batchnorm = nn.BatchNorm3d
|
||||
else:
|
||||
batchnorm = identity
|
||||
|
||||
|
||||
return _video_resnet('r3d_34',
|
||||
pretrained, progress,
|
||||
block=BasicBlock,
|
||||
conv_makers=[Conv3DSimple] * 4,
|
||||
layers=[3, 4, 6, 3],
|
||||
stem=BasicStem, **kwargs)
|
|
@ -0,0 +1,3 @@
|
|||
FLAVR 2x - Official model for 2x interpolation
|
||||
FLAVR 4x - Official model for 4x interpolation
|
||||
FLAVR 8x - Official model for 8x interpolation
|
Loading…
Reference in New Issue