generate_sky_masks.py 4.08 KB
Newer Older
Clément Pinard's avatar
Clément Pinard committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn.functional as F
import imageio
from model.enet import ENet
from path import Path
from tqdm import tqdm
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import numpy as np

cityscapes_labels = ['unlabeled', 'road', 'sidewalk',
                     'building', 'wall', 'fence', 'pole',
                     'traffic_light', 'traffic_sign', 'vegetation',
                     'terrain', 'sky', 'person', 'rider', 'car',
                     'truck', 'bus', 'train', 'motorcycle', 'bicycle']

sky_index = cityscapes_labels.index('sky')


def prepare_network():
    ENet_model = ENet(len(cityscapes_labels))
    checkpoint = torch.load('model/ENet')
    ENet_model.load_state_dict(checkpoint['state_dict'])
    return ENet_model.eval().cuda()


def erosion(width, mask):
    kernel = torch.ones(1, 1, 2 * width + 1, 2 * width + 1).to(mask) / (2 * width + 1)**2
Clement Pinard's avatar
Clement Pinard committed
28
    padded = torch.nn.functional.pad(mask.unsqueeze(1), [width]*4, value=1)
Clément Pinard's avatar
Clément Pinard committed
29
30
31
32
33
34
35
    filtered = torch.nn.functional.conv2d(padded, kernel)
    mask = (filtered == 1).float()

    return mask


@torch.no_grad()
Clement Pinard's avatar
Clement Pinard committed
36
37
def extract_sky_mask(network, image_paths, mask_folder):
    images = np.stack([imageio.imread(i) for i in image_paths])
38
39
    if len(images.shape) == 3:
        images = np.stack(3 * [images], axis=-1)
Clement Pinard's avatar
Clement Pinard committed
40
41
42
    b, h, w, _ = images.shape
    image_tensor = torch.from_numpy(images).float()/255
    image_tensor = image_tensor.permute(0, 3, 1, 2)  # shape [B, C, H, W]
Clément Pinard's avatar
Clément Pinard committed
43

Clément Pinard's avatar
Clément Pinard committed
44
45
46
    w_r = 512
    h_r = int(512 * h / w)
    reduced = F.interpolate(image_tensor, size=(h_r, w_r), mode='area')
Clément Pinard's avatar
Clément Pinard committed
47
48

    result = network(reduced.cuda())
Clement Pinard's avatar
Clement Pinard committed
49
    classes = torch.max(result, 1)[1]
Clément Pinard's avatar
Clément Pinard committed
50
51
52
53
54
    mask = (classes == sky_index).float()

    filtered_mask = erosion(1, mask)
    upsampled = F.interpolate(filtered_mask, size=(h, w), mode='nearest')

Clement Pinard's avatar
Clement Pinard committed
55
56
57
58
59
    final_masks = 1 - upsampled.permute(0, 2, 3, 1).cpu().numpy()

    for f, path in zip(final_masks, image_paths):
        imageio.imwrite(mask_folder/(path.basename() + '.png'), (f*255).astype(np.uint8))

Clément Pinard's avatar
Clément Pinard committed
60

61
def process_folder(folder_to_process, colmap_img_root, mask_path, pic_ext, verbose=False, batchsize=8, **env):
Clement Pinard's avatar
Clement Pinard committed
62
63
64
65
66
    network = prepare_network()
    folders = [folder_to_process] + list(folder_to_process.walkdirs())

    for folder in folders:

67
        mask_folder = mask_path/colmap_img_root.relpathto(folder)
Clement Pinard's avatar
Clement Pinard committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        mask_folder.makedirs_p()
        images = sum((folder.files('*{}'.format(ext)) for ext in pic_ext), [])
        if images:
            if verbose:
                print("Generating masks for images in {}".format(str(folder)))
                images = tqdm(images)
            to_process = []
            for image_file in images:
                if (mask_folder / (image_file.basename() + '.png')).isfile():
                    continue
                to_process.append(image_file)
                if len(to_process) == batchsize:
                    extract_sky_mask(network, to_process, mask_folder)
                    to_process = []
            if to_process:
                extract_sky_mask(network, to_process, mask_folder)
Clément Pinard's avatar
Clément Pinard committed
84
85
    del network
    torch.cuda.empty_cache()
Clément Pinard's avatar
Clément Pinard committed
86
87
88
89
90


parser = ArgumentParser(description='sky mask generator using ENet trained on cityscapes',
                        formatter_class=ArgumentDefaultsHelpFormatter)

Clément Pinard's avatar
Clément Pinard committed
91
parser.add_argument('--img_dir', metavar='DIR', default=Path("workspace/Pictures"),
Clément Pinard's avatar
Clément Pinard committed
92
                    help='path to image folder root', type=Path)
Clément Pinard's avatar
Clément Pinard committed
93
parser.add_argument('--colmap_img_root', metavar='DIR', default=Path("workspace/Pictures"), type=Path,
Clément Pinard's avatar
Clément Pinard committed
94
                    help='image_path you will give to colmap when extracting feature')
Clément Pinard's avatar
Clément Pinard committed
95
parser.add_argument('--mask_root', metavar='DIR', default=Path("workspace/Masks"),
Clément Pinard's avatar
Clément Pinard committed
96
97
                    help='where to store the generated_masks', type=Path)
parser.add_argument("--batch_size", "-b", type=int, default=8)
Clément Pinard's avatar
Clément Pinard committed
98
99
100
101

if __name__ == '__main__':
    args = parser.parse_args()
    network = prepare_network()
Clément Pinard's avatar
Clément Pinard committed
102
    if args.img_dir[-1] == "/":
Clément Pinard's avatar
Clément Pinard committed
103
        args.img_dir = args.img_dir[:-1]
Clément Pinard's avatar
Clément Pinard committed
104
    args.mask_root.makedirs_p()
Clément Pinard's avatar
Clément Pinard committed
105
106
    file_exts = ['jpg', 'JPG']

Clément Pinard's avatar
Clément Pinard committed
107
    process_folder(args.img_dir, args.colmap_img_root, args.mask_root, file_exts, True, args.batchsize)