flavr.py: removed unnecessary functions, added jpeg input support

This commit is contained in:
N00MKRAD 2021-04-25 21:52:40 +02:00
parent c88d190118
commit e90bce1989
1 changed files with 6 additions and 11 deletions

View File

@ -33,7 +33,6 @@ parser.add_argument('--fp16', dest='fp16', action='store_true', help='half-preci
parser.add_argument('--imgformat', default="png")
parser.add_argument("--output_ext", type=str, help="Output video format", default=".avi")
parser.add_argument("--input_ext", type=str, help="Input video format", default=".mp4")
parser.add_argument("--downscale", type=float, help="Downscale input res. for memory", default=1)
args = parser.parse_args()
input_ext = args.input_ext
@ -82,7 +81,7 @@ def make_image(img):
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
return im
def files_to_videoTensor(path, downscale=1.):
def files_to_videoTensor(path):
from PIL import Image
global in_files
in_files_fixed = in_files
@ -93,22 +92,18 @@ def files_to_videoTensor(path, downscale=1.):
videoTensor = torch.stack(images)
return videoTensor
def video_transform(videoTensor, downscale=1):
def video_transform(videoTensor):
T, H, W = videoTensor.size(0), videoTensor.size(1), videoTensor.size(2)
downscale = int(downscale * 8)
resizes = 8*(H//downscale), 8*(W//downscale)
transforms = torchvision.transforms.Compose([ToTensorVideo(), Resize(resizes)])
transforms = torchvision.transforms.Compose([ToTensorVideo()])
videoTensor = transforms(videoTensor)
print("Resizing to %dx%d"%(resizes[0], resizes[1]) )
return videoTensor, resizes
return videoTensor
videoTensor = files_to_videoTensor(interp_input_path, args.downscale)
videoTensor = files_to_videoTensor(interp_input_path)
print(f"Video Tensor len: {len(videoTensor)}")
idxs = torch.Tensor(range(len(videoTensor))).type(torch.long).view(1, -1).unfold(1,size=nbr_frame,step=1).squeeze(0)
print(f"len(idxs): {len(idxs)}")
videoTensor, resizes = video_transform(videoTensor, args.downscale)
videoTensor = video_transform(videoTensor)
print("Video tensor shape is ", videoTensor.shape)
frames = torch.unbind(videoTensor, 1)