To solve this and simplify the graph, I disable TorchDynamo by explicitly setting dynamo=False
when exporting the model. Here's the updated export function:
def export_ONNX(model):
tensor_x = torch.randn(1, 1, 32, 16, 32)
torch.onnx.export(model, (tensor_x,), "my_model.onnx", input_names=["input"], dynamo=False)
From my understanding, torch.onnx.export()
uses TorchDynamo, by default, to trace the model. This method results in additional nodes being introduced in the ONNX graph to handle dynamic aspects of the computation.
By setting dynamo=False
, the exported ONNX graph aligns more closely with the original PyTorch operations, containing only the essential nodes such as MaxPool, Reshape, Gemm, and Softmax.
This solved the issue for me. Although, I still wonder why using Dynamo does not generate the same graph when I believe it should.