flowframes/Pkgs/flavr-cuda/flavr.py

153 lines
5.4 KiB
Python

import os
import torch
import cv2
import pdb
import time
import sys
import torchvision
from PIL import Image
import numpy as np
import _thread
from torchvision.io import read_video, write_video
import torch.nn.functional as F
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
print("Changing working dir to {0}".format(dname))
os.chdir(os.path.dirname(dname))
print("Added {0} to temporary PATH".format(dname))
sys.path.append(dname)
from dataset.transforms import ToTensorVideo, Resize
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input', dest='input', type=str, default=None)
parser.add_argument('--output', required=False, default='frames-interpolated')
parser.add_argument("--factor", type=int, choices=[2,4,8], help="How much interpolation needed. 2x/4x/8x.")
parser.add_argument("--model", type=str, help="path for stored model")
parser.add_argument("--up_mode", type=str, help="Upsample Mode", default="transpose")
parser.add_argument('--fp16', dest='fp16', action='store_true', help='half-precision mode')
parser.add_argument('--imgformat', default="png")
parser.add_argument("--output_ext", type=str, help="Output video format", default=".avi")
parser.add_argument("--input_ext", type=str, help="Input video format", default=".mp4")
args = parser.parse_args()
input_ext = args.input_ext
path = args.input
base = os.path.basename(path)
interp_input_path = os.path.join(dname, args.input)
interp_output_path = os.path.join(dname, args.output)
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if(args.fp16):
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print("FLAVR is running in FP16 mode.")
else:
print("WARNING: CUDA is not available, FLAVR is running on CPU! [ff:nocuda-cpu]")
n_outputs = args.factor - 1
model_name = "unet_18"
nbr_frame = 4
joinType = "concat"
def loadModel(model, checkpoint):
saved_state_dict = torch.load(checkpoint)['state_dict']
saved_state_dict = {k.partition("module.")[-1]:v for k,v in saved_state_dict.items()}
model.load_state_dict(saved_state_dict)
checkpoint = os.path.join(dname, args.model)
from model.FLAVR_arch import UNet_3D_3D
model = UNet_3D_3D(model_name.lower(), n_inputs=4, n_outputs=n_outputs, joinType=joinType, upmode=args.up_mode)
loadModel(model, checkpoint)
model = model.cuda()
in_files = sorted(os.listdir(interp_input_path))
def make_image(img):
q_im = img.data.mul(255.).clamp(0,255).round()
im = q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
return im
def files_to_videoTensor(path):
from PIL import Image
global in_files
in_files_fixed = in_files
in_files_fixed.insert(0, in_files[0]) # Workaround: Insert extra entry before
in_files_fixed.append(in_files[-1]) # Workaround: Insert extra entry after
images = [torch.Tensor(np.asarray(Image.open(os.path.join(path, f)))).type(torch.uint8) for f in in_files]
print(images[0].shape)
videoTensor = torch.stack(images)
return videoTensor
def video_transform(videoTensor):
T, H, W = videoTensor.size(0), videoTensor.size(1), videoTensor.size(2)
transforms = torchvision.transforms.Compose([ToTensorVideo()])
videoTensor = transforms(videoTensor)
return videoTensor
videoTensor = files_to_videoTensor(interp_input_path)
print(f"Video Tensor len: {len(videoTensor)}")
idxs = torch.Tensor(range(len(videoTensor))).type(torch.long).view(1, -1).unfold(1,size=nbr_frame,step=1).squeeze(0)
print(f"len(idxs): {len(idxs)}")
videoTensor = video_transform(videoTensor)
print("Video tensor shape is ", videoTensor.shape)
frames = torch.unbind(videoTensor, 1)
n_inputs = len(frames)
width = n_outputs + 1
model = model.eval()
frame_num = 1
def load_and_write_img (writedir, writename, path_load):
os.chdir(writedir)
cv2.imwrite(writename, cv2.imdecode(np.fromfile(path_load, dtype=np.uint8), cv2.IMREAD_UNCHANGED), [cv2.IMWRITE_PNG_COMPRESSION, 1])
def write_img (writedir, writename, img):
os.chdir(writedir)
cv2.imwrite(writename, img, [cv2.IMWRITE_PNG_COMPRESSION, 1])
for i in (range(len(idxs))):
idxSet = idxs[i]
inputs = [frames[idx_].cuda().unsqueeze(0) for idx_ in idxSet]
with torch.no_grad():
outputFrame = model(inputs)
outputFrame = [of.squeeze(0).cpu().data for of in outputFrame]
#outputs.extend(outputFrame)
#outputs.append(inputs[2].squeeze(0).cpu().data)
print(f"Frame {i}")
print(f"Writing source frame {'{:0>8d}.{}'.format(frame_num, args.imgformat)}")
input_frame_path = os.path.join(interp_input_path, in_files[i+1])
_thread.start_new_thread(load_and_write_img, (interp_output_path, '{:0>8d}.{}'.format(frame_num, args.imgformat), input_frame_path))
frame_num += 1
for img in outputFrame:
print(f"Writing interp frame {'{:0>8d}.{}'.format(frame_num, args.imgformat)}")
_thread.start_new_thread(write_img, (interp_output_path, '{:0>8d}.{}'.format(frame_num, args.imgformat), make_image(img)))
frame_num += 1
print(f"Writing source frame {frame_num} [LAST]")
input_frame_path = os.path.join(interp_input_path, in_files[-1])
os.chdir(interp_output_path)
cv2.imwrite('{:0>8d}.{}'.format(frame_num, args.imgformat), cv2.imdecode(np.fromfile(input_frame_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED), [cv2.IMWRITE_PNG_COMPRESSION, 2]) # Last input frame
time.sleep(0.5)