Improved loading of snapshot weights with` torch.load(..., weights_only=True)` by MMathisLab · Pull Request #2823 · DeepLabCut/DeepLabCut

This pull request offers Improved handling of loading snapshot weights with torch.load(..., weights_only=True). PyTorch snapshots saved with older release candidates could contain some numpy floats, which failed to load with weights_only=True, which can make it annoying to use them as the pytorch_config.yaml needed to be modified with load_weights_only: true keys for both detectors and pose models. In this pull request: the following improvements are made:

Fix the issue with numpy>=1.25

For users with numpy>=1.25 installed, the issue is fixed as the float64 class causing issues can be added to the safe_globals, as done in _add_numpy_to_torch_safe_globals. Current snapshots will be loaded without error, as they are with weights_only=False.

torch.serialization.add_safe_globals([np.dtype, Float64DType, scalar])

This doesn't work in numpy<1.25, as the Float64Dtype did not yet exist (it was a dtype that could only be used internally), and there is no easy way to add np.dtype[np.float64] to the safe globals.

A global variable is set to handle the default weights_only value

The global variable sets the default value given to load_weights_only, when none is specified in the pytorch_config.yaml. This value is controlled through the get_load_weights_only and set_load_weights_only methods, which can be imported through deeplabcut.pose_estimation_pytorch:

>>> from deeplabcut.pose_estimation_pytorch import get_load_weights_only, set_load_weights_only
>>> print(get_load_weights_only())
True
>>> set_load_weights_only(False)
>>> print(get_load_weights_only())
False

When calling torch.load without load_weights_only being specified, get_load_weights_only() is used to get the default value. So when loading snapshots that are known to be safe, set_load_weights_only(False) can be called at the start of a script so that all snapshots are loaded with weights_only=False.

So for users using numpy<1.25 with snapshots that have issues, they can load the snapshots without having to modify the pytorch_config.yaml by just calling from deeplabcut.pose_estimation_pytorch import set_load_weights_only and set_load_weights_only(False) before loading their snapshots.

The initial default load_weights_only value can also be set with an TORCH_LOAD_WEIGHTS_ONLY environment variable, which makes it easier to set this value when working with the GUI. The DeepLabCut GUI can just be launched with TORCH_LOAD_WEIGHTS_ONLY=False python -m deeplabcut, which will set the default load_weights_only value to False.

Function to fix snapshots containing numpy float values

A new fix_snapshot_metadata method is added to replace numpy floats with python floats in existing snapshots.

Improved error message when a snapshot fails to load

The error message when a snapshot fails to load is improved and more descriptive. It now says:

ERROR:root:
Failed to load the snapshot: snapshot-best-200.pt.

If you trust the snapshot that you're trying to load, you can try
calling `Runner.load_snapshot` with `weights_only=False`. See the 
error message below for more information and warnings.
You can set the `weights_only` parameter in the model configuration (
the content of the pytorch_config.yaml), as:

'''
runner:
  load_weights_only: False
'''

If it's the detector snapshot that's failing to load, place the
`load_weights_only` key under the detector runner:

'''
detector:
    runner:
      load_weights_only: False
'''

You can also set the default `load_weights_only` that will be used when
the `load_weights_only` variable is not set in the `pytorch_config.yaml`
using `deeplabcut.pose_estimation_pytorch.set_load_weights_only(value)`:

'''
from deeplabcut.pose_estimation_pytorch import set_load_weights_only
set_load_weights_only(True)
'''

You can also set the value for `load_weights_only` with a 
`TORCH_LOAD_WEIGHTS_ONLY` environment variable. If you call 
`TORCH_LOAD_WEIGHTS_ONLY=False python -m deeplabcut`, it will launch the
DeepLabCut GUI with the default `load_weights_only` value to False.
If you set this value to `False`, make sure you only load snapshots that
you trust.