I'm trying to deploy a custom PyTorch model to a SageMaker Multi-Model Endpoint (MME). My model is saved as a state_dict using torch.save(), so it requires a custom inference.py script to load the model architecture before loading the weights.
The problem is that the endpoint fails to load any model. The CloudWatch logs show that SageMaker is ignoring my custom inference.py and falling back to its default handler. This default handler expects a TorchScript model, which mine is not, leading to a ModelLoadError ("Please ensure model is saved using torchscript").
This happens even when I replace my inference.py with a simple "hello world" script. This makes me believe the issue is in the deployment configuration, not the inference code itself.
My Setup:
Models: My model artifacts (e.g., model_A.tar.gz, model_B.tar.gz) are located in an S3 prefix. Each .tar.gz file contains only the model weights (model.pt) and other data artifacts (.json, .csv), with no Python code inside.
Deployment: I am using the SageMaker Python SDK.
File Structure:
/project_root
├── deploy.py
└── code/
├── __init__.py
├── inference.py
├── model.py
└── requirements.txt
Deployment Script (deploy.py):
import sagemaker
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.multidatamodel import MultiDataModel
import time
def main():
# Generic configuration
s3_model_repository = "s3://my-bucket/my-mme-models/"
endpoint_name = f"my-mme-endpoint-{time.strftime('%Y%m%d-%H%M%S')}"
role_arn = "arn:aws:iam::123456789012:role/MySageMakerRole"
instance_type = 'ml.g4dn.xlarge'
sagemaker_session = sagemaker.Session()
# 1. Define the container and the code to be used for inference
pytorch_model_container = PyTorchModel(
entry_point='inference.py',
source_dir='./code',
role=role_arn,
sagemaker_session=sagemaker_session,
model_data=None, # This is critical for MME
framework_version='2.0',
py_version='py310',
)
# 2. Create the MultiDataModel object pointing to the S3 model repository
multi_data_model = MultiDataModel(
name=f"my-mme-container-def-{time.strftime('%Y%m%d-%H%M%S')}",
model_data_prefix=s3_model_repository,
model=pytorch_model_container
)
# 3. Deploy the endpoint
try:
multi_data_model.deploy(
initial_instance_count=1,
instance_type=instance_type,
endpoint_name=endpoint_name,
)
print(f"Deployment successful. Endpoint Name: {endpoint_name}")
except Exception as e:
print(f"Deployment failed: {e}")
if __name__ == "__main__":
main()
Inference Script (code/inference.py):
# code/inference.py
import logging
import traceback
logging.basicConfig(level=logging.INFO, format='%(asctime)s - [CUSTOM_HANDLER] - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
try:
logger.info("Attempting to import MyCustomModel from model.py...")
from model import MyCustomModel
logger.info("Successfully imported MyCustomModel.")
except Exception as e:
logger.critical(f"CRITICAL IMPORT FAILED: {e}")
raise
def initialize(context):
"""
Loads the model artifacts and instantiates the model.
"""
try:
properties = context.system_properties
model_dir = properties.get("model_dir")
logger.info(f"--- [initialize] Initializing model from '{model_dir}' ---")
# ... logic to load artifacts and model ...
# model = MyCustomModel(...)
# model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pt')))
logger.info(f"--- [initialize] Model loaded successfully. ---")
except Exception as e:
logger.error("!!!!!! CRITICAL ERROR DURING initialize() !!!!!!")
logger.error(f"Stack Trace:\n{traceback.format_exc()}")
raise
CloudWatch Logs:
Regardless of the configuration, I never see my [CUSTOM_HANDLER] log messages. The worker process dies almost instantly, and the logs always show the same error from the SageMaker default handler:
Generated code
... [INFO] W-9000-...-stdout MODEL_LOG - Torch worker started.
... [INFO] W-9000-...-stdout MODEL_LOG - Backend worker process died.
... [INFO] W-9000-...-stdout MODEL_LOG - Traceback (most recent call last):
... [INFO] W-9000-...-stdout MODEL_LOG - File "/opt/conda/lib/python3.10/site-packages/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py", line 80, in default_model_fn
... [INFO] W-9000-...-stdout MODEL_LOG - raise ModelLoadError(
sagemaker_pytorch_serving_container.default_pytorch_inference_handler.ModelLoadError: Failed to load /opt/ml/models/.../model.pt. Please ensure model is saved using torchscript.
Question:
What could be causing SageMaker to completely ignore my source_dir and entry_point configuration and fall back to its default handler? I have followed the official examples for PyTorch MME, but my custom inference code is never executed
What I tried:
I am attempting to deploy a PyTorch Multi-Model Endpoint using the standard PyTorchModel class with a custom entry_point and a source_dir containing my inference code.
Initial Attempt: I configured PyTorchModel with entry_point='inference.py', source_dir='./code', and model_data=None, which is the standard pattern for MME.
Isolation Test: To rule out errors within my script, I replaced my full inference.py with a minimal "hello world" script that only contained logging messages and no complex imports or logic.
Configuration Variants: I have also tried using the lower-level sagemaker.model.Model class with explicit environment variables like SAGEMAKER_PROGRAM and MMS_DEFAULT_HANDLER to force the container to use my script.
What actually resulted:
In every attempt, my custom script is completely ignored. I never see any of my custom log messages. The container worker dies almost instantly, and the CloudWatch logs always show the same error from SageMaker's default handler: sagemaker_pytorch_serving_container.default_pytorch_inference_handler.ModelLoadError: ... Please ensure model is saved using torchscript.
This proves that SageMaker is not executing my code and is falling back to its default model loading mechanism, which is incompatible with my torch.save() model artifacts.