diff --git a/Run.py b/Run.py index 793eafb..607c2c9 100644 --- a/Run.py +++ b/Run.py @@ -55,6 +55,9 @@ def str2bool(v): parser.add_argument("--hide-debug", action="store_true", help="Supress the printing of each 3d location") +parser.add_argument("--run-on-cpu", default=False, action="store_true", + help="Runs the code on CPU") + def plot_regressed_3d_bbox(img, cam_to_img, box_2d, dimensions, alpha, theta_ray, img_2d=None): @@ -85,8 +88,12 @@ def main(): print('Using previous model %s'%model_lst[-1]) my_vgg = vgg.vgg19_bn(pretrained=True) # TODO: load bins from file or something - model = Model.Model(features=my_vgg.features, bins=2).cuda() - checkpoint = torch.load(weights_path + '/%s'%model_lst[-1]) + if FLAGS.run_on_cpu == True : + model = Model.Model(features=my_vgg.features, bins=2)#.cuda() + checkpoint = torch.load(weights_path + '/%s'%model_lst[-1], map_location=torch.device('cpu')) + else: + model = Model.Model(features=my_vgg.features, bins=2).cuda + checkpoint = torch.load(weights_path + '/%s'%model_lst[-1]) model.load_state_dict(checkpoint['model_state_dict']) model.eval() @@ -154,7 +161,11 @@ def main(): box_2d = detection.box_2d detected_class = detection.detected_class - input_tensor = torch.zeros([1,3,224,224]).cuda() + if FLAGS.run_on_cpu == True : + + input_tensor = torch.zeros([1,3,224,224]) #.cuda() + else: + input_tensor = torch.zeros([1,3,224,224]).cuda() input_tensor[0,:,:,:] = input_img [orient, conf, dim] = model(input_tensor)