Yes—do it by exposing intermediates as graph outputs and then asking ORT for those outputs. Two solid ways:
A) Extract a subgraph ending at a tensor you care about
import onnx
from onnx.utils import extract_model
in_model = "model.onnx"
out_model = "dbg_layer6.onnx"
m = onnx.load(in_model)
# Find the name of the tensor you want (e.g., output of layer 6); use Netron or print node names.
target = "bert.encoder.layer.6.output.LayerNorm_Output_0"
# Keep original inputs, but make `target` the graph output
extract_model(
in_model, out_model,
input_names=[i.name for i in m.graph.input],
output_names=[target]
)
Then run:
import onnxruntime as ort, numpy as np
so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL # apples-to-apples
sess = ort.InferenceSession("dbg_layer6.onnx", sess_options=so, providers=["CPUExecutionProvider"])
onnx_act = sess.run(None, {"input_ids": ids, "attention_mask": mask})[0]
B) Promote multiple intermediates to outputs in-place
import onnx
from onnx import helper, shape_inference
model = onnx.load("model.onnx")
model = shape_inference.infer_shapes(model)
want = [
"bert.embeddings.LayerNorm_Output_0",
"bert.encoder.layer.0.output.LayerNorm_Output_0",
"bert.encoder.layer.1.output.LayerNorm_Output_0",
]
existing = {o.name for o in model.graph.output}
vi_map = {vi.name: vi for vi in list(model.graph.value_info)+list(model.graph.input)}
for name in want:
if name not in existing:
vi = vi_map.get(name, helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, None))
model.graph.output.extend([vi])
onnx.save(model, "model_with_intermediates.onnx")
Then:
sess = ort.InferenceSession("model_with_intermediates.onnx", providers=["CPUExecutionProvider"])
outs = sess.run(want, feeds) # request those internal tensors by name
Compare with PyTorch hooks
import torch, numpy as np
acts = {}
def hook(name):
def _h(m, i, o): acts[name] = o.detach().cpu().numpy()
return _h
# register on the matching PyTorch layers
model.bert.encoder.layer[1].output.LayerNorm.register_forward_hook(hook(want[2]))
model.eval(); torch_out = model(**pt_inputs)
np.testing.assert_allclose(acts[want[2]], onnx_act, rtol=1e-4, atol=1e-4)