inference_toolkit.py 6.83 KB
Newer Older
1
2
3
4
import numpy as np
from path import Path
from imageio import imread
import time
5
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
6
7
from scipy.spatial.transform import Rotation
from tqdm import tqdm
8
9
10
11
12
13
14


class Timer:
    def __init__(self):
        self._start_time = None
        self._elapsed_time = 0

15
16
17
    def running(self):
        return self._start_time is not None

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    def start(self):
        """Start a new timer"""
        if self._start_time is not None:
            return

        self._start_time = time.perf_counter()

    def stop(self):
        """Stop the timer, and report the elapsed time"""
        if self._start_time is None:
            return

        self._elapsed_time += time.perf_counter() - self._start_time
        self._start_time = None

    def get_elapsed(self):
        return self._elapsed_time

    def reset(self):
        self.__init__()


class inferenceFramework(object):
41
42
    def __init__(self, root, test_files, min_depth=1e-3, max_depth=80, max_shift=50, frame_transform=None):
        self.root = Path(root)
43
44
45
        self.test_files = test_files
        self.min_depth, self.max_depth = min_depth, max_depth
        self.max_shift = max_shift
46
        self.frame_transform = frame_transform
Clément Pinard's avatar
Clément Pinard committed
47
48
        self.inference_time = []
        self.estimated_depth_maps = {}
49
50
51

    def __getitem__(self, i):
        timer = Timer()
Clément Pinard's avatar
Clément Pinard committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        self.i = i
        self.current_sample = inferenceSample(self.root, self.test_files[i], self.max_shift, timer, self.frame_transform)
        self.current_sample.timer.start()
        return self.current_sample

    def finish_frame(self, estimated_depth):
        self.current_sample.timer.stop()
        elapsed = self.current_sample.timer.get_elapsed()
        self.inference_time.append(elapsed)
        self.estimated_depth_maps[self.current_sample.file] = estimated_depth
        return elapsed

    def finalize(self, output_path=None):
        if output_path is not None:
            np.savez(output_path, **self.estimated_depth_maps)
        mean_inference_time = np.mean(self.inference_time)
        return mean_inference_time, self.estimated_depth_maps
69
70

    def __len__(self):
Clément Pinard's avatar
Clément Pinard committed
71
        return len(self.test_files)
72
73
74
75
76
77
78


class inferenceSample(object):
    def __init__(self, root, file, max_shift, timer, frame_transform=None):
        self.root = root
        self.file = file
        self.frame_transform = frame_transform
79
        self.timer = timer
80
81
        full_filepath = self.root / file
        scene = full_filepath.parent
82
        scene_files = sorted(scene.files("*.jpg"))
83
84
85
86
87
        poses = np.genfromtxt(scene / "poses.txt").reshape((-1, 3, 4))
        sample_id = scene_files.index(full_filepath)
        assert(sample_id > max_shift)
        start_id = sample_id - max_shift
        self.valid_frames = scene_files[start_id:sample_id + 1][::-1]
88
89
90
        valid_poses = np.flipud(poses[start_id:sample_id + 1])
        last_line = np.broadcast_to(np.array([0, 0, 0, 1]), (valid_poses.shape[0], 1, 4))
        valid_poses_full = np.concatenate([valid_poses, last_line], axis=1)
91
92
        self.poses = (np.linalg.inv(valid_poses_full[0]) @  valid_poses_full)[:, :3]
        R = self.poses[:, :3, :3]
93
        self.rotation_angles = Rotation.from_matrix(R).magnitude()
94
        self.displacements = np.linalg.norm(self.poses[:, :, -1], axis=-1)
95
96
97
98
99
100
101

        if (scene / "intrinsics.txt").isfile():
            self.intrinsics = np.stack([np.genfromtxt(scene / "intrinsics.txt")]*max_shift)
        else:
            intrinsics_files = [f.stripext() + "_intrinsics.txt" for f in self.valid_frames]
            self.intrinsics = np.stack([np.genfromtxt(i) for i in intrinsics_files])

102
103
104
105
106
107
108
109
110
111
112
113
    def timer_decorator(func, *args, **kwargs):
        def wrapper(self, *args, **kwargs):
            if self.timer.running():
                self.timer.stop()
                res = func(self, *args, **kwargs)
                self.timer.start()
            else:
                res = func(self, *args, **kwargs)
            return res
        return wrapper

    @timer_decorator
114
115
116
117
118
    def get_frame(self, shift=0):
        file = self.valid_frames[shift]
        img = imread(file)
        if self.frame_transform is not None:
            img = self.frame_transform(img)
119
        return img, self.intrinsics[shift], self.poses[shift]
120

121
    @timer_decorator
122
123
    def get_previous_frame(self, shift=1, displacement=None, max_rot=1):
        if displacement is not None:
124
            shift = max(1, np.abs(self.displacements - displacement).argmin())
125
        rot_valid = self.rotation_angles < max_rot
126
        assert sum(rot_valid[1:shift+1] > 0), "Rotation is always higher than {}".format(max_rot)
127
        # Highest shift that has rotation below max_rot thresold
128
129
        final_shift = np.where(rot_valid[-1 - shift:])[0][-1]
        return self.get_frame(final_shift)
130

131
132
133
134
135
136
137
138
    @timer_decorator
    def get_previous_frames(self, shifts=[1], displacements=None, max_rot=1):
        if displacements is not None:
            frames = zip(*[self.get_previous_frame(displacement=d, max_rot=max_rot) for d in displacements])
        else:
            frames = zip(*[self.get_previous_frame(shift=s, max_rot=max_rot) for s in shifts])
        return frames

139
140
141
142
143
144

def inference_toolkit_example():
    parser = ArgumentParser(description='Example usage of Inference toolkit',
                            formatter_class=ArgumentDefaultsHelpFormatter)

    parser.add_argument('--dataset_root', metavar='DIR', type=Path)
145
    parser.add_argument('--depth_output', metavar='FILE', type=Path,
146
147
148
149
150
151
152
153
154
155
156
157
                        help='where to store the estimated depth maps, must be a npy file')
    parser.add_argument('--evaluation_list_path', metavar='PATH', type=Path,
                        help='File with list of images to test for depth evaluation')
    parser.add_argument('--scale-invariant', action='store_true',
                        help='If selected, will rescale depth map with ratio of medians')
    args = parser.parse_args()

    with open(args.evaluation_list_path) as f:
        evaluation_list = [line[:-1] for line in f.readlines()]

    def my_model(frame, previous, pose):
        # Mock up function that uses two frames and translation magnitude
158
159
160
        # return frame[..., -1]
        return np.linalg.norm(pose[:, -1]) * np.linalg.norm(frame - previous, axis=-1)
        # return np.exp(np.random.randn(frame.shape[0], frame.shape[1]))
161

162
    engine = inferenceFramework(args.dataset_root, evaluation_list, lambda x: x.transpose(2, 0, 1).astype(np.float32)[None]/255)
Clément Pinard's avatar
Clément Pinard committed
163
    for sample in tqdm(engine):
164
165
        latest_frame, latest_intrinsics, _ = sample.get_frame()
        previous_frame, previous_intrinsics, previous_pose = sample.get_previous_frame(displacement=0.3)
Clément Pinard's avatar
Clément Pinard committed
166
        engine.finish_frame(my_model(latest_frame, previous_frame, previous_pose))
167

Clément Pinard's avatar
Clément Pinard committed
168
169
    mean_time, _ = engine.finalize(args.depth_output)
    print("Mean time per sample : {:.2f}us".format(1e6 * mean_time))
170
171
172
173


if __name__ == '__main__':
    inference_toolkit_example()