[feat] support frames packing for minicpmv4_5 video processing by fanqiNO1 · Pull Request #8046 · modelscope/ms-swift

import math
import json
import os
from copy import deepcopy
from decord import VideoReader, cpu

import numpy as np
import torch
from PIL import Image
from scipy.spatial import cKDTree
from transformers import AutoProcessor

from swift.model import get_processor
from swift.template import get_template


MAX_NUM_FRAMES = 180
MAX_NUM_PACKING = 3
TIME_SCALE = 0.1

video_path = "./test_video.mp4"
user_prompt = "Describe the video"
fps = 5
force_packing = None


def map_to_nearest_scale(values, scale):
    tree = cKDTree(np.asarray(scale)[:, None])
    _, indices = tree.query(np.asarray(values)[:, None])
    return np.asarray(scale)[indices]


def group_array(arr, size):
    return [arr[i:i+size] for i in range(0, len(arr), size)]


def encode_video(video_path, choose_fps=3, force_packing=None):
    def uniform_sample(l, n):
        gap = len(l) / n
        idxs = [int(i * gap + gap / 2) for i in range(n)]
        return [l[i] for i in idxs]
    vr = VideoReader(video_path, ctx=cpu(0))
    fps = vr.get_avg_fps()
    video_duration = len(vr) / fps
        
    if choose_fps * int(video_duration) <= MAX_NUM_FRAMES:
        packing_nums = 1
        choose_frames = round(min(choose_fps, round(fps)) * min(MAX_NUM_FRAMES, video_duration))
        
    else:
        packing_nums = math.ceil(video_duration * choose_fps / MAX_NUM_FRAMES)
        if packing_nums <= MAX_NUM_PACKING:
            choose_frames = round(video_duration * choose_fps)
        else:
            choose_frames = round(MAX_NUM_FRAMES * MAX_NUM_PACKING)
            packing_nums = MAX_NUM_PACKING

    frame_idx = [i for i in range(0, len(vr))]      
    frame_idx =  np.array(uniform_sample(frame_idx, choose_frames))

    if force_packing:
        packing_nums = min(force_packing, MAX_NUM_PACKING)
    
    print(video_path, ' duration:', video_duration)
    print(f'get video frames={len(frame_idx)}, packing_nums={packing_nums}')
    
    frames = vr.get_batch(frame_idx).asnumpy()

    frame_idx_ts = frame_idx / fps
    scale = np.arange(0, video_duration, TIME_SCALE)

    frame_ts_id = map_to_nearest_scale(frame_idx_ts, scale) / TIME_SCALE
    frame_ts_id = frame_ts_id.astype(np.int32)

    assert len(frames) == len(frame_ts_id)

    frames = [Image.fromarray(v.astype('uint8')).convert('RGB') for v in frames]
    frame_ts_id_group = group_array(frame_ts_id, packing_nums)
    
    return frames, frame_ts_id_group


def minicpmv4_5_official():
    processor = AutoProcessor.from_pretrained("OpenBMB/MiniCPM-V-4_5", trust_remote_code=True)

    frames, frame_ts_id_group = encode_video(video_path, fps, force_packing=force_packing)

    messages_list = [[{'role': 'user', 'content': frames + [user_prompt]}]]
    images_list = [None]

    prompts_lists = []
    input_images_lists = []

    for image, msgs in zip(images_list, messages_list):
        if isinstance(msgs, str):
            msgs = json.loads(msgs)
        copy_msgs = deepcopy(msgs)

        if image is not None and isinstance(copy_msgs[0]["content"], str):
            copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]

        images = []
        for i, msg in enumerate(copy_msgs):
            role = msg["role"]
            content = msg["content"]
            assert role in ["system", "user", "assistant"]
            if isinstance(content, str):
                content = [content]
            cur_msgs = []
            for c in content:
                if isinstance(c, Image.Image):
                    images.append(c)
                    cur_msgs.append("(<image>./</image>)")
                elif isinstance(c, str):
                    cur_msgs.append(c)
            msg["content"] = "\n".join(cur_msgs)

        prompts_lists.append(processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True, enable_thinking=False))
        input_images_lists.append(images)

    inputs = processor(
        prompts_lists,
        input_images_lists,
        max_slice_nums=1,
        use_image_id=False,
        temporal_ids=frame_ts_id_group,
        return_tensors="pt"
    )

    input_string = processor.tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=False)
    print("Official Decoded input string:", input_string[0])
    return inputs


def swift_template_test():

    os.environ["VIDEO_MAX_SLICE_NUMS"] = "1"
    os.environ["MAX_NUM_FRAMES"] = str(MAX_NUM_FRAMES)
    os.environ["MAX_NUM_PACKING"] = str(MAX_NUM_PACKING)
    os.environ["TIME_SCALE"] = str(TIME_SCALE)
    os.environ["CHOOSE_FPS"] = str(fps)

    processor = get_processor("OpenBMB/MiniCPM-V-4_5")
    template = get_template(processor, enable_thinking=False)

    inputs = {
        "messages": [
            {"role": "user", "content": f"<video>{user_prompt}"}
        ],
        "videos": [video_path]
    }

    inputs = template.encode(inputs)
    input_string = template.safe_decode(inputs["input_ids"])
    print("Swift Decoded input string:", input_string)
    return inputs


def is_equal(value1, value2):
    if isinstance(value1, list) and isinstance(value2, list):
        if len(value1) != len(value2):
            return False
        for v1, v2 in zip(value1, value2):
            if not is_equal(v1, v2):
                return False
        return True
    elif isinstance(value1, torch.Tensor) and isinstance(value2, torch.Tensor):
        if value1.shape != value2.shape:
            print(f"Tensor shapes differ: {value1.shape} vs {value2.shape}")
            return False
        if not torch.equal(value1, value2):
            print(f"Tensor values differ at some positions.")
            return False
        return True
    else:
        return value1 == value2


def main():
    official_inputs = minicpmv4_5_official()
    swift_inputs = swift_template_test()

    print("Official inputs keys:", list(official_inputs.keys()))
    print("Swift inputs keys:", list(swift_inputs.keys()))

    for key in swift_inputs.keys():
        assert key in official_inputs, f"Key '{key}' not found in official inputs"
        print(f"Comparing key: {key}")

        if key == "input_ids":
            official_value = official_inputs[key][0].tolist()
        elif key == "pixel_values":
            official_value = official_inputs[key]
            for i in range(len(official_value[0])):
                official_value[0][i] = official_value[0][i].to(torch.bfloat16)
        else:
            official_value = official_inputs[key]
        swift_value = swift_inputs[key]

        assert isinstance(official_value, list)
        assert isinstance(swift_value, list)
        assert len(official_value) == len(swift_value), f"len(official[{key}])={len(official_value)} vs len(swift[{key}])={len(swift_value)})"

        for i, (o, s) in enumerate(zip(official_value, swift_value)):
            if not is_equal(o, s):
                print(f"❌ Difference found in key '{key}' at index {i}")
                break
                # raise AssertionError(f"Values for key '{key}' at index {i} do not match.")
        else:
            print(f"✔️ Values match for key '{key}'.")

if __name__ == "__main__":
    main()