Ok I just realized what was wrong. I'm just so dumb. During visualisation, I printed the images in a for loop like this. This calls the augmentation function twice, so the images seem differently augmented...
for i in range(2):
ax[i].imshow(vis_pcb_ds[0][i])
ax[i].axis('off')