feature: Torch dependency in sagameker-core to be made optional (5457) by aviruthen · Pull Request #5714 · aws/sagemaker-python-sdk
Description
The torch dependency in sagemaker-core/pyproject.toml is declared as a required dependency ('torch>=1.9.0') but torch is only actually used in two client-side classes: TorchTensorSerializer and TorchTensorDeserializer, both of which already use lazy imports inside init. The torchrun_driver.py imports torch but runs inside the SageMaker training container (not client-side). All other files only reference 'pytorch' as a string. The fix is to: (1) move torch from required dependencies to an optional extras group in pyproject.toml, (2) ensure torch imports in serializer/deserializer use DeferredError pattern (deserializer already has try/except but raises immediately - should use DeferredError), and (3) update the serializer implementations.py and deserializer implementations.py to not eagerly import TorchTensorSerializer/TorchTensorDeserializer at module level.
Related Issue
Related issue: 5457
Changes Made
sagemaker-core/pyproject.tomlsagemaker-core/src/sagemaker/core/serializers/base.pysagemaker-core/src/sagemaker/core/deserializers/base.pysagemaker-core/tests/unit/serializers/test_torch_optional.py
AI-Generated PR
This PR was automatically generated by the PySDK Issue Agent.
- Confidence score: 85%
- Classification: type: feature request
- SDK version target: V3
Merge Checklist
- Changes are backward compatible
- Commit message follows
prefix: descriptionformat - Unit tests added/updated
- Integration tests added (if applicable)
- Documentation updated (if applicable)