fix: Model builder unable to (5667) by aviruthen · Pull Request #5729 · aws/sagemaker-python-sdk
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤖 AI Code Review
This PR adds a utility function to resolve missing BaseModel fields (hub_content_version, recipe_name) when they are Unassigned. While the approach is reasonable, there are several issues: a syntax error from a missing newline before an existing constant, a broad except Exception catch that violates SDK error handling conventions, the function lacks type annotations, and the function is never actually called in the build() flow as described in the PR description. The tests also have issues with mock patching targets.
|
|
||
| return base_model | ||
| _DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py" | ||
| _NO_JS_MODEL_EX = "HuggingFace JumpStart Model ID not detected. Building for HuggingFace Model ID." |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Missing newline before existing code. The new function's closing return base_model is immediately followed by _DJL_MODEL_BUILDER_ENTRY_POINT on the same logical block without a blank line separator. This will cause a syntax/parse error since _DJL_MODEL_BUILDER_ENTRY_POINT will appear to be inside the function body.
Add a blank line:
return base_model _DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py"
| from sagemaker.serve.model_server.triton.config_template import CONFIG_TEMPLATE | ||
|
|
||
| SPECULATIVE_DRAFT_MODEL = "/opt/ml/additional-model-data-sources" | ||
|
|
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing type annotations. Per SDK conventions (PEP 484), all new public functions must have type annotations for parameters and return types. Please add them:
def resolve_base_model_fields(base_model: BaseModel | None) -> BaseModel | None:
(Use the appropriate BaseModel type from sagemaker-core, not Pydantic's BaseModel.)
| except Exception as e: | ||
| logger.warning( | ||
| "Failed to resolve hub_content_version for hub_content_name='%s' " | ||
| "from SageMakerPublicHub. You may need to set it manually. Error: %s", |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bare except Exception violates SDK error handling conventions. The SDK coding standards say to catch specific exceptions rather than broad Exception. At minimum, catch the expected boto3 ClientError and any known sagemaker-core exceptions. If you must catch broadly for resilience, at least log the exception type:
from botocore.exceptions import ClientError try: ... except (ClientError, ImportError) as e: logger.warning(...)
| from sagemaker.serve.model_server.triton.config_template import CONFIG_TEMPLATE | ||
|
|
||
| SPECULATIVE_DRAFT_MODEL = "/opt/ml/additional-model-data-sources" | ||
|
|
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function is defined but never integrated into the build() flow. The PR description states this resolution logic should be "called during the build() flow in model_builder.py before the values are consumed," but model_builder.py is not modified in this PR. Without the call site, this function is dead code and the bug described in issue 5667 is not actually fixed. Please add the call to resolve_base_model_fields() in the appropriate location in model_builder.py.
| hub_content_name, | ||
| ) | ||
| hc = HubContent.get( | ||
| hub_content_type="Model", |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lazy import inside the function body. The from sagemaker.core.resources import HubContent import is done inside the try block. While conditional imports are acceptable for optional dependencies, sagemaker.core.resources is a core dependency (not optional). Consider moving this import to the top of the file with the other sagemaker-core imports, or at least document why it's deferred (e.g., circular import avoidance).
|
|
||
|
|
||
| class FakeBaseModel: | ||
| """Fake BaseModel for testing.""" |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test uses a plain class instead of Pydantic BaseModel. The FakeBaseModel is a plain Python class, but in production the BaseModel object comes from sagemaker-core and is likely a Pydantic model. Consider using a MagicMock(spec=...) or a Pydantic model to more accurately represent the real object, especially since Pydantic models have different attribute-setting behavior.
| fake_hc = FakeHubContent(hub_content_version="2.5.0") | ||
| mock_hub_content_cls.get.return_value = fake_hc | ||
|
|
||
| base_model = FakeBaseModel( |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mock patch target mismatch. You're patching sagemaker.core.resources.HubContent at the class decorator level, but then also patching sagemaker.serve.model_builder_utils.HubContent inside the test body with a with statement. The outer @patch is unnecessary and creates confusion. Since the function does from sagemaker.core.resources import HubContent locally, you should only patch sagemaker.core.resources.HubContent — but the import inside the function creates a local reference, so you actually need to patch it where it's looked up. Consider restructuring to only use one consistent patch target.
| assert result.hub_content_version == "1.0.0" | ||
| assert result.recipe_name == "my-recipe" | ||
|
|
||
| @patch("sagemaker.core.resources.HubContent") |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test test_resolve_with_all_fields_present_no_api_call doesn't assert the mock was NOT called. You create a mock but never assert mock_hc.get.assert_not_called(). The with patch(...) block creates the mock but the assertion that no API call was made is missing:
with patch("sagemaker.serve.model_builder_utils.HubContent", autospec=True) as mock_hc: result = resolve_base_model_fields(base_model) mock_hc.get.assert_not_called()
Also note: since the function imports HubContent locally inside the if version_missing: block, this mock may not even be effective.
| """Test that missing recipe_name logs a warning but does not crash.""" | ||
| base_model = FakeBaseModel( | ||
| hub_content_name="huggingface-reasoning-qwen3-32b", | ||
| hub_content_version="1.0.0", |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test for test_resolve_missing_recipe_name_logs_warning doesn't verify the warning was actually logged. The docstring says it "logs a warning" but there's no assertion on the logger. Consider adding:
with patch("sagemaker.serve.model_builder_utils.logger") as mock_logger: result = resolve_base_model_fields(base_model) mock_logger.warning.assert_called_once()