3

I want to debug layer-by-layer to see where the ONNX model starts deviating from the PyTorch model outputs.

I can extract intermediate outputs in PyTorch using forward hooks, like:

def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

activations = {}
model.bert.encoder.layer[0].output.register_forward_hook(get_activation("layer_0"))

But I’m not sure how to extract comparable intermediate results from the ONNX model, since ONNX Runtime doesn’t have a straightforward hook system.

Is there a reliable way to trace or dump intermediate layer outputs from an ONNX model (for example, via onnxruntime.InferenceSession.run() with internal node names)?

Ideally, I’d like to compare PyTorch vs ONNX activations numerically to pinpoint the exact mismatch layer.

Any best practices or tools (like onnxruntime.tools.symbolic_shape_infer or onnxsim) for this kind of debugging?

1 Answer 1

1

Yes—do it by exposing intermediates as graph outputs and then asking ORT for those outputs. Two solid ways:

A) Extract a subgraph ending at a tensor you care about

import onnx
from onnx.utils import extract_model

in_model = "model.onnx"
out_model = "dbg_layer6.onnx"

m = onnx.load(in_model)
# Find the name of the tensor you want (e.g., output of layer 6); use Netron or print node names.
target = "bert.encoder.layer.6.output.LayerNorm_Output_0"

# Keep original inputs, but make `target` the graph output
extract_model(
    in_model, out_model,
    input_names=[i.name for i in m.graph.input],
    output_names=[target]
)

Then run:

import onnxruntime as ort, numpy as np
so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL  # apples-to-apples
sess = ort.InferenceSession("dbg_layer6.onnx", sess_options=so, providers=["CPUExecutionProvider"])
onnx_act = sess.run(None, {"input_ids": ids, "attention_mask": mask})[0]

B) Promote multiple intermediates to outputs in-place

import onnx
from onnx import helper, shape_inference

model = onnx.load("model.onnx")
model = shape_inference.infer_shapes(model)

want = [
    "bert.embeddings.LayerNorm_Output_0",
    "bert.encoder.layer.0.output.LayerNorm_Output_0",
    "bert.encoder.layer.1.output.LayerNorm_Output_0",
]
existing = {o.name for o in model.graph.output}
vi_map = {vi.name: vi for vi in list(model.graph.value_info)+list(model.graph.input)}

for name in want:
    if name not in existing:
        vi = vi_map.get(name, helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, None))
        model.graph.output.extend([vi])

onnx.save(model, "model_with_intermediates.onnx")

Then:

sess = ort.InferenceSession("model_with_intermediates.onnx", providers=["CPUExecutionProvider"])
outs = sess.run(want, feeds)   # request those internal tensors by name

Compare with PyTorch hooks

import torch, numpy as np

acts = {}
def hook(name):
    def _h(m, i, o): acts[name] = o.detach().cpu().numpy()
    return _h

# register on the matching PyTorch layers
model.bert.encoder.layer[1].output.LayerNorm.register_forward_hook(hook(want[2]))
model.eval(); torch_out = model(**pt_inputs)

np.testing.assert_allclose(acts[want[2]], onnx_act, rtol=1e-4, atol=1e-4)
Sign up to request clarification or add additional context in comments.

1 Comment

Nice trick, I have been using ONNX for so long but never know about this

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.