mirror of https://github.com/n00mkrad/flowframes
flavr.py: removed unnecessary functions, added jpeg input support
This commit is contained in:
parent
c88d190118
commit
e90bce1989
|
@ -33,7 +33,6 @@ parser.add_argument('--fp16', dest='fp16', action='store_true', help='half-preci
|
|||
parser.add_argument('--imgformat', default="png")
|
||||
parser.add_argument("--output_ext", type=str, help="Output video format", default=".avi")
|
||||
parser.add_argument("--input_ext", type=str, help="Input video format", default=".mp4")
|
||||
parser.add_argument("--downscale", type=float, help="Downscale input res. for memory", default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
input_ext = args.input_ext
|
||||
|
@ -82,7 +81,7 @@ def make_image(img):
|
|||
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
||||
return im
|
||||
|
||||
def files_to_videoTensor(path, downscale=1.):
|
||||
def files_to_videoTensor(path):
|
||||
from PIL import Image
|
||||
global in_files
|
||||
in_files_fixed = in_files
|
||||
|
@ -93,22 +92,18 @@ def files_to_videoTensor(path, downscale=1.):
|
|||
videoTensor = torch.stack(images)
|
||||
return videoTensor
|
||||
|
||||
def video_transform(videoTensor, downscale=1):
|
||||
def video_transform(videoTensor):
|
||||
T, H, W = videoTensor.size(0), videoTensor.size(1), videoTensor.size(2)
|
||||
downscale = int(downscale * 8)
|
||||
resizes = 8*(H//downscale), 8*(W//downscale)
|
||||
transforms = torchvision.transforms.Compose([ToTensorVideo(), Resize(resizes)])
|
||||
transforms = torchvision.transforms.Compose([ToTensorVideo()])
|
||||
videoTensor = transforms(videoTensor)
|
||||
|
||||
print("Resizing to %dx%d"%(resizes[0], resizes[1]) )
|
||||
return videoTensor, resizes
|
||||
return videoTensor
|
||||
|
||||
videoTensor = files_to_videoTensor(interp_input_path, args.downscale)
|
||||
videoTensor = files_to_videoTensor(interp_input_path)
|
||||
|
||||
print(f"Video Tensor len: {len(videoTensor)}")
|
||||
idxs = torch.Tensor(range(len(videoTensor))).type(torch.long).view(1, -1).unfold(1,size=nbr_frame,step=1).squeeze(0)
|
||||
print(f"len(idxs): {len(idxs)}")
|
||||
videoTensor, resizes = video_transform(videoTensor, args.downscale)
|
||||
videoTensor = video_transform(videoTensor)
|
||||
print("Video tensor shape is ", videoTensor.shape)
|
||||
|
||||
frames = torch.unbind(videoTensor, 1)
|
||||
|
|
Loading…
Reference in New Issue