diff --git a/mnist/main.py b/mnist/main.py index 0671b8e336..dee5a384cb 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -107,6 +107,7 @@ def main(): test_kwargs = {'batch_size': args.test_batch_size} if use_accel: accel_kwargs = {'num_workers': 1, + 'persistent_workers': True, 'pin_memory': True, 'shuffle': True} train_kwargs.update(accel_kwargs)