improved handling of animal names when analyzing videos by n-poulsen · Pull Request #2884 · DeepLabCut/DeepLabCut

Expand Up @@ -8,6 +8,8 @@ # # Licensed under GNU Lesser General Public License v3.0 # from typing import List, Optional
import matplotlib.pyplot as plt import networkx as nx import numpy as np Expand Down Expand Up @@ -891,8 +893,11 @@ def concatenate_data(self):
def format_df(self, animal_names=None): data = self.concatenate_data() if not animal_names or len(animal_names) != self.n_tracks: if not animal_names or len(animal_names) < self.n_tracks: animal_names = [f"ind{i}" for i in range(1, self.n_tracks + 1)] elif len(animal_names) > self.n_tracks: animal_names = animal_names[:self.n_tracks]
coords = ["x", "y", "likelihood"] n_multi_bpts = data.shape[1] // (len(animal_names) * len(coords)) n_unique_bpts = 0 if self.single is None else self.single.data.shape[1] Expand Down Expand Up @@ -1031,6 +1036,7 @@ def stitch_tracklets( shuffle=1, trainingsetindex=0, n_tracks=None, animal_names: Optional[List[str]] = None, min_length=10, split_tracklets=True, prestitch_residuals=True, Expand Down Expand Up @@ -1071,6 +1077,13 @@ def stitch_tracklets( passed if the number of animals in the video is different from the number of animals the model was trained on.
animal_names: list, optional If you want the names given to individuals in the labeled data file, you can specify those names as a list here. If given and `n_tracks` is None, `n_tracks` will be set to `len(animal_names)`. If `n_tracks` is not None, then it must be equal to `len(animal_names)`. If it is not given, then `animal_names` will be loaded from the `individuals` in the project config.yaml file.
min_length : int, optional Tracklets less than `min_length` frames of length are considered to be residuals; i.e., they do not participate Expand Down Expand Up @@ -1107,8 +1120,8 @@ def stitch_tracklets( tracklets should be stitched together, the lower the returned value.
destfolder: string, optional Specifies the destination folder for analysis data (default is the path of the video). Note that for subsequent analysis this folder also needs to be passed. Specifies the destination folder for analysis data (default is the path of the video). Note that for subsequent analysis this folder also needs to be passed.
track_method: string, optional Specifies the tracker used to generate the pose estimation data. Expand All @@ -1135,7 +1148,14 @@ def stitch_tracklets( cfg = auxiliaryfunctions.read_config(config_path) track_method = auxfun_multianimal.get_track_method(cfg, track_method=track_method)
animal_names = cfg["individuals"] if animal_names is None: animal_names = cfg["individuals"] elif n_tracks is not None and n_tracks != len(animal_names): raise ValueError( "When setting both `n_tracks` and `animal_names`, `n_tracks` must be equal " f"to len(animal_names)`. Found `n_tracks`={n_tracks} and `animal_names`=" f"{animal_names} of length {len(animal_names)}.`")
if n_tracks is None: n_tracks = len(animal_names)
Expand Down