diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0f0345e --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +./output +./output/* +./data/* +./data \ No newline at end of file diff --git a/plot_deform.py b/plot_deform.py index 08e6c5a..0eb79dd 100644 --- a/plot_deform.py +++ b/plot_deform.py @@ -8,7 +8,7 @@ from absl import app from absl import flags -from matplotlib import animation +from matplotlib import animation, cm import matplotlib.pyplot as plt import math @@ -18,26 +18,29 @@ import torch -root_dir = pathlib.Path(__file__).parent.resolve() -output_dir = os.path.join(root_dir, 'output', 'deforming_plate') -all_subdirs = [os.path.join(output_dir, d) for d in os.listdir(output_dir) if - os.path.isdir(os.path.join(output_dir, d))] -latest_subdir = max(all_subdirs, key=os.path.getmtime) -rollout_path = os.path.join(latest_subdir, 'rollout', 'rollout.pkl') + +os.environ['KMP_DUPLICATE_LIB_OK']='TRUE' +FLAGS = flags.FLAGS + +flags.DEFINE_string('path_prefix', 'output/deforming_plate', 'root dir to the output files relative to this script.') +flags.DEFINE_string('rollout_path', None, 'specific rollout path to plot. Will plot all if set to None.') def main(unused_argv): - path_prefix = 'E:\\meshgraphnets\\output\\deforming_plate\\' - path_suffix = 'rollout\\rollout.pkl' - rollout_paths = ['Sat-Feb-19-15-44-13-2022'] - # path_prefix = '/home/kit/anthropomatik/sn2444/meshgraphnets/output/deforming_plate/' - # rollout_paths = ['Mon-Jan-31-05-04-38-2022/2', 'Mon-Jan-31-05-10-30-2022/2', 'Mon-Jan-31-05-20-38-2022/2', 'Mon-Jan-31-05-35-42-2022/2', 'Mon-Jan-31-05-39-05-2022/2', 'Mon-Jan-31-08-28-21-2022/2'] + path_prefix = FLAGS.path_prefix + path_suffix = 'rollout.pkl' + + if not FLAGS.rollout_path: + rollout_paths = [d.name for d in os.scandir(path_prefix) if d.is_dir()] + else: + rollout_paths = [FLAGS.rollout_path] + for rollout_path in rollout_paths: run_path = os.path.join(path_prefix, rollout_path) all_subdirs = [os.path.join(run_path, d) for d in os.listdir(run_path) if os.path.isdir(os.path.join(run_path, d))] save_path = max(all_subdirs, key=os.path.getmtime) - data_path = os.path.join(path_prefix, save_path, path_suffix) + data_path = os.path.join(save_path, path_suffix) print("Ploting run", save_path) with open(data_path, 'rb') as fp: rollout_data = pickle.load(fp) @@ -83,17 +86,17 @@ def animate(num): original_pos = torch.squeeze(rollout_data[traj]['gt_pos'], dim=0)[step].to('cpu') faces = torch.squeeze(rollout_data[traj]['faces'], dim=0)[step].to('cpu') - - ax_origin.plot_trisurf(original_pos[:, 0], original_pos[:, 1], faces, original_pos[:, 2], shade=True, alpha=0.3) - ax_pred.plot_trisurf(pos[:, 0], pos[:, 1], faces, pos[:, 2], shade=True, alpha=0.3) + ax_origin.plot_trisurf(original_pos[:, 0], original_pos[:, 1], faces, original_pos[:, 2], shade=False, alpha=0.3) + ax_pred.plot_trisurf(pos[:, 0], pos[:, 1], faces, pos[:, 2], shade=False, alpha=0.3) ax_origin.set_title('ORIGIN Trajectory %d Step %d' % (traj, step)) ax_pred.set_title('PRED Trajectory %d Step %d' % (traj, step)) return fig, anima = animation.FuncAnimation(fig, animate, frames=math.floor(num_frames * 10), interval=100) - # writervideo = animation.FFMpegWriter(fps=30) - # anima.save(os.path.join(save_path, 'ani.mp4'), writer=writervideo) + writervideo = animation.PillowWriter(fps=30) + anima.save(os.path.join(save_path, 'ani.gif'), writer=writervideo) + plt.show(block=True) diff --git a/requirements.txt b/requirements.txt index 97d419d..d56a441 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ +numpy==1.21.2 matplotlib==3.4.3 +torch==1.9.0 torch_geometric==2.0.1 torch_cluster==1.5.9 -torch==1.9.0 -numpy==1.21.2 absl_py==0.13.0 tfrecord==1.14.1 torch_scatter==2.0.8 diff --git a/run_model.py b/run_model.py index c7c50f3..aafef21 100644 --- a/run_model.py +++ b/run_model.py @@ -251,7 +251,8 @@ def process_trajectory(trajectory_data, params, model_type, dataset_dir, add_tar trajectory = {} # decode bytes into corresponding dtypes for key, value in trajectory_data.items(): - raw_data = value.numpy().tobytes() + # raw_data = value.numpy().tobytes() + raw_data = value[0][0] mature_data = np.frombuffer(raw_data, dtype=getattr(np, dtypes[key])) mature_data = torch.from_numpy(mature_data).to(device) reshaped_data = torch.reshape(mature_data, shapes[key])