Masking video backgrounds with Apple DepthPro and PyTorch
By Robert Russell
- 6 minutes read - 1159 wordsLast year I wrote about loading and saving video in a Jupyter Notebook for frame-by-frame processing with PyTorch. Today I’d like to explain more of the actual image processing that motivated me. It started from tinkering with Apple’s Depth Pro model. I just wanted to see how it performed with some arbitrary video and maybe use it to separate the background from the foreground.
Today I’ll focus on just two main tasks that differ from the last notebook:
- Using the Apple DepthPro model
- Relevant UI controls
Using the Apple DepthPro model
DepthPro is a monocular depth model, meaning it infers depth for each pixel in a single still image. The repo includes example code which loads an image, runs it through the model, and then saves the output as a depth map. There’s also a paper I haven’t read - Depth Pro: Sharp Monocular Metric Depth in Less Than a Second.
The example code is fine for testing out a few still images. I’m interested in processing video though. And as I described in my previous post, I also really want an interface that encourages interactive exploration. A responsive environment to play with the code is just more fun. So I spent some time on interactive UI elements and, to keep things speedy, caching function results.
Throughout the rest of the post I’ll explain my notebook and highlight some code snippets.
About caching
The Depth Pro model produces results in a fraction of a second, so why cache? A cache is only memory access and memory access is still faster than reprocessing frames. In my last notebook example I wrote custom caching for each method that took more than a few milliseconds. By the time I was done writing that though I recognized the pattern and guessed there might be some built-in support for memoizing or caching function return values. And of course there is. The @cache
decorator from the functools
library seems to do the trick quite nicely.
Restart the Jupyter kernel when editting the bodies of any function with the @cache
decorator. I expect that the cache will not be cleared otherwise.
Load the model
One more aside before we really get into it. My notebook lives in a working copy of the git repo but not in the root. The default DepthProConfig
assumes the model is under .checkpoints/depth_pro.pt
. So I needed to use a custom configuration which I wrapped in load_model()
:
def load_model():
MONODEPTH_CONFIG_DICT = DepthProConfig(
patch_encoder_preset="dinov2l16_384",
image_encoder_preset="dinov2l16_384",
# checkpoint_uri="./checkpoints/depth_pro.pt",
checkpoint_uri="/home/my-stuff/ml-depth-pro/checkpoints/depth_pro.pt",
decoder_features=256,
use_fov_head=True,
fov_encoder_preset="dinov2l16_384",
)
model, transform = create_model_and_transforms(
config=MONODEPTH_CONFIG_DICT,
device=get_torch_device(),
precision=torch.half,
)
model.eval()
return model, transform
Another option would be to create a symlink so the default configuration finds the file where it’s expected.
Masks
I want to show which pixels which are in the background and which are in the foreground by applying a semitransparent mask to pixels beyond some specific depth threshold. It turns out that this example from the PyTorch docs gives a really clear and straightforward way to exactly that.
The core functions are draw_segmentation_masks()
from torchvision.utils
and softmax()
from torch.nn.functional
.
def get_frame_tensor(frame_number: int) -> torch.Tensor:
return frame_tensors[frame_number]
@cache
def get_frame_image(frame_number: int) -> bytes:
return tensor_to_image(get_frame_tensor(frame_number))
@cache
def get_mask_tensor_bool(frame_number: int, threshold: float=0.0001) -> torch.Tensor:
return torch.nn.functional.softmax(get_depth_tensor(frame_number), dim=1) > threshold
@cache
def get_mask_tensor_chw(frame_number: int, threshold: float=0.0001, alpha: float=0.8) -> torch.Tensor:
return draw_segmentation_masks(get_frame_tensor(frame_number), get_mask_tensor_bool(frame_number, threshold), alpha=alpha)
@cache
def get_mask_image(frame_number: int, threshold: float=0.0001, alpha: float=0.8) -> bytes:
return tensor_to_image(get_mask_tensor_chw(frame_number, threshold, alpha))
UI controls
I went pretty deep on making some UI controls that work well for the repetitive pattern of tweaking values and just seeing what happens. The UI controls simply adjust parameter values for frame selection, mask threshold and alpha blending. UI events trigger functions which update the displayed images based on these parameters. In the end this lets me tweak values, scrub through frames, and get a feel for how things interact.
Since I want to see how well I can mask out the background, I started by adding a slider to set the depth limit. This will be used to choose which pixels are masked. A second slider lets me set an opacity or alpha value rather than completely blacking out the background.
depth_slider = widgets.FloatSlider(
orientation='horizontal',
description='Depth:',
value=0.001,
min=0.00,
max=1.0,
step=0.0000001,
readout=True,
readout_format='.4f',
)
alpha_slider = widgets.FloatSlider(
orientation='horizontal',
description='Alpha:',
value=0.8,
min=0.0,
max=1.0,
step=0.02
)
masked_image_widget = widgets.Image(
value=get_mask_image(
frame_slider.value, threshold=depth_slider.value, alpha=alpha_slider.value
)
)
def update_depth(change):
masked_image_widget.set_trait(
"value",
get_mask_image(
frame_slider.value, threshold=change.new, alpha=alpha_slider.value
),
)
depth_slider.observe(update_depth, names='value')
def update_alpha(change):
masked_image_widget.set_trait(
"value",
get_mask_image(
frame_slider.value, threshold=depth_slider.value, alpha=change.new
),
)
alpha_slider.observe(update_alpha, names='value')
As you can see, each slider’s .observe()
method hooks up the update functions so that the masked_image_widget
gets a new image after either parameter is updated.
Here’s what this depth slider looks like
It works fine in principle but it’s hard to pick a value that might be 4 decimal places with a slider. The widget does allow typing but that defeats the purpose.
Enter the depth range selector.
This UI control will be used as a control for another UI control.
IPywidgets includes a range slider which allows the user to easily pick two ends of a range. I want a quick way to set the minimum and maximum values for my depth slider. So I just added a range selector and use the values from the range selector to set the endpoints of the depth range.
transforms_lower = [
lambda value, r=depth_range: value[0],
lambda value, r=depth_range: [min(value, r.value[0]), r.value[1]],
]
transforms_upper = [
lambda value, r=depth_range: value[1],
lambda value, r=depth_range: [r.value[0], max(value, r.value[1])],
]
widgets.link((depth_range, "value"), (depth_slider, "min"), transform=transforms_lower)
widgets.link((depth_range, "value"), (depth_slider, "max"), transform=transforms_upper)
It would be logical to compute a good range based on the depth values found in the frame but the range could vary from frame to frame. To figure out the valid ranges across all frames would require processing all frames of the video. That’s a reasonable idea but it could take a very long time and I don’t want to assume that the user would always want to wait through that. So a range slider is my quick solution for this sort of exploration.
Frame selection
Frames are selected by the animation widget, the same as before. But I’ve added a little text box to show how long it took to do the latest update. This info gets written out into a simple text widget on each frame update.
message_view = widgets.Text(
value='Ready',
description='Status:',
disabled=True
)
def update_frame(change):
start_process_ns = time.process_time_ns()
start_perf_ns = time.perf_counter_ns()
mi = get_mask_image(
change.new, threshold=depth_slider.value, alpha=alpha_slider.value
)
di = get_depth_image(change.new)
fi = get_frame_image(change.new)
end_process_ns = time.process_time_ns()
end_perf_ns = time.perf_counter_ns()
process_ms = (end_process_ns - start_process_ns) // 1_000_000
perf_ms = (end_perf_ns - start_perf_ns) // 1_000_000
message_view.set_trait(
"value",
f"Frame update {change.new} took {process_ms}ms (process) {perf_ms}ms (wall clock)",
)
# Update after getting potentially slow resources
frame_image_widget.set_trait("value", fi)
masked_image_widget.set_trait("value", mi)
depth_image_widget.set_trait("value", di)
frame_slider.observe(update_frame, names="value")
The full notebook is available here and the sample video. You’ll need to pull the Apple DepthPro repo as well in order to use it. Try it out and see what you think.