New detector options and better detector integration by n-poulsen · Pull Request #2676 · DeepLabCut/DeepLabCut

Expand Up @@ -128,6 +128,18 @@ def _generate_layout_attributes(self, layout): lambda _: self.set_edit_table_visibility() )
# Detector selection for top-down models self.detector_label = QtWidgets.QLabel("Detector architecture") self.detector_choice = QtWidgets.QComboBox() self.detector_choice.setMinimumWidth(200) self.update_detectors(engine=self.root.engine) self.root.engine_change.connect( lambda engine: self.update_detectors(engine=engine) ) self.net_choice.currentTextChanged.connect( lambda new_net_choice: self.update_detectors(net_choice=new_net_choice) )
# Overwrite selection self.overwrite = QtWidgets.QCheckBox("Overwrite if exists") self.overwrite.setChecked(False) Expand All @@ -153,8 +165,11 @@ def _generate_layout_attributes(self, layout): layout.addWidget(augmentation_label, 1, 2) layout.addWidget(self.aug_choice, 1, 3)
layout.addWidget(self.overwrite, 2, 0) layout.addWidget(self.data_split_selection, 3, 0) layout.addWidget(self.detector_label, 2, 0) layout.addWidget(self.detector_choice, 2, 1)
layout.addWidget(self.overwrite, 3, 0) layout.addWidget(self.data_split_selection, 4, 0)
def log_net_choice(self, net): self.root.logger.info(f"Network architecture set to {net.upper()}") Expand Down Expand Up @@ -216,17 +231,22 @@ def create_training_dataset(self): else: try: engine = self.root.engine net_type = self.net_choice.currentText() detector_type = None if engine == Engine.TF: import tensorflow # try importing TF so they can't create shuffles for it if they # don't have it installed elif engine == Engine.PYTORCH and "top_down" in net_type: detector_type = self.detector_choice.currentText()
if self.data_split_selection.selected: deeplabcut.create_training_dataset_from_existing_split( self.root.config, from_shuffle=self.data_split_selection.from_shuffle, shuffles=[self.shuffle.value()], net_type=self.net_choice.currentText(), net_type=net_type, detector_type=detector_type, userfeedback=not overwrite, weight_init=weight_init, engine=engine, Expand All @@ -237,7 +257,8 @@ def create_training_dataset(self): self.root.config, shuffle, Shuffles=[self.shuffle.value()], net_type=self.net_choice.currentText(), net_type=net_type, detector_type=detector_type, userfeedback=not overwrite, weight_init=weight_init, engine=engine, Expand All @@ -247,7 +268,8 @@ def create_training_dataset(self): self.root.config, shuffle, Shuffles=[self.shuffle.value()], net_type=self.net_choice.currentText(), net_type=net_type, detector_type=detector_type, augmenter_type=self.aug_choice.currentText(), userfeedback=not overwrite, weight_init=weight_init, Expand Down Expand Up @@ -391,6 +413,42 @@ def update_nets(self, engine: Engine | None) -> None: if default_net in nets: self.net_choice.setCurrentIndex(nets.index(default_net))
@Slot(Engine) def update_detectors( self, engine: Engine | None = None, net_choice: str | None = None, ) -> None: if engine is None: engine = self.root.engine
if engine == Engine.TF: detectors = [] else: # FIXME: Circular imports make it impossible to import this at the top from deeplabcut.pose_estimation_pytorch import available_detectors detectors = available_detectors() det_filter = self.get_detector_filter() if det_filter is not None: detectors = [d for d in detectors if d in det_filter]
while self.detector_choice.count() > 0: self.detector_choice.removeItem(0)
self.detector_choice.addItems(detectors) if "ssdlite" in detectors: self.detector_choice.setCurrentIndex(detectors.index("ssdlite"))
if net_choice is None: net_choice = self.net_choice.currentText()
if "top_down" in net_choice: self.detector_label.show() self.detector_choice.show() else: self.detector_label.hide() self.detector_choice.hide()
@Slot(Engine) def update_aug_methods(self, engine: Engine) -> None: methods = compat.get_available_aug_methods(engine) Expand Down Expand Up @@ -422,6 +480,17 @@ def get_net_filter(self) -> list[str] | None: weight_init_cfg = _WEIGHT_INIT_OPTIONS[self.weight_init_selector.weight_init] return weight_init_cfg["model_filter"]
def get_detector_filter(self) -> list[str] | None: """Returns: the detectors that can be used based on weight initialization""" if self.root.engine != Engine.PYTORCH: return None
if self.weight_init_selector.weight_init not in _WEIGHT_INIT_OPTIONS: return None
weight_init_cfg = _WEIGHT_INIT_OPTIONS[self.weight_init_selector.weight_init] return weight_init_cfg["detector_filter"]
def get_default_net(self) -> str | None: """Returns: the net type that can be used based on weight initialization""" if self.root.engine != Engine.PYTORCH: Expand Down Expand Up @@ -624,13 +693,15 @@ def _create_confirmation_box(title, description): _WEIGHT_INIT_OPTIONS = { # FIXME - Generate dynamically "Transfer Learning - ImageNet": { "model_filter": None, "detector_filter": None, }, "Transfer Learning - SuperAnimal Quadruped": { "default_net": "top_down_hrnet_w32", "model_filter": [ "dekr_w32", "hrnet_w32", ], "detector_filter": ["fasterrcnn_resnet50_fpn_v2"], "super_animal": "superanimal_quadruped", }, "Transfer Learning - SuperAnimal TopViewMouse": { Expand All @@ -639,16 +710,19 @@ def _create_confirmation_box(title, description): "dekr_w32", "hrnet_w32", ], "detector_filter": ["fasterrcnn_resnet50_fpn_v2"], "super_animal": "superanimal_topviewmouse", }, "Fine-tuning - SuperAnimal Quadruped": { "default_net": "top_down_hrnet_w32", "model_filter": ["hrnet_w32"], # FIXME - Add ResNet "detector_filter": ["fasterrcnn_resnet50_fpn_v2"], "super_animal": "superanimal_quadruped", }, "Fine-tuning - SuperAnimal TopViewMouse": { "default_net": "top_down_hrnet_w32", "model_filter": ["hrnet_w32"], # FIXME - Add ResNet "detector_filter": ["fasterrcnn_resnet50_fpn_v2"], "super_animal": "superanimal_topviewmouse", }, }