-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_samples.py
More file actions
135 lines (121 loc) · 7.11 KB
/
generate_samples.py
File metadata and controls
135 lines (121 loc) · 7.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import matplotlib.pyplot as plt
import numpy as np
from models.generate import generate_ddpm,generate_ddim,generate_shift
from models.unet import TrajectoryDenoiser_CondEmbed,TrajectoryDenoiser_CondMerge,TrajectoryDenoiser_Shift
from models.diffusion.ddpm import DDPM
from models.diffusion.resshift import ResShift
from utils.datasets_utils import get_dataset,traj_dataset,setup_loo_experiment,collate_fn_padd
from utils.config import load_config
import logging,random
import cv2,os
# Ugly but avoid conflict between opencv and matplotlib
for k, v in os.environ.items():
if k.startswith("QT_") and "cv2" in v:
del os.environ[k]
from utils.constants import (
FRAMES_IDS, KEY_IDX, OBS_NEIGHBORS, OBS_TRAJ, OBS_TRAJ_VEL, OBS_TRAJ_ACC, OBS_TRAJ_THETA, PRED_TRAJ, PRED_TRAJ_VEL, PRED_TRAJ_ACC,REFERENCE_IMG,PED_IDS,
TRAIN_DATA_STR, TEST_DATA_STR, VAL_DATA_STR, MUN_POS_CSV, DATASETS_DIR, SUBDATASETS_NAMES
)
# Gets a testing batch of trajectories starting at the same frame (for visualization)
def get_testing_batch(test_data,config):
# A trajectory id
randomtrajId = np.random.randint(len(test_data),size=1)[0]
# Last observed frame id for a random trajectory in the testing dataset
frame_id = test_data.Frame_Ids[randomtrajId][7]
idx = np.where((test_data.Frame_Ids[:,7]==frame_id))[0]
ds_path = DATASETS_DIR[config["id_dataset"]]
ds_names = SUBDATASETS_NAMES[config["id_dataset"]][config["id_test"]]
# Get the video corresponding to the testing
cap = cv2.VideoCapture(ds_path+ds_names+'/video.avi')
frame = 0
while(cap.isOpened()):
ret, test_bckgd = cap.read()
if frame == frame_id:
break
frame = frame + 1
# Form the batch
return frame_id, traj_dataset(*(test_data[idx])), test_bckgd
logging.basicConfig(format='%(levelname)s: %(message)s',level=20)
# Device to use later on
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info("Using device: "+torch.cuda.get_device_name(DEVICE))
# Load configuation file
#config = load_config("ethucy-conditional-past-social.yaml")
config = load_config("ethucy-conditional-past.yaml")
#config = load_config("ethucy-conditional-past-shift.yaml")
# Load configuation file (unconditional model)
#config = load_config("ethucy-unconditional.yaml")
# Get the data
# batched_train_data,batched_val_data,batched_test_data,homography,reference_image = get_dataset(config["dataset"])
training_data, validation_data, testing_data, test_homography = setup_loo_experiment(config["dataset"])
test_data = traj_dataset(testing_data[OBS_TRAJ_VEL ], testing_data[PRED_TRAJ_VEL], testing_data[OBS_TRAJ], testing_data[PRED_TRAJ], Frame_Ids=testing_data[FRAMES_IDS])
frame_id, batch, test_bckgd = get_testing_batch(test_data,config["dataset"])
batched_test_data = torch.utils.data.DataLoader(batch,batch_size=len(batch),collate_fn=collate_fn_padd)
# Instantiate the denoising model
if config["diffusion"]["sampler"] == "shift":
denoiser = TrajectoryDenoiser_Shift(config["model"])
else:
denoiser = TrajectoryDenoiser_CondEmbed(config["model"])
save_path = config["model"]["save_dir"]+(config["model"]["model_name"].format(config["model"]["condition"],config["train"]["epochs"]))
denoiser.load_state_dict(torch.load(save_path)['model'])
denoiser.to(DEVICE)
# Instantiate the diffusion model
timesteps = config["diffusion"]["timesteps"]
if config["diffusion"]["sampler"] != "shift":
diffusionmodel = DDPM(timesteps = timesteps)
else:
diffusionmodel = ResShift(timesteps = timesteps,kappa=config["diffusion"]["kappa"])
diffusionmodel.to(DEVICE)
step = config["diffusion"]["trajs_at_a_time"]
# Get a batch
for batch in batched_test_data:
past_velocities, future_velocities, past_positions, future_positions, neighbors, __, neighbors_mask = batch
# We process step trajectories at a time
id_trajectory = 0
past_positions = past_positions[id_trajectory:id_trajectory+step,:,:]
future_positions = future_positions[id_trajectory:id_trajectory+step,:,:]
if neighbors is not None:
neighbors = neighbors[id_trajectory:id_trajectory+step,:,:,:]
neighbors_mask = neighbors_mask[id_trajectory:id_trajectory+step,:,:]
neighbors = neighbors.to(DEVICE)
neighbors_mask = neighbors_mask.to(DEVICE)
past_velocities = past_velocities[id_trajectory:id_trajectory+step,:,:]
past_velocities = past_velocities.to(DEVICE)
future_velocities = future_velocities[id_trajectory:id_trajectory+step,:,:]
future_velocities = future_velocities.to(DEVICE)
if config["diffusion"]["sampler"] == "ddpm":
x = generate_ddpm(denoiser, diffusionmodel, past_velocities, neighbors,neighbors_mask,config, DEVICE).cpu() # Sample process
elif config["diffusion"]["sampler"] == "ddim":
taus = np.arange(0,timesteps,config["diffusion"]["ddim_divider"])
x = generate_ddim(denoiser,taus,diffusionmodel,past_velocities,neighbors,neighbors_mask,config,DEVICE).cpu()
print(x.shape)
elif config["diffusion"]["sampler"] == "shift":
future_rough = torch.zeros_like(future_velocities)
future_rough[:,:,:]= past_velocities[:,-2:-1,:]
x = generate_shift(denoiser,backward_sampler=diffusionmodel,past=past_velocities,rough=future_rough,config=config,device=DEVICE).cpu()
idx = 0
# Plot and show samples from the first element of the batch
_, ax = plt.subplots(1, 1, figsize=(10, 9), facecolor='white')
# Display the results column by column
for j in range(config["diffusion"]["nsamples"]):
traj = 0.4*np.cumsum(x[idx*config["diffusion"]["nsamples"]+j,:,:].permute(1,0),axis=0)
traj = traj+past_positions[idx,-1,:]
ax.plot([past_positions[idx,-1,0],traj[0,0]],[past_positions[idx,-1,1],traj[0,1]],'g',alpha=0.4)
ax.plot(traj[:,0],traj[:,1],'g',alpha=0.4)
# Past and ground truth
ax.plot(past_positions[idx,:,0],past_positions[idx,:,1],'b')
ax.plot(future_positions[idx,:,0],future_positions[idx,:,1],'r')
# Plot neighbors
if neighbors is not None:
neighbors = neighbors[idx].cpu()
neighbors[:,:,:2] = neighbors[:,:,:2]+past_positions[idx,-1,:]
for j in range(neighbors.shape[0]):
for k in range(neighbors.shape[1]):
if neighbors_mask[idx,j,k]:
ax.plot(neighbors[j,k,0],neighbors[j,k,1],'ro',alpha=0.1*(k+1))
ax.grid(False)
ax.set_aspect('equal', 'box')
plt.suptitle("Forward Diffusion Process, Conditional Model", y=1.0)
plt.axis("off")
plt.show()