batched inference by n-poulsen · Pull Request #2708 · DeepLabCut/DeepLabCut
Expand Up
@@ -118,12 +118,11 @@ def video_inference(
if detector_runner is None:
raise ValueError("Must use a detector for top-down video analysis")
print("Running Detector") print(f"Running detector with batch size {detector_runner.batch_size}") bbox_predictions = detector_runner.inference(images=tqdm(video))
video.set_context(bbox_predictions)
print("Running Pose Prediction") print(f"Running pose prediction with batch size {pose_runner.batch_size}") predictions = pose_runner.inference(images=tqdm(video))
if with_identity: Expand Down Expand Up @@ -161,7 +160,8 @@ def analyze_videos( detector_snapshot_index: int | str | None = None, device: str | None = None, destfolder: str | None = None, batchsize: int | None = None, batch_size: int | None = None, detector_batch_size: int | None = None, modelprefix: str = "", transform: A.Compose | None = None, auto_track: bool | None = True, Expand All @@ -171,7 +171,6 @@ def analyze_videos( """Makes prediction based on a trained network.
# TODO: - allow batch size greater than 1 - other options missing options such as shelve - pass detector path or detector runner
Expand Down Expand Up @@ -206,8 +205,10 @@ def analyze_videos( snapshot to use, used in the same way as ``snapshot_index`` modelprefix: directory containing the deeplabcut models to use when evaluating the network. By default, they are assumed to exist in the project folder. batchsize: the batch size to use for inference. Takes the value from the PyTorch config as a default batch_size: the batch size to use for inference. Takes the value from the project config as a default. detector_batch_size: the batch size to use for detector inference. Takes the value from the project config as a default. transform: Optional custom transforms to apply to the video overwrite: Overwrite any existing videos auto_track: By default, tracking and stitching are automatically performed, Expand Down Expand Up @@ -261,6 +262,9 @@ def analyze_videos( if device is not None: model_cfg["device"] = device
if batch_size is None: batch_size = cfg["batch_size"]
snapshot = get_model_snapshots(snapshot_index, train_folder, pose_task)[0] print(f"Analyzing videos with {snapshot.path}") detector_path, detector_snapshot = None, None Expand All @@ -272,6 +276,9 @@ def analyze_videos( "project's configuration file." )
if detector_batch_size is None: detector_batch_size = cfg.get("detector_batch_size", 1)
detector_snapshot = get_model_snapshots( detector_snapshot_index, train_folder, Task.DETECT )[0] Expand All @@ -291,8 +298,10 @@ def analyze_videos( max_individuals=max_num_animals, num_bodyparts=len(bodyparts), num_unique_bodyparts=len(unique_bodyparts), batch_size=batch_size, with_identity=with_identity, transform=transform, detector_batch_size=detector_batch_size, detector_path=detector_path, detector_transform=None, ) Expand Down Expand Up @@ -326,7 +335,7 @@ def analyze_videos( pytorch_config=model_cfg, dlc_scorer=dlc_scorer, train_fraction=train_fraction, batch_size=batchsize, batch_size=batch_size, runtime=(runtime[0], runtime[1]), video=VideoReader(str(video)), ) Expand Down
print("Running Detector") print(f"Running detector with batch size {detector_runner.batch_size}") bbox_predictions = detector_runner.inference(images=tqdm(video))
video.set_context(bbox_predictions)
print("Running Pose Prediction") print(f"Running pose prediction with batch size {pose_runner.batch_size}") predictions = pose_runner.inference(images=tqdm(video))
if with_identity: Expand Down Expand Up @@ -161,7 +160,8 @@ def analyze_videos( detector_snapshot_index: int | str | None = None, device: str | None = None, destfolder: str | None = None, batchsize: int | None = None, batch_size: int | None = None, detector_batch_size: int | None = None, modelprefix: str = "", transform: A.Compose | None = None, auto_track: bool | None = True, Expand All @@ -171,7 +171,6 @@ def analyze_videos( """Makes prediction based on a trained network.
# TODO: - allow batch size greater than 1 - other options missing options such as shelve - pass detector path or detector runner
Expand Down Expand Up @@ -206,8 +205,10 @@ def analyze_videos( snapshot to use, used in the same way as ``snapshot_index`` modelprefix: directory containing the deeplabcut models to use when evaluating the network. By default, they are assumed to exist in the project folder. batchsize: the batch size to use for inference. Takes the value from the PyTorch config as a default batch_size: the batch size to use for inference. Takes the value from the project config as a default. detector_batch_size: the batch size to use for detector inference. Takes the value from the project config as a default. transform: Optional custom transforms to apply to the video overwrite: Overwrite any existing videos auto_track: By default, tracking and stitching are automatically performed, Expand Down Expand Up @@ -261,6 +262,9 @@ def analyze_videos( if device is not None: model_cfg["device"] = device
if batch_size is None: batch_size = cfg["batch_size"]
snapshot = get_model_snapshots(snapshot_index, train_folder, pose_task)[0] print(f"Analyzing videos with {snapshot.path}") detector_path, detector_snapshot = None, None Expand All @@ -272,6 +276,9 @@ def analyze_videos( "project's configuration file." )
if detector_batch_size is None: detector_batch_size = cfg.get("detector_batch_size", 1)
detector_snapshot = get_model_snapshots( detector_snapshot_index, train_folder, Task.DETECT )[0] Expand All @@ -291,8 +298,10 @@ def analyze_videos( max_individuals=max_num_animals, num_bodyparts=len(bodyparts), num_unique_bodyparts=len(unique_bodyparts), batch_size=batch_size, with_identity=with_identity, transform=transform, detector_batch_size=detector_batch_size, detector_path=detector_path, detector_transform=None, ) Expand Down Expand Up @@ -326,7 +335,7 @@ def analyze_videos( pytorch_config=model_cfg, dlc_scorer=dlc_scorer, train_fraction=train_fraction, batch_size=batchsize, batch_size=batch_size, runtime=(runtime[0], runtime[1]), video=VideoReader(str(video)), ) Expand Down