79317709

Date: 2024-12-30 13:20:17
Score: 0.5
Natty:
Report link

To solve this and simplify the graph, I disable TorchDynamo by explicitly setting dynamo=False when exporting the model. Here's the updated export function:

def export_ONNX(model):
    tensor_x = torch.randn(1, 1, 32, 16, 32)
    torch.onnx.export(model, (tensor_x,), "my_model.onnx", input_names=["input"], dynamo=False)

From my understanding, torch.onnx.export() uses TorchDynamo, by default, to trace the model. This method results in additional nodes being introduced in the ONNX graph to handle dynamic aspects of the computation.

By setting dynamo=False, the exported ONNX graph aligns more closely with the original PyTorch operations, containing only the essential nodes such as MaxPool, Reshape, Gemm, and Softmax. Figure with the MaxPool, Reshape, Gemm, and Softmax nodes.

This solved the issue for me. Although, I still wonder why using Dynamo does not generate the same graph when I believe it should.

Reasons:
  • Long answer (-0.5):
  • Has code block (-0.5):
  • Self-answer (0.5):
  • Low reputation (1):
Posted by: Pedro Antunes