Fix: correct the early return error when save_epochs=1 and optimize the clean up way by xiu-cs · Pull Request #2793 · DeepLabCut/DeepLabCut

Expand Up @@ -113,6 +113,8 @@ def update(self, epoch: int, state_dict: dict, last: bool = False) -> None: ): current_best = self.best() self._best_metric = metrics[self._key]
# Save the new best model save_path = self.snapshot_path(epoch, best=True) parsed_state_dict = { k: v Expand All @@ -121,33 +123,31 @@ def update(self, epoch: int, state_dict: dict, last: bool = False) -> None: } torch.save(parsed_state_dict, save_path)
# Handle previous best model if current_best is not None: # rename if the current best should have been saved, otherwise delete if current_best.epochs % self.save_epochs == 0: new_name = self.snapshot_path(epoch=current_best.epochs) current_best.path.rename(new_name) else: current_best.path.unlink(missing_ok=False) return
if not (last or epoch % self.save_epochs == 0): return elif last or epoch % self.save_epochs == 0: # Save regular snapshot if needed save_path = self.snapshot_path(epoch=epoch) parsed_state_dict = { k: v for k, v in state_dict.items() if self.save_optimizer_state or k != "optimizer" } torch.save(parsed_state_dict, save_path)
# Clean up old snapshots if needed existing_snapshots = [s for s in self.snapshots() if not s.best] if len(existing_snapshots) >= self.max_snapshots: num_to_delete = 1 + len(existing_snapshots) - self.max_snapshots num_to_delete = len(existing_snapshots) - self.max_snapshots to_delete = existing_snapshots[:num_to_delete] for snapshot in to_delete: snapshot.path.unlink(missing_ok=False)
save_path = self.snapshot_path(epoch=epoch) parsed_state_dict = { k: v for k, v in state_dict.items() if self.save_optimizer_state or k != "optimizer" } torch.save(parsed_state_dict, save_path)
def best(self) -> Snapshot | None: """Returns: the path to the best snapshot, if it exists""" snapshots = self.snapshots() Expand Down