0

I have a model trained in TF2 and was trying to deploy it using TensorFlowModel using sagemaker. When using the model, I also need to load other utilities, such as json to map inputs and a tokenizer (transformer) as well.

When using HuggingFace or SkLearn models, I could define a model_fn to load the model as well as the other utilities. How do I do this for TF2 models?

I have gone through the documentation but can't find a way to do so. I tried to add a model_fn to my custom inference script, but the script doesn't seem to invoke it.

Here is the code I am using

import tensorflow as tf
import numpy as np
import json

# Load the model
def model_fn(model_dir):
    """
    Load and return a TensorFlow SavedModel.
    Args:
        model_dir (str): The directory where the model is saved.
    Returns:
        A TensorFlow SavedModel.
    """
    model = tf.=.load_model(model_dir)
    return model

# Parse input data
def input_fn(input_data, content_type):
    """
    Deserialize the input data for the model.
    Args:
        input_data: The input data as received from the client.
        content_type: The MIME type of the input data.
    Returns:
        A numpy array to be passed to the model for inference.
    """
    if content_type == "application/json":
        input_dict = json.loads(input_data)
        return np.array(input_dict["instances"])
    elif content_type == "text/csv":
        return np.loadtxt(input_data.splitlines(), delimiter=",")
    else:
        raise ValueError(f"Unsupported content type: {content_type}")

# Handle preprocessing of inputs
def input_handler(data, context):
    """
    Process the request payload into the format required for `input_fn`.
    Args:
        data: The payload of the request.
        context: The context of the request, including content type.
    Returns:
        Preprocessed data for the model.
    """
    content_type = context.get_content_type()
    return input_fn(data.read().decode('utf-8'), content_type)

# Perform inference
def predict_fn(input_data, model):
    """
    Perform inference on the input data using the loaded model.
    Args:
        input_data: The preprocessed input data.
        model: The TensorFlow model.
    Returns:
        The raw predictions from the model.
    """
    predictions = model.predict(input_data)
    return predictions

# Serialize the output
def output_fn(prediction, accept):
    """
    Serialize the prediction output.
    Args:
        prediction: The raw prediction output from the model.
        accept: The requested MIME type of the response.
    Returns:
        The serialized response.
    """
    if accept == "application/json":
        response = {"predictions": prediction.tolist()}
        return json.dumps(response), "application/json"
    elif accept == "text/csv":
        response = "\n".join([",".join(map(str, row)) for row in prediction])
        return response, "text/csv"
    else:
        raise ValueError(f"Unsupported accept type: {accept}")

# Handle postprocessing of outputs
def output_handler(prediction, context):
    """
    Process the model's raw predictions into the response format.
    Args:
        prediction: The raw predictions from the model.
        context: The context of the response, including accept type.
    Returns:
        The formatted response.
    """
    accept = context.get_accept_header()
    response, content_type = output_fn(prediction, accept)
    context.set_response_content_type(content_type)
    return response


2
  • could you please share the steps to create and deploy the model on the endpoint? Commented Dec 29, 2024 at 14:36
  • I believe this is not possible, what i have done for this is to create my own inference script from scratch which uses Flask and defines the required endpoints. I was also having trouble downloading extra files using boto3, due to an issue with the TensorflowModel image. Commented Dec 30, 2024 at 14:45

0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.