Commit 25db1cc5 authored by Clément Pinard's avatar Clément Pinard
Browse files

enhance timer for inference toolkit

parent 44ca0766
......@@ -12,6 +12,9 @@ class Timer:
self._start_time = None
self._elapsed_time = 0
def running(self):
return self._start_time is not None
def start(self):
"""Start a new timer"""
if self._start_time is not None:
......@@ -76,7 +79,7 @@ class inferenceSample(object):
self.poses = (np.linalg.inv(valid_poses_full[0]) @ valid_poses_full)[:, :3]
R = self.poses[:, :3, :3]
self.rotation_angles = Rotation.from_matrix(R).magnitude()
self.displacements = np.linalg.norm(self.poses[:, :, -1])
self.displacements = np.linalg.norm(self.poses[:, :, -1], axis=-1)
if (scene / "intrinsics.txt").isfile():
self.intrinsics = np.stack([np.genfromtxt(scene / "intrinsics.txt")]*max_shift)
......@@ -84,27 +87,43 @@ class inferenceSample(object):
intrinsics_files = [f.stripext() + "_intrinsics.txt" for f in self.valid_frames]
self.intrinsics = np.stack([np.genfromtxt(i) for i in intrinsics_files])
def timer_decorator(func, *args, **kwargs):
def wrapper(self, *args, **kwargs):
if self.timer.running():
self.timer.stop()
res = func(self, *args, **kwargs)
self.timer.start()
else:
res = func(self, *args, **kwargs)
return res
return wrapper
@timer_decorator
def get_frame(self, shift=0):
self.timer.stop()
file = self.valid_frames[shift]
img = imread(file)
if self.frame_transform is not None:
img = self.frame_transform(img)
self.timer.start()
return img, self.intrinsics[shift], self.poses[shift]
@timer_decorator
def get_previous_frame(self, shift=1, displacement=None, max_rot=1):
self.timer.stop()
if displacement is not None:
shift = max(1, (self.displacements - displacement).argmin())
shift = max(1, np.abs(self.displacements - displacement).argmin())
rot_valid = self.rotation_angles < max_rot
assert sum(rot_valid[1:shift+1] > 0), "Rotation is alaways higher than {}".format(max_rot)
assert sum(rot_valid[1:shift+1] > 0), "Rotation is always higher than {}".format(max_rot)
# Highest shift that has rotation below max_rot thresold
final_shift = np.where(rot_valid[-1 - shift:])[0][-1]
self.timer.start()
return self.get_frame(final_shift)
@timer_decorator
def get_previous_frames(self, shifts=[1], displacements=None, max_rot=1):
if displacements is not None:
frames = zip(*[self.get_previous_frame(displacement=d, max_rot=max_rot) for d in displacements])
else:
frames = zip(*[self.get_previous_frame(shift=s, max_rot=max_rot) for s in shifts])
return frames
def inference_toolkit_example():
parser = ArgumentParser(description='Example usage of Inference toolkit',
......@@ -124,21 +143,22 @@ def inference_toolkit_example():
def my_model(frame, previous, pose):
# Mock up function that uses two frames and translation magnitude
# return np.linalg.norm(pose[:, -1]) * np.linalg.norm(frame - previous, axis=-1)
return np.exp(np.random.randn(frame.shape[0], frame.shape[1]))
# return frame[..., -1]
return np.linalg.norm(pose[:, -1]) * np.linalg.norm(frame - previous, axis=-1)
# return np.exp(np.random.randn(frame.shape[0], frame.shape[1]))
engine = inferenceFramework(args.dataset_root, evaluation_list, lambda x: x.transpose(2, 0, 1).astype(np.float32)[None]/255)
esimated_depth_maps = {}
estimated_depth_maps = {}
mean_time = []
for sample, image_path in zip(engine, tqdm(evaluation_list)):
latest_frame, latest_intrinsics, _ = sample.get_frame()
previous_frame, previous_intrinsics, previous_pose = sample.get_previous_frame(displacement=0.3)
esimated_depth_maps[image_path] = (my_model(latest_frame, previous_frame, previous_pose))
estimated_depth_maps[image_path] = (my_model(latest_frame, previous_frame, previous_pose))
time_spent = engine.finish_frame(sample)
mean_time.append(time_spent)
print("Mean time per sample : {:.2f}us".format(1e6 * sum(mean_time)/len(mean_time)))
np.savez(args.depth_output, **esimated_depth_maps)
np.savez(args.depth_output, **estimated_depth_maps)
if __name__ == '__main__':
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment