return to posts

Video Frames Extraction Using Decord

Sep 14, 2024

Optimized Video Frame Extraction with the VideoFrameExtractor Class

The VideoFrameExtractor class, built on top of Decord, is a powerful tool for optimized video frame extraction. Whether you're processing videos for machine learning, video analysis, or creative projects, this class makes it easy to extract frames with precision and efficiency.


Key Features

1. Flexible Sampling

  • Extract frames using a specified stride or based on frames per second (FPS).

2. Batch Processing

  • Retrieve frames in batches to reduce file access overhead and improve performance.

3. Custom Time Ranges

  • Focus on specific parts of the video by specifying start and end times.

4. Output Formats

  • Choose between Pillow images or NumPy arrays for your extracted frames.

Getting Started

Install Dependencies

Install the required Decord library using pip:

pip install decord

Code Implementation

import decord
from PIL import Image
import numpy as np
import datetime

class VideoFrameExtractor:
    """
    Optimized class for extracting frames from a video using Decord. The frames can be sampled either by
    a specified stride or by frames per second (FPS).

    This optimized version uses batch retrieval and reduces the overhead of repeated conversions.
    """

    def __init__(
        self,
        video_path,
        num_frames=-1,
        stride=1,
        sample_by_fps=1,
        sampling_mode="sample_by_fps",
        batch_size=1,
        output_format="pillow",
        start_time="00:00:00",
        end_time=None,
        ctx=decord.cpu(),
    ):
        """
        Initializes the VideoFrameExtractor with the specified parameters.

        Args:
            video_path (str): Path to the video file.
            num_frames (int, optional): Number of frames to extract. Defaults to -1 (all frames).
            stride (int, optional): Frame stride for extraction. Defaults to 1.
            sample_by_fps (float, optional): Frames to sample per second if using 'sample_by_fps'. Defaults to None.
            sampling_mode (str, optional): Either 'sample_by_stride' or 'sample_by_fps'. Defaults to 'sample_by_stride'.
            batch_size (int, optional): Number of frames to return per batch. Defaults to 1.
            output_format (str, optional): Format of output frames ('pillow' for PIL, or 'numpy'). Defaults to 'pillow'.
            start_time (str, optional): Extraction start time (HH:MM:SS). Defaults to "00:00:00".
            end_time (str, optional): Extraction end time (HH:MM:SS). If None, process until the end. Defaults to None.
            ctx (optional): Decord context (CPU or GPU). Defaults to decord.cpu().
        """
        self.video_path = video_path
        self.decord_vr = decord.VideoReader(video_path, ctx=ctx)
        self.num_frames = num_frames
        self.stride = stride
        self.sample_by_fps = sample_by_fps
        self.sampling_mode = sampling_mode
        self.batch_size = batch_size
        self.output_format = output_format
        self.fps = self.decord_vr.get_avg_fps()
        self.frame_count = len(self.decord_vr)
        self.shape = self.decord_vr[0].shape
        self.total_frames_output = 0  # Tracks total frames extracted

        # Convert start and end times to frame indices
        self.start_frame = self._time_to_frame(start_time)
        self.end_frame = self._time_to_frame(end_time) if end_time else self.frame_count

        # Set the initial frame index to the start frame
        self.current_index = self.start_frame

        # Adjust stride if sampling mode is "sample_by_fps"
        if self.sampling_mode == "sample_by_fps" and self.sample_by_fps:
            self.stride = max(1, int(self.fps / self.sample_by_fps))

        # Precompute the frame indices to extract
        self.frame_indices = list(range(self.start_frame, self.end_frame, self.stride))
        if self.num_frames != -1:
            self.frame_indices = self.frame_indices[:self.num_frames]

    def _time_to_frame(self, time_str):
        """
        Converts a time string in 'HH:MM:SS' format to the corresponding frame index.

        Args:
            time_str (str): Time string in the format "HH:MM:SS".

        Returns:
            int: Corresponding frame index for the given time.
        """
        if not time_str:
            return self.frame_count
        x = datetime.datetime.strptime(time_str, "%H:%M:%S")
        seconds = x.hour * 3600 + x.minute * 60 + x.second
        return int(seconds * self.fps)

    def _frame_to_time(self, frame_index):
        """
        Converts a frame index to a time string in 'HH:MM:SS:ms' format.

        Args:
            frame_index (int): The index of the frame.

        Returns:
            str: The time corresponding to the given frame in "HH:MM:SS:ms" format.
        """
        total_seconds = frame_index / self.fps
        hours = int(total_seconds // 3600)
        minutes = int((total_seconds % 3600) // 60)
        seconds = int(total_seconds % 60)
        milliseconds = int((total_seconds * 1000) % 1000)
        return f"{hours:02}:{minutes:02}:{seconds:02}:{milliseconds:03}"

    def __iter__(self):
        """
        Returns the iterator object itself.

        Returns:
            VideoFrameExtractor: The iterator object.
        """
        return self

    def __next__(self):
        """
        Retrieves the next batch of frames, their timestamps, and frame indices.

        Returns:
            tuple: A tuple containing:
                - frames (list): A list of frames in the specified format (Pillow or NumPy).
                - timestamps (list): Corresponding timestamps for the frames.
                - frame_idxs (list): Corresponding frame indices for the frames.

        Raises:
            StopIteration: When the extraction reaches the end of the video or the specified number of frames.
        """
        if not self.frame_indices or (self.total_frames_output >= self.num_frames and self.num_frames != -1):
            raise StopIteration

        # Determine the batch size to process
        batch_indices = self.frame_indices[:self.batch_size]
        self.frame_indices = self.frame_indices[self.batch_size:]

        # Fetch frames in batch mode from Decord
        frames_batch = self.decord_vr.get_batch(batch_indices).asnumpy()

        frames = []
        timestamps = []
        frame_idxs = []

        for i, frame_idx in enumerate(batch_indices):
            # Convert frame based on the desired output format
            frame = frames_batch[i]
            if self.output_format == "pillow":
                frame = Image.fromarray(frame)

            # Get timestamp and frame number
            timestamp = self._frame_to_time(frame_idx)
            frame_number = frame_idx

            # Append frame, timestamp, and frame number to the result
            frames.append(frame)
            timestamps.append(timestamp)
            frame_idxs.append(frame_number)

            self.total_frames_output += 1

        return (frames, timestamps, frame_idxs)

    def get_video_info(self):
        """
        Returns basic information about the video.

        Returns:
            dict: A dictionary containing:
                - 'fps': Frames per second of the video.
                - 'frame_count': Total number of frames in the video.
                - 'duration': Duration of the video in seconds.
        """
        return {
            "fps": self.fps,
            "frame_count": self.frame_count,
            "duration": self.frame_count / self.fps,
        }

How to Use the VideoFrameExtractor

Example: Extract 10 Frames in a Specific Range

sampler = VideoFrameExtractor(
    video_path=r"c:\path\to\video.mp4",
    num_frames=10,
    stride=25,
    sample_by_fps=4,  # Sample 4 frame per second
    sampling_mode="sample_by_fps",  # Sample frames per second
    batch_size=4,
    output_format="pillow",
    start_time="00:00:00",  # Start at 0 seconds
    end_time="00:02:00",  # End at 2 minutes
)

# Get video information
video_info = sampler.get_video_info()
print("Video Info:", video_info)

frames, timestamps, frame_idxes = next(sampler)
print("Timestamps:, timestamps)
print("Frame Indices:", frame_idxes)

Output

  • Frames: A list of frames in the specified format (Pillow images or NumPy arrays).
  • Timestamps: Corresponding timestamps in HH:MM:SS:ms format.
  • Frame Indices: Indices of the extracted frames.