Frame-by-frame video processing in a Jupyter Notebook with PyTorch
By Robert Russell
- 5 minutes read - 987 wordsToday’s goal is just to load a video, display individual frames in the output from a Jupyter notebook cell, and write the video back out to a new file. In the middle I’ll do a little processing on the video frames. The processing is beside the point today - I just want to make the input, interaction, and output work really well so that later I can focus more on that processing step in the middle. Think of this as laying the groundwork to for future experiments.
The torchvision.io library has warnings about being replaced by torchcodec but I chose to stick with torchvision for this experiment. I want to write out my output as a video file and torchcodec doesn’t currently support that afaict.
Load all frames from the video
from torchvision.io import read_video, write_video
frame_tensors, _, _ = read_video(vid2, output_format="TCHW")
frame_count = len(frame_tensors)
def get_frame_tensor(frame_number: int) -> torch.Tensor:
return frame_tensors[frame_number]
The two ignored return values from torchvision.io.read_video()
are audio samples and video metadata. In my notebooks I wrap up global variables with an accessor. The get_frame_tensor()
doesn’t hide the varoable but using the accessor consistently makes it easier for me to change out the way frames are loaded later on. That migration to torchcodec
, for example, should be a little easier. We’ll see the accessor also lets me use the global as a cache or lazy-load expensive data.
Show the frames in the notebook.
There are a lot of different approaches for showing an image in a notebook. If you’re already relying on matplotlib then you might show it as an image attached to an axis. Those are hard to animate in my experience. I want to use Jupyter widgets to build up a minimal interactive UI. There is a video player widget but that seems to wrap an html5 video element. I don’t just want to play the video though - I want to be able to perform operations on frames and display those derived frames along with the original. Instead I’ll combine a few controls so that I can scrub through frames and work with the data in more depth.
A custom display is straightforward by combining an Image
widget and a way to select the frame number, like an IntSlider
. Adding an animation widget also allows playing frames in sequence.
I’ve included “Save” button which delegates to a function I’ll show later.
%matplotlib widget
def update_frame(change):
frame_image_widget.set_trait('value', get_frame_image(change.new))
def save_button_click(event):
try:
save_button.disabled = True
save_all(get_frame_tensor)
save_button.disabled = False
except():
save_button.disabled = False
frame_slider = widgets.IntSlider(
description='Frame:',
value=0,
min=0,
max=frame_count-1)
frame_slider.observe(update_frame, names='value')
frame_image_widget = widgets.Image(value=get_frame_image(frame_slider.value))
play = widgets.Play(
interval=30,
value=1,
min=0,
max=frame_count-1,
step=1)
widgets.jslink((play, 'value'), (frame_slider, 'value'))
save_button = widgets.Button(description="Save")
save_button.on_click(save_button_click)
control_box = widgets.VBox([frame_slider, play, save_button])
widgets.AppLayout(
center=frame_image_widget,
footer=control_box)
Aside - making images out of tensors
The images loaded from the video are stored as a tensor. The Image
widget wraps an html img
tag and will only show images that are supported in browsers. So we need to convert to PNG, JPEG, or something like that. There’s a popular convoluted snippet of code for doing that with Pillow that I’ve isolated into this function.
import torchvision.transforms.functional as F
from PIL.Image import Image as PILImage
from io import BytesIO
def tensor_to_image(tensor_frame: torch.Tensor, format: str='PNG', reduction_factor: int=4) -> bytes:
height_width = tensor_frame.size()[-2:]
reduced_height_width = [i // reduction_factor for i in height_width]
pil_image: PILImage = F.to_pil_image(F.resize(tensor_frame, reduced_height_width))
with BytesIO() as image_binary:
pil_image.save(image_binary, format)
return image_binary.getvalue()
And to avoid repeating the conversion from tensor to image I spend some more meory to cache at the locations where I call tensor_to_image()
.
def get_frame_tensor(frame_number: int) -> torch.Tensor:
return frame_tensors[frame_number]
def get_frame_image(frame_number: int) -> bytes:
if not frame_number in frame_images:
frame_images[frame_number] = tensor_to_image(get_frame_tensor(frame_number))
return frame_images[frame_number]
Processing
At this point you can apply some processing to the image. Add a second set of controls to show the processed image. Link the frames in the original and the processed. And, if the processing takes more than a fraction of a second, cache the result.
Today’s example just calls torchvision.transforms.functional.invert()
with the result cached in a dictionary derived_frame_images
.
derived_frame_tensors: dict[int, bytes] = {}
derived_frame_images: dict[int, bytes] = {}
def get_derived_frame_tensor(frame_number: int) -> torch.Tensor:
derived_frame_tensors[frame_number] = F.invert(get_frame_tensor(frame_number))
return derived_frame_tensors[frame_number]
def get_derived_frame_image(frame_number: int) -> bytes:
if not frame_number in derived_frame_images:
derived_frame_images[frame_number] = tensor_to_image(
get_derived_frame_tensor(frame_number))
return derived_frame_images[frame_number]
To show the original and result side-by-side I’ve just added a second Image
widget whose value is updated at same time as the original. The AppLayout
puts the two images in the left and right sidebar rather than just one in the center.
%matplotlib widget
def update_frame(change):
frame_image_widget.set_trait('value', get_frame_image(change.new))
derived_frame_widget.set_trait('value', get_derived_frame_image(change.new))
def save_button_click(event):
try:
save_button.disabled = True
save_all(get_derived_frame_tensor)
save_button.disabled = False
except():
save_button.disabled = False
frame_slider = widgets.IntSlider(
description='Frame:',
value=0,
min=0,
max=frame_count-1)
frame_slider.observe(update_frame, names='value')
frame_image_widget = widgets.Image(value=get_frame_image(frame_slider.value))
derived_frame_widget = widgets.Image(value=get_derived_frame_image(frame_slider.value))
play = widgets.Play(
interval=30,
value=1,
min=0,
max=frame_count-1,
step=1)
widgets.jslink((play, 'value'), (frame_slider, 'value'))
save_button = widgets.Button(description="Save")
save_button.on_click(save_button_click)
control_box = widgets.VBox([frame_slider, play, save_button])
widgets.AppLayout(
left_sidebar=frame_image_widget,
right_sidebar=derived_frame_widget,
footer=control_box,
pane_widths=[1,0,1]
)
Save the result
Finally, to share the results we need a way to save the output as video. As I mentioned earlier, the “save” button click handler is hooked up to call save_all()
. In order to write out the video we want to call torchvision.io.write_video()
. Note that it needs a tensor in THWC format. When we loaded with read_video()
the tensors were created in TCHW order. These shorthands give the order of the dimensions in the tensor for Time (i.e. frame count) Height, Width, and Channels (i.e. colours or similar). Each frame is a tensor with C, H, and W dimensions. Stacking the frames gives us time or frame count dimension. All that’s left is to reorder the dimensions of the resulting tensor with permute()
.
def tchw_to_thwc(tensor_tchw: torch.Tensor):
return tensor_tchw.permute(0, 2, 3, 1)
def save_video_frames(filename: str, tensor_tchw: torch.Tensor):
write_video(filename, tchw_to_thwc(tensor_tchw), 30)
def save_all():
all_frames = torch.stack([get_mask_tensor_chw(f) for f in range(0, frame_count)])
save_video_frames('output-masked.mp4', all_frames)
The full notebook is available here and the sample video. Try it out and see what you think.