When i was making this post i cleaned up some logging and argument parsing from the code, which was initially causing the serialization issue. The problematic part in my code was a line where i called the argument parser from the train function as follows:
def train_unet_model(config):
...
args = get_args()
model = UNet(n_channels=3,
n_classes=args.classes,
bilinear=args.bilinear,
base_channels=config["BASE_CHANNELS"],
kernel_size=config["SAMPLING_KERNEL_SIZE"],
use_bias=config["USE_BIAS"],
base_mid_channels=config["BASE_MID_CHANNELS"])
...
and the argument getter function looking like:
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
...
return parser.parse_args()
If you want to keep the parser, then use it outside of the trainable and pass the extracted values down via simple variables.