inference_toolkit.py 6.47 KB
Newer Older
1
2
3
4
import numpy as np
from path import Path
from imageio import imread
import time
5
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
6
7
from scipy.spatial.transform import Rotation
from tqdm import tqdm
8
9
10
11
12
13
14


class Timer:
    def __init__(self):
        self._start_time = None
        self._elapsed_time = 0

15
16
17
    def running(self):
        return self._start_time is not None

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    def start(self):
        """Start a new timer"""
        if self._start_time is not None:
            return

        self._start_time = time.perf_counter()

    def stop(self):
        """Stop the timer, and report the elapsed time"""
        if self._start_time is None:
            return

        self._elapsed_time += time.perf_counter() - self._start_time
        self._start_time = None

    def get_elapsed(self):
        return self._elapsed_time

    def reset(self):
        self.__init__()


class inferenceFramework(object):
41
42
    def __init__(self, root, test_files, min_depth=1e-3, max_depth=80, max_shift=50, frame_transform=None):
        self.root = Path(root)
43
44
45
        self.test_files = test_files
        self.min_depth, self.max_depth = min_depth, max_depth
        self.max_shift = max_shift
46
        self.frame_transform = frame_transform
47
48
49

    def __getitem__(self, i):
        timer = Timer()
50
        sample = inferenceSample(self.root, self.test_files[i], self.max_shift, timer, self.frame_transform)
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        sample.timer.start()
        return sample

    def finish_frame(self, sample):
        sample.timer.stop()
        return sample.timer.get_elapsed()

    def __len__(self):
        return len(self.img_files)


class inferenceSample(object):
    def __init__(self, root, file, max_shift, timer, frame_transform=None):
        self.root = root
        self.file = file
        self.frame_transform = frame_transform
67
        self.timer = timer
68
69
        full_filepath = self.root / file
        scene = full_filepath.parent
70
        scene_files = sorted(scene.files("*.jpg"))
71
72
73
74
75
        poses = np.genfromtxt(scene / "poses.txt").reshape((-1, 3, 4))
        sample_id = scene_files.index(full_filepath)
        assert(sample_id > max_shift)
        start_id = sample_id - max_shift
        self.valid_frames = scene_files[start_id:sample_id + 1][::-1]
76
77
78
        valid_poses = np.flipud(poses[start_id:sample_id + 1])
        last_line = np.broadcast_to(np.array([0, 0, 0, 1]), (valid_poses.shape[0], 1, 4))
        valid_poses_full = np.concatenate([valid_poses, last_line], axis=1)
79
80
        self.poses = (np.linalg.inv(valid_poses_full[0]) @  valid_poses_full)[:, :3]
        R = self.poses[:, :3, :3]
81
        self.rotation_angles = Rotation.from_matrix(R).magnitude()
82
        self.displacements = np.linalg.norm(self.poses[:, :, -1], axis=-1)
83
84
85
86
87
88
89

        if (scene / "intrinsics.txt").isfile():
            self.intrinsics = np.stack([np.genfromtxt(scene / "intrinsics.txt")]*max_shift)
        else:
            intrinsics_files = [f.stripext() + "_intrinsics.txt" for f in self.valid_frames]
            self.intrinsics = np.stack([np.genfromtxt(i) for i in intrinsics_files])

90
91
92
93
94
95
96
97
98
99
100
101
    def timer_decorator(func, *args, **kwargs):
        def wrapper(self, *args, **kwargs):
            if self.timer.running():
                self.timer.stop()
                res = func(self, *args, **kwargs)
                self.timer.start()
            else:
                res = func(self, *args, **kwargs)
            return res
        return wrapper

    @timer_decorator
102
103
104
105
106
    def get_frame(self, shift=0):
        file = self.valid_frames[shift]
        img = imread(file)
        if self.frame_transform is not None:
            img = self.frame_transform(img)
107
        return img, self.intrinsics[shift], self.poses[shift]
108

109
    @timer_decorator
110
111
    def get_previous_frame(self, shift=1, displacement=None, max_rot=1):
        if displacement is not None:
112
            shift = max(1, np.abs(self.displacements - displacement).argmin())
113
        rot_valid = self.rotation_angles < max_rot
114
        assert sum(rot_valid[1:shift+1] > 0), "Rotation is always higher than {}".format(max_rot)
115
        # Highest shift that has rotation below max_rot thresold
116
117
        final_shift = np.where(rot_valid[-1 - shift:])[0][-1]
        return self.get_frame(final_shift)
118

119
120
121
122
123
124
125
126
    @timer_decorator
    def get_previous_frames(self, shifts=[1], displacements=None, max_rot=1):
        if displacements is not None:
            frames = zip(*[self.get_previous_frame(displacement=d, max_rot=max_rot) for d in displacements])
        else:
            frames = zip(*[self.get_previous_frame(shift=s, max_rot=max_rot) for s in shifts])
        return frames

127
128
129
130
131
132

def inference_toolkit_example():
    parser = ArgumentParser(description='Example usage of Inference toolkit',
                            formatter_class=ArgumentDefaultsHelpFormatter)

    parser.add_argument('--dataset_root', metavar='DIR', type=Path)
133
    parser.add_argument('--depth_output', metavar='FILE', type=Path,
134
135
136
137
138
139
140
141
142
143
144
145
                        help='where to store the estimated depth maps, must be a npy file')
    parser.add_argument('--evaluation_list_path', metavar='PATH', type=Path,
                        help='File with list of images to test for depth evaluation')
    parser.add_argument('--scale-invariant', action='store_true',
                        help='If selected, will rescale depth map with ratio of medians')
    args = parser.parse_args()

    with open(args.evaluation_list_path) as f:
        evaluation_list = [line[:-1] for line in f.readlines()]

    def my_model(frame, previous, pose):
        # Mock up function that uses two frames and translation magnitude
146
147
148
        # return frame[..., -1]
        return np.linalg.norm(pose[:, -1]) * np.linalg.norm(frame - previous, axis=-1)
        # return np.exp(np.random.randn(frame.shape[0], frame.shape[1]))
149

150
    engine = inferenceFramework(args.dataset_root, evaluation_list, lambda x: x.transpose(2, 0, 1).astype(np.float32)[None]/255)
151
    estimated_depth_maps = {}
152
    mean_time = []
153
    for sample, image_path in zip(engine, tqdm(evaluation_list)):
154
155
        latest_frame, latest_intrinsics, _ = sample.get_frame()
        previous_frame, previous_intrinsics, previous_pose = sample.get_previous_frame(displacement=0.3)
156
        estimated_depth_maps[image_path] = (my_model(latest_frame, previous_frame, previous_pose))
157
158
159
160
        time_spent = engine.finish_frame(sample)
        mean_time.append(time_spent)

    print("Mean time per sample : {:.2f}us".format(1e6 * sum(mean_time)/len(mean_time)))
161
    np.savez(args.depth_output, **estimated_depth_maps)
162
163
164
165


if __name__ == '__main__':
    inference_toolkit_example()