Well, I (and DeepSeek) found a solution for this problem, I hope it will be helpful for someone else.
def stable_implementation(A, B):
log_S_A = torch.logsumexp(A, dim=1) # Shape: (bs, m, m)
log_S_B = torch.logsumexp(B, dim=1) # Shape: (bs, m, m)
combined = log_S_A.unsqueeze(3) + log_S_B.unsqueeze(1) # Shape: (bs, m, m, m)
out = torch.logsumexp(combined, dim=2) # Shape: (bs, m, m)
return out