With some hints from a colleague I at least got want I wanted. In the end I did two changes on the code above:
I simplified the Up.forward()
so it would be properly translated to ONNX, I believe:
def forward(self, x1, x2):
x1 = self.up(x1)
# input is Channel Height Width
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
padding_left = diffX // 2
padding_right = diffX - padding_left
padding_top = diffY // 2
padding_bottom = diffY - padding_top
x1 = F.pad(x1, [padding_left, padding_right, padding_top, padding_bottom])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
And I did:
dummy_input = torch.randn(1, 1, 768, 768) # H[400 to 900] x W (512, 768, 1024, 1536)
As I realised the model was trained with that input shape.
torch.dynamo
did not work for me.