Commit 21f00ec0 authored by Julien Lin's avatar Julien Lin
Browse files

implement eaf

parent 88acb958
import argparse
from matplotlib import pyplot as plt
import numpy as np
from snp import main as run
import multiprocessing
from joblib import Parallel, delayed
num_cores = multiprocessing.cpu_count()
def get_args_parse():
can = argparse.ArgumentParser()
can.add_argument("--nb-run", type=int, default=10, help="number of runs")
can.add_argument("--quality-threshold", type=float, default=600, help="Quality threshold")
can.add_argument("--not-parallel", type=bool, default=False, help="launch runs sequentialy")
return can
def main():
can = get_args_parse()
the = can.parse_args()
if the.not_parallel:
results = []
for i in range(the.nb_run):
print(f"{i}th run")
results.append(
run(["--solver", "num_random", "--iters", "2000", "--steady-delta", "500"])
)
else:
def f(i):
print(f"{i}th run")
return run(["--solver", "num_random", "--iters", "2000", "--steady-delta", "500"])
results = Parallel(n_jobs=num_cores)(delayed(f)(i) for i in range(the.nb_run))
fig = plt.figure()
ax3 = fig.add_subplot(111)
values = [(val, iter) for val, iter, _, _ in results]
values = sorted(values, key=lambda el: el[1])
t = [iter for _,iter in values]
values = np.array([ val for val, _ in values])
values = np.cumsum( values >= the.quality_threshold) / len(values)
ax3.step(t, values)
plt.show()
if __name__ == "__main__":
main()
......@@ -296,6 +296,32 @@ def run_algorithm(the, iters_func):
return val, sol, sensors
def plot_sol(the, history, sensors, shape):
fig = plt.figure()
if the.nb_sensors == 1 and the.domain_width <= 50:
ax1 = fig.add_subplot(121, projection="3d")
ax2 = fig.add_subplot(122)
f = make.func(
num.cover_sum,
domain_width=the.domain_width,
sensor_range=the.sensor_range,
dim=d * the.nb_sensors,
)
plot.surface(ax1, shape, f)
plot.path(ax1, shape, history)
else:
ax2 = fig.add_subplot(121)
domain = np.zeros(shape)
domain = pb.coverage(domain, sensors, the.sensor_range * the.domain_width)
domain = plot.highlight_sensors(domain, sensors)
ax2.imshow(domain)
plt.show()
def main(args=None):
can = get_args_parser()
......@@ -348,32 +374,6 @@ def main(args=None):
return val, len(history), sol, sensors
def plot_sol(the, history, sensors, shape):
fig = plt.figure()
if the.nb_sensors == 1 and the.domain_width <= 50:
ax1 = fig.add_subplot(121, projection="3d")
ax2 = fig.add_subplot(122)
f = make.func(
num.cover_sum,
domain_width=the.domain_width,
sensor_range=the.sensor_range,
dim=d * the.nb_sensors,
)
plot.surface(ax1, shape, f)
plot.path(ax1, shape, history)
else:
ax2 = fig.add_subplot(121)
domain = np.zeros(shape)
domain = pb.coverage(domain, sensors, the.sensor_range * the.domain_width)
domain = plot.highlight_sensors(domain, sensors)
ax2.imshow(domain)
plt.show()
if __name__ == "__main__":
main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment