Improved loading of snapshot weights with` torch.load(..., weights_only=True)` by MMathisLab · Pull Request #2823 · DeepLabCut/DeepLabCut
This pull request offers Improved handling of loading snapshot weights with torch.load(..., weights_only=True). PyTorch snapshots saved with older release candidates could contain some numpy floats, which failed to load with weights_only=True, which can make it annoying to use them as the pytorch_config.yaml needed to be modified with load_weights_only: true keys for both detectors and pose models. In this pull request: the following improvements are made:
Fix the issue with numpy>=1.25
For users with numpy>=1.25 installed, the issue is fixed as the float64 class causing issues can be added to the safe_globals, as done in _add_numpy_to_torch_safe_globals. Current snapshots will be loaded without error, as they are with weights_only=False.
torch.serialization.add_safe_globals([np.dtype, Float64DType, scalar])
This doesn't work in numpy<1.25, as the Float64Dtype did not yet exist (it was a dtype that could only be used internally), and there is no easy way to add np.dtype[np.float64] to the safe globals.
A global variable is set to handle the default weights_only value
The global variable sets the default value given to load_weights_only, when none is specified in the pytorch_config.yaml. This value is controlled through the get_load_weights_only and set_load_weights_only methods, which can be imported through deeplabcut.pose_estimation_pytorch:
>>> from deeplabcut.pose_estimation_pytorch import get_load_weights_only, set_load_weights_only >>> print(get_load_weights_only()) True >>> set_load_weights_only(False) >>> print(get_load_weights_only()) False
When calling torch.load without load_weights_only being specified, get_load_weights_only() is used to get the default value. So when loading snapshots that are known to be safe, set_load_weights_only(False) can be called at the start of a script so that all snapshots are loaded with weights_only=False.
So for users using numpy<1.25 with snapshots that have issues, they can load the snapshots without having to modify the pytorch_config.yaml by just calling from deeplabcut.pose_estimation_pytorch import set_load_weights_only and set_load_weights_only(False) before loading their snapshots.
The initial default load_weights_only value can also be set with an TORCH_LOAD_WEIGHTS_ONLY environment variable, which makes it easier to set this value when working with the GUI. The DeepLabCut GUI can just be launched with TORCH_LOAD_WEIGHTS_ONLY=False python -m deeplabcut, which will set the default load_weights_only value to False.
Function to fix snapshots containing numpy float values
A new fix_snapshot_metadata method is added to replace numpy floats with python floats in existing snapshots.
Improved error message when a snapshot fails to load
The error message when a snapshot fails to load is improved and more descriptive. It now says:
ERROR:root:
Failed to load the snapshot: snapshot-best-200.pt.
If you trust the snapshot that you're trying to load, you can try
calling `Runner.load_snapshot` with `weights_only=False`. See the
error message below for more information and warnings.
You can set the `weights_only` parameter in the model configuration (
the content of the pytorch_config.yaml), as:
'''
runner:
load_weights_only: False
'''
If it's the detector snapshot that's failing to load, place the
`load_weights_only` key under the detector runner:
'''
detector:
runner:
load_weights_only: False
'''
You can also set the default `load_weights_only` that will be used when
the `load_weights_only` variable is not set in the `pytorch_config.yaml`
using `deeplabcut.pose_estimation_pytorch.set_load_weights_only(value)`:
'''
from deeplabcut.pose_estimation_pytorch import set_load_weights_only
set_load_weights_only(True)
'''
You can also set the value for `load_weights_only` with a
`TORCH_LOAD_WEIGHTS_ONLY` environment variable. If you call
`TORCH_LOAD_WEIGHTS_ONLY=False python -m deeplabcut`, it will launch the
DeepLabCut GUI with the default `load_weights_only` value to False.
If you set this value to `False`, make sure you only load snapshots that
you trust.