feature: Torch dependency in sagameker-core to be made optional (5457) by aviruthen · Pull Request #5714 · aws/sagemaker-python-sdk

Description

The torch dependency in sagemaker-core/pyproject.toml is declared as a required dependency ('torch>=1.9.0') but torch is only actually used in two client-side classes: TorchTensorSerializer and TorchTensorDeserializer, both of which already use lazy imports inside init. The torchrun_driver.py imports torch but runs inside the SageMaker training container (not client-side). All other files only reference 'pytorch' as a string. The fix is to: (1) move torch from required dependencies to an optional extras group in pyproject.toml, (2) ensure torch imports in serializer/deserializer use DeferredError pattern (deserializer already has try/except but raises immediately - should use DeferredError), and (3) update the serializer implementations.py and deserializer implementations.py to not eagerly import TorchTensorSerializer/TorchTensorDeserializer at module level.

Related Issue

Related issue: 5457

Changes Made

  • sagemaker-core/pyproject.toml
  • sagemaker-core/src/sagemaker/core/serializers/base.py
  • sagemaker-core/src/sagemaker/core/deserializers/base.py
  • sagemaker-core/tests/unit/serializers/test_torch_optional.py

AI-Generated PR

This PR was automatically generated by the PySDK Issue Agent.

  • Confidence score: 85%
  • Classification: type: feature request
  • SDK version target: V3

Merge Checklist

  • Changes are backward compatible
  • Commit message follows prefix: description format
  • Unit tests added/updated
  • Integration tests added (if applicable)
  • Documentation updated (if applicable)