eaf.py 2.12 KB
Newer Older
Julien Lin's avatar
Julien Lin committed
1
2
3
4
5
6
7
8
9
import argparse
from matplotlib import pyplot as plt
import numpy as np
from snp import main as run
import multiprocessing
from joblib import Parallel, delayed

num_cores = multiprocessing.cpu_count()

Julien Lin's avatar
Julien Lin committed
10

Julien Lin's avatar
Julien Lin committed
11
12
13
14
15
def get_args_parse():
    can = argparse.ArgumentParser()

    can.add_argument("--nb-run", type=int, default=10, help="number of runs")

Julien Lin's avatar
Julien Lin committed
16
    can.add_argument(
Julien Lin's avatar
Julien Lin committed
17
        "--quality-threshold", type=float, default=660, nargs="+", help="Quality threshold"
Julien Lin's avatar
Julien Lin committed
18
19
20
21
22
    )

    can.add_argument(
        "--not-parallel", type=bool, default=False, help="launch runs sequentialy"
    )
Julien Lin's avatar
Julien Lin committed
23

Julien Lin's avatar
Julien Lin committed
24
25
    can.add_argument("--solver", type=str, default="num_random", help="Solver to use.")

Julien Lin's avatar
Julien Lin committed
26
27
28
    return can


Julien Lin's avatar
Julien Lin committed
29
def main(eaf_args=None):
Julien Lin's avatar
Julien Lin committed
30
31
32

    can = get_args_parse()

Julien Lin's avatar
Julien Lin committed
33
    the = can.parse_args(eaf_args)
Julien Lin's avatar
Julien Lin committed
34

Julien Lin's avatar
Julien Lin committed
35
36
    args = ["--solver", the.solver, "--iters", "2000", "--steady-delta", "500"]

Julien Lin's avatar
Julien Lin committed
37
38
39
    if the.not_parallel:
        results = []
        for i in range(the.nb_run):
Julien Lin's avatar
Julien Lin committed
40
            # print(f"{i}th run")
Julien Lin's avatar
Julien Lin committed
41
            results.append(run(args))
Julien Lin's avatar
Julien Lin committed
42
    else:
Julien Lin's avatar
Julien Lin committed
43

Julien Lin's avatar
Julien Lin committed
44
        def f(i):
Julien Lin's avatar
Julien Lin committed
45
            # print(f"{i}th run")
Julien Lin's avatar
Julien Lin committed
46
47
48
            return run(args)

        results = Parallel(n_jobs=num_cores)(delayed(f)(i) for i in range(the.nb_run))
Julien Lin's avatar
Julien Lin committed
49
50
51
52
53

    values = [(val, iter) for val, iter, _, _ in results]

    values = sorted(values, key=lambda el: el[1])

Julien Lin's avatar
Julien Lin committed
54
55
    t = [iter for _, iter in values]
    values = np.array([val for val, _ in values])
Julien Lin's avatar
Julien Lin committed
56
57
58
59
60
    buff = np.empty((1,the.nb_run))
    for quality_threshold in the.quality_threshold:
        curr_qualities = np.cumsum(values >= quality_threshold) / len(values) 
        curr_qualities = np.expand_dims(curr_qualities, axis=0)
        buff = np.concatenate((buff,curr_qualities))
Julien Lin's avatar
Julien Lin committed
61

Julien Lin's avatar
Julien Lin committed
62
    fig = plt.figure()
Julien Lin's avatar
Julien Lin committed
63

Julien Lin's avatar
Julien Lin committed
64
65
    plt.title(f"{the.solver}, threshold {the.quality_threshold}")
    for i in range(1, buff.shape[0]):
Julien Lin's avatar
Julien Lin committed
66
        plt.step(t, buff[i])
Julien Lin's avatar
Julien Lin committed
67
68
69
70
71
72
    
    # ax4 = fig.add_subplot(122, projection="3d")
    # plt.title(f"{the.solver}, thresholds : {the.quality_threshold}")

    # ax4.plot_surface(t, values, the.quality_threshold )
    plt.legend([f'threshold : {the.quality_threshold[i]}' for i in range(len(the.quality_threshold))])
Julien Lin's avatar
Julien Lin committed
73
74
75
76
77
    plt.show()


if __name__ == "__main__":
    main()