Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions Run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def str2bool(v):

parser.add_argument("--hide-debug", action="store_true",
help="Supress the printing of each 3d location")
parser.add_argument("--run-on-cpu", default=False, action="store_true",
help="Runs the code on CPU")



def plot_regressed_3d_bbox(img, cam_to_img, box_2d, dimensions, alpha, theta_ray, img_2d=None):
Expand Down Expand Up @@ -85,8 +88,12 @@ def main():
print('Using previous model %s'%model_lst[-1])
my_vgg = vgg.vgg19_bn(pretrained=True)
# TODO: load bins from file or something
model = Model.Model(features=my_vgg.features, bins=2).cuda()
checkpoint = torch.load(weights_path + '/%s'%model_lst[-1])
if FLAGS.run_on_cpu == True :
model = Model.Model(features=my_vgg.features, bins=2)#.cuda()
checkpoint = torch.load(weights_path + '/%s'%model_lst[-1], map_location=torch.device('cpu'))
else:
model = Model.Model(features=my_vgg.features, bins=2).cuda
checkpoint = torch.load(weights_path + '/%s'%model_lst[-1])
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

Expand Down Expand Up @@ -154,7 +161,11 @@ def main():
box_2d = detection.box_2d
detected_class = detection.detected_class

input_tensor = torch.zeros([1,3,224,224]).cuda()
if FLAGS.run_on_cpu == True :

input_tensor = torch.zeros([1,3,224,224]) #.cuda()
else:
input_tensor = torch.zeros([1,3,224,224]).cuda()
input_tensor[0,:,:,:] = input_img

[orient, conf, dim] = model(input_tensor)
Expand Down