Commit 25db1cc5 authored by Clément Pinard's avatar Clément Pinard
Browse files

enhance timer for inference toolkit

parent 44ca0766
...@@ -12,6 +12,9 @@ class Timer: ...@@ -12,6 +12,9 @@ class Timer:
self._start_time = None self._start_time = None
self._elapsed_time = 0 self._elapsed_time = 0
def running(self):
return self._start_time is not None
def start(self): def start(self):
"""Start a new timer""" """Start a new timer"""
if self._start_time is not None: if self._start_time is not None:
...@@ -76,7 +79,7 @@ class inferenceSample(object): ...@@ -76,7 +79,7 @@ class inferenceSample(object):
self.poses = (np.linalg.inv(valid_poses_full[0]) @ valid_poses_full)[:, :3] self.poses = (np.linalg.inv(valid_poses_full[0]) @ valid_poses_full)[:, :3]
R = self.poses[:, :3, :3] R = self.poses[:, :3, :3]
self.rotation_angles = Rotation.from_matrix(R).magnitude() self.rotation_angles = Rotation.from_matrix(R).magnitude()
self.displacements = np.linalg.norm(self.poses[:, :, -1]) self.displacements = np.linalg.norm(self.poses[:, :, -1], axis=-1)
if (scene / "intrinsics.txt").isfile(): if (scene / "intrinsics.txt").isfile():
self.intrinsics = np.stack([np.genfromtxt(scene / "intrinsics.txt")]*max_shift) self.intrinsics = np.stack([np.genfromtxt(scene / "intrinsics.txt")]*max_shift)
...@@ -84,27 +87,43 @@ class inferenceSample(object): ...@@ -84,27 +87,43 @@ class inferenceSample(object):
intrinsics_files = [f.stripext() + "_intrinsics.txt" for f in self.valid_frames] intrinsics_files = [f.stripext() + "_intrinsics.txt" for f in self.valid_frames]
self.intrinsics = np.stack([np.genfromtxt(i) for i in intrinsics_files]) self.intrinsics = np.stack([np.genfromtxt(i) for i in intrinsics_files])
def timer_decorator(func, *args, **kwargs):
def wrapper(self, *args, **kwargs):
if self.timer.running():
self.timer.stop()
res = func(self, *args, **kwargs)
self.timer.start()
else:
res = func(self, *args, **kwargs)
return res
return wrapper
@timer_decorator
def get_frame(self, shift=0): def get_frame(self, shift=0):
self.timer.stop()
file = self.valid_frames[shift] file = self.valid_frames[shift]
img = imread(file) img = imread(file)
if self.frame_transform is not None: if self.frame_transform is not None:
img = self.frame_transform(img) img = self.frame_transform(img)
self.timer.start()
return img, self.intrinsics[shift], self.poses[shift] return img, self.intrinsics[shift], self.poses[shift]
@timer_decorator
def get_previous_frame(self, shift=1, displacement=None, max_rot=1): def get_previous_frame(self, shift=1, displacement=None, max_rot=1):
self.timer.stop()
if displacement is not None: if displacement is not None:
shift = max(1, (self.displacements - displacement).argmin()) shift = max(1, np.abs(self.displacements - displacement).argmin())
rot_valid = self.rotation_angles < max_rot rot_valid = self.rotation_angles < max_rot
assert sum(rot_valid[1:shift+1] > 0), "Rotation is alaways higher than {}".format(max_rot) assert sum(rot_valid[1:shift+1] > 0), "Rotation is always higher than {}".format(max_rot)
# Highest shift that has rotation below max_rot thresold # Highest shift that has rotation below max_rot thresold
final_shift = np.where(rot_valid[-1 - shift:])[0][-1] final_shift = np.where(rot_valid[-1 - shift:])[0][-1]
self.timer.start()
return self.get_frame(final_shift) return self.get_frame(final_shift)
@timer_decorator
def get_previous_frames(self, shifts=[1], displacements=None, max_rot=1):
if displacements is not None:
frames = zip(*[self.get_previous_frame(displacement=d, max_rot=max_rot) for d in displacements])
else:
frames = zip(*[self.get_previous_frame(shift=s, max_rot=max_rot) for s in shifts])
return frames
def inference_toolkit_example(): def inference_toolkit_example():
parser = ArgumentParser(description='Example usage of Inference toolkit', parser = ArgumentParser(description='Example usage of Inference toolkit',
...@@ -124,21 +143,22 @@ def inference_toolkit_example(): ...@@ -124,21 +143,22 @@ def inference_toolkit_example():
def my_model(frame, previous, pose): def my_model(frame, previous, pose):
# Mock up function that uses two frames and translation magnitude # Mock up function that uses two frames and translation magnitude
# return np.linalg.norm(pose[:, -1]) * np.linalg.norm(frame - previous, axis=-1) # return frame[..., -1]
return np.exp(np.random.randn(frame.shape[0], frame.shape[1])) return np.linalg.norm(pose[:, -1]) * np.linalg.norm(frame - previous, axis=-1)
# return np.exp(np.random.randn(frame.shape[0], frame.shape[1]))
engine = inferenceFramework(args.dataset_root, evaluation_list, lambda x: x.transpose(2, 0, 1).astype(np.float32)[None]/255) engine = inferenceFramework(args.dataset_root, evaluation_list, lambda x: x.transpose(2, 0, 1).astype(np.float32)[None]/255)
esimated_depth_maps = {} estimated_depth_maps = {}
mean_time = [] mean_time = []
for sample, image_path in zip(engine, tqdm(evaluation_list)): for sample, image_path in zip(engine, tqdm(evaluation_list)):
latest_frame, latest_intrinsics, _ = sample.get_frame() latest_frame, latest_intrinsics, _ = sample.get_frame()
previous_frame, previous_intrinsics, previous_pose = sample.get_previous_frame(displacement=0.3) previous_frame, previous_intrinsics, previous_pose = sample.get_previous_frame(displacement=0.3)
esimated_depth_maps[image_path] = (my_model(latest_frame, previous_frame, previous_pose)) estimated_depth_maps[image_path] = (my_model(latest_frame, previous_frame, previous_pose))
time_spent = engine.finish_frame(sample) time_spent = engine.finish_frame(sample)
mean_time.append(time_spent) mean_time.append(time_spent)
print("Mean time per sample : {:.2f}us".format(1e6 * sum(mean_time)/len(mean_time))) print("Mean time per sample : {:.2f}us".format(1e6 * sum(mean_time)/len(mean_time)))
np.savez(args.depth_output, **esimated_depth_maps) np.savez(args.depth_output, **estimated_depth_maps)
if __name__ == '__main__': if __name__ == '__main__':
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment