According to the link in the error message (https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset), you should rather change the implementation of the __iter__
method so that it behaves differently based on which worker calls it, or change the worker_init_fn
(see their two code examples).
Should I modify it after the fact in each worker so that each worker gets a dataset that's 4 times smaller?
Yes, from what I understand, this will make each worker fetch `1 / num_workers` of the dataset.