mirror of https://github.com/n00mkrad/flowframes
rife-cuda updated to include RIFE 3.9 code and model
This commit is contained in:
parent
ed15cb5b0a
commit
5719d779e6
|
@ -28,7 +28,7 @@
|
|||
"desc": "Updated v3 General Model",
|
||||
"dir": "RIFE31",
|
||||
"supportsAlpha": "false",
|
||||
"isDefault": "true"
|
||||
"isDefault": "false"
|
||||
},
|
||||
{
|
||||
"name": "RIFE 3.8",
|
||||
|
@ -37,4 +37,11 @@
|
|||
"supportsAlpha": "true",
|
||||
"isDefault": "false"
|
||||
},
|
||||
{
|
||||
"name": "RIFE 3.9",
|
||||
"desc": "Latest General Model",
|
||||
"dir": "RIFE39",
|
||||
"supportsAlpha": "true",
|
||||
"isDefault": "true"
|
||||
},
|
||||
]
|
|
@ -1,9 +0,0 @@
|
|||
RIFE 1.5 - Old general model
|
||||
RIFE 1.6 - Updated old general model
|
||||
RIFE 1.7 - Optimized for 2D animation
|
||||
RIFE 1.8 - Updated 2D animation model
|
||||
RIFE 2.0 - New general model
|
||||
RIFE 2.1 - Updated general model
|
||||
RIFE 2.2 - Updated general model
|
||||
RIFE 2.3 - Updated general model (Recommended)
|
||||
RIFE 2.4 - Updated general model (Experimental)
|
|
@ -0,0 +1,189 @@
|
|||
import sys
|
||||
import io
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.nn import functional as F
|
||||
import warnings
|
||||
import _thread
|
||||
import skvideo.io
|
||||
from queue import Queue, Empty
|
||||
import shutil
|
||||
import base64
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
abspath = os.path.abspath(__file__)
|
||||
dname = os.path.dirname(abspath)
|
||||
print("Changing working dir to {0}".format(dname))
|
||||
os.chdir(os.path.dirname(dname))
|
||||
print("Added {0} to temporary PATH".format(dname))
|
||||
sys.path.append(dname)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
|
||||
parser.add_argument('--input', dest='input', type=str, default=None)
|
||||
parser.add_argument('--output', required=False, default='frames-interpolated')
|
||||
parser.add_argument('--model', required=False, default='models')
|
||||
parser.add_argument('--imgformat', default="png")
|
||||
parser.add_argument('--rbuffer', dest='rbuffer', type=int, default=200)
|
||||
parser.add_argument('--wthreads', dest='wthreads', type=int, default=4)
|
||||
parser.add_argument('--fp16', dest='fp16', action='store_true', help='half-precision mode')
|
||||
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
|
||||
parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video')
|
||||
parser.add_argument('--exp', dest='exp', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
assert (not args.input is None)
|
||||
if args.UHD and args.scale==1.0:
|
||||
args.scale = 0.5
|
||||
assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
torch.set_grad_enabled(False)
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if(args.fp16):
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
print("RIFE is running in FP16 mode.")
|
||||
else:
|
||||
print("WARNING: CUDA is not available, RIFE is running on CPU! [ff:nocuda-cpu]")
|
||||
|
||||
try:
|
||||
print("\nSystem Info:")
|
||||
print("Python: {} - Pytorch: {} - cuDNN: {}".format(sys.version, torch.__version__, torch.backends.cudnn.version()))
|
||||
print("Hardware Acceleration: Using {} device(s), first is {}".format( torch.cuda.device_count(), torch.cuda.get_device_name(0)))
|
||||
except:
|
||||
print("Failed to get hardware info!")
|
||||
|
||||
try:
|
||||
try:
|
||||
print(f"Trying to load v3 (new) model using arch files from {os.path.join(dname, args.model)}")
|
||||
from arch.RIFE_HDv3 import Model
|
||||
model = Model()
|
||||
model.load_model(os.path.join(dname, args.model), -1)
|
||||
print("Loaded v3.x HD model.")
|
||||
except:
|
||||
try:
|
||||
print(f"Trying to load v3 (legacy) model from {os.path.join(dname, args.model)}")
|
||||
from model.RIFE_HDv3 import Model
|
||||
model = Model()
|
||||
model.load_model(os.path.join(dname, args.model), -1)
|
||||
print("Loaded v3.x HD model.")
|
||||
except:
|
||||
print(f"Trying to load v2 model from {os.path.join(dname, args.model)}")
|
||||
from model.RIFE_HDv2 import Model
|
||||
model = Model()
|
||||
model.load_model(os.path.join(dname, args.model), -1)
|
||||
print("Loaded v2.x HD model.")
|
||||
except:
|
||||
print(f"Trying to load v1 model from {os.path.join(dname, args.model)}")
|
||||
from model.RIFE_HD import Model
|
||||
model = Model()
|
||||
model.load_model(os.path.join(dname, args.model), -1)
|
||||
print("Loaded v1.x HD model")
|
||||
|
||||
model.eval()
|
||||
model.device()
|
||||
|
||||
path = args.input
|
||||
name = os.path.basename(path)
|
||||
interp_output_path = (args.output).join(path.rsplit(name, 1))
|
||||
print("interp_output_path: " + interp_output_path)
|
||||
|
||||
cnt = 1
|
||||
|
||||
videogen = []
|
||||
for f in os.listdir(args.input):
|
||||
if 'png' in f or 'jpg' in f:
|
||||
videogen.append(f)
|
||||
tot_frame = len(videogen)
|
||||
videogen.sort(key= lambda x:int(x[:-4]))
|
||||
img_path = os.path.join(args.input, videogen[0])
|
||||
lastframe = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
|
||||
videogen = videogen[1:]
|
||||
h, w, _ = lastframe.shape
|
||||
vid_out = None
|
||||
if not os.path.exists(interp_output_path):
|
||||
os.mkdir(interp_output_path)
|
||||
|
||||
|
||||
def clear_write_buffer(user_args, write_buffer, thread_id):
|
||||
os.chdir(interp_output_path)
|
||||
while True:
|
||||
item = write_buffer.get()
|
||||
if item is None:
|
||||
break
|
||||
frameNum = item[0]
|
||||
img = item[1]
|
||||
print('[T{}] => {:0>8d}.{}'.format(thread_id, frameNum, args.imgformat))
|
||||
#imgBytes = base64.b64encode(cv2.imencode(f'.{args.imgformat}', img[:, :, ::-1], [cv2.IMWRITE_PNG_COMPRESSION, 2])[1].tostring())
|
||||
#print(f"{frameNum:08}:"+ imgBytes.decode('utf-8') + "\n\n\n\n")
|
||||
cv2.imwrite('{:0>8d}.{}'.format(frameNum, args.imgformat), img[:, :, ::-1], [cv2.IMWRITE_PNG_COMPRESSION, 2])
|
||||
|
||||
def build_read_buffer(user_args, read_buffer, videogen):
|
||||
for frame in videogen:
|
||||
if not user_args.input is None:
|
||||
img_path = os.path.join(user_args.input, frame)
|
||||
frame = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
|
||||
read_buffer.put(frame)
|
||||
read_buffer.put(None)
|
||||
|
||||
def make_inference(I0, I1, exp):
|
||||
global model
|
||||
middle = model.inference(I0, I1, args.scale)
|
||||
if exp == 1:
|
||||
return [middle]
|
||||
first_half = make_inference(I0, middle, exp=exp - 1)
|
||||
second_half = make_inference(middle, I1, exp=exp - 1)
|
||||
return [*first_half, middle, *second_half]
|
||||
|
||||
def pad_image(img):
|
||||
if(args.fp16):
|
||||
return F.pad(img, padding).half()
|
||||
else:
|
||||
return F.pad(img, padding)
|
||||
|
||||
if args.UHD:
|
||||
print("UHD mode enabled.")
|
||||
ph = ((h - 1) // 64 + 1) * 64
|
||||
pw = ((w - 1) // 64 + 1) * 64
|
||||
else:
|
||||
ph = ((h - 1) // 32 + 1) * 32
|
||||
pw = ((w - 1) // 32 + 1) * 32
|
||||
padding = (0, pw - w, 0, ph - h)
|
||||
|
||||
write_buffer = Queue(maxsize=args.rbuffer)
|
||||
read_buffer = Queue(maxsize=args.rbuffer)
|
||||
_thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
|
||||
|
||||
for x in range(args.wthreads):
|
||||
_thread.start_new_thread(clear_write_buffer, (args, write_buffer, x))
|
||||
|
||||
I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
||||
I1 = pad_image(I1)
|
||||
|
||||
while True:
|
||||
frame = read_buffer.get()
|
||||
if frame is None:
|
||||
break
|
||||
I0 = I1
|
||||
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
||||
I1 = pad_image(I1)
|
||||
|
||||
output = make_inference(I0, I1, args.exp)
|
||||
write_buffer.put([cnt, lastframe])
|
||||
cnt += 1
|
||||
for mid in output:
|
||||
mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
|
||||
# print(f"Adding #{cnt} to buffer.")
|
||||
write_buffer.put([cnt, mid[:h, :w]])
|
||||
cnt += 1
|
||||
|
||||
lastframe = frame
|
||||
write_buffer.put([cnt, lastframe])
|
||||
import time
|
||||
while(not write_buffer.empty()):
|
||||
time.sleep(0.5)
|
||||
time.sleep(0.5)
|
|
@ -33,10 +33,17 @@ parser.add_argument('--fp16', dest='fp16', action='store_true', help='half-preci
|
|||
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
|
||||
parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video')
|
||||
parser.add_argument('--exp', dest='exp', type=int, default=1)
|
||||
parser.add_argument('--multi', dest='multi', type=int, default=2)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.exp != 1:
|
||||
args.multi = (2 ** args.exp)
|
||||
|
||||
assert (not args.input is None)
|
||||
|
||||
if args.UHD and args.scale==1.0:
|
||||
args.scale = 0.5
|
||||
|
||||
assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
@ -62,6 +69,12 @@ try:
|
|||
print(f"Trying to load v3 (new) model using arch files from {os.path.join(dname, args.model)}")
|
||||
from arch.RIFE_HDv3 import Model
|
||||
model = Model()
|
||||
|
||||
if not hasattr(model, 'version'):
|
||||
model.version = 0
|
||||
else:
|
||||
print("Using >= 3.9 model.")
|
||||
|
||||
model.load_model(os.path.join(dname, args.model), -1)
|
||||
print("Loaded v3.x HD model.")
|
||||
except:
|
||||
|
@ -130,14 +143,23 @@ def build_read_buffer(user_args, read_buffer, videogen):
|
|||
read_buffer.put(frame)
|
||||
read_buffer.put(None)
|
||||
|
||||
def make_inference(I0, I1, exp):
|
||||
def make_inference(I0, I1, n):
|
||||
global model
|
||||
middle = model.inference(I0, I1, args.scale)
|
||||
if exp == 1:
|
||||
return [middle]
|
||||
first_half = make_inference(I0, middle, exp=exp - 1)
|
||||
second_half = make_inference(middle, I1, exp=exp - 1)
|
||||
return [*first_half, middle, *second_half]
|
||||
if hasattr(model, 'version') and model.version >= 3.9:
|
||||
res = []
|
||||
for i in range(n):
|
||||
res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), args.scale))
|
||||
return res
|
||||
else:
|
||||
middle = model.inference(I0, I1, args.scale)
|
||||
if n == 1:
|
||||
return [middle]
|
||||
first_half = make_inference(I0, middle, n=n//2)
|
||||
second_half = make_inference(middle, I1, n=n//2)
|
||||
if n%2:
|
||||
return [*first_half, middle, *second_half]
|
||||
else:
|
||||
return [*first_half, *second_half]
|
||||
|
||||
def pad_image(img):
|
||||
if(args.fp16):
|
||||
|
@ -172,7 +194,7 @@ while True:
|
|||
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
||||
I1 = pad_image(I1)
|
||||
|
||||
output = make_inference(I0, I1, args.exp)
|
||||
output = make_inference(I0, I1, args.multi-1)
|
||||
write_buffer.put([cnt, lastframe])
|
||||
cnt += 1
|
||||
for mid in output:
|
||||
|
@ -185,5 +207,5 @@ while True:
|
|||
write_buffer.put([cnt, lastframe])
|
||||
import time
|
||||
while(not write_buffer.empty()):
|
||||
time.sleep(0.2)
|
||||
time.sleep(0.5)
|
||||
time.sleep(0.5)
|
||||
|
|
Loading…
Reference in New Issue