Commit df749e75 authored by Julien Lin's avatar Julien Lin
Browse files

refactor

parent 006ced51
......@@ -3,19 +3,18 @@ import math
from typing import Any
import numpy as np
import matplotlib.pyplot as plt
import argparse
from sho import make, algo, iters, plot, num, bit, pb
# Dimension of the search space.
d = 2
########################################################################
# Interface
########################################################################
if __name__ == "__main__":
import argparse
# Dimension of the search space.
d = 2
def get_args_parser():
can = argparse.ArgumentParser()
can.add_argument(
......@@ -147,41 +146,10 @@ if __name__ == "__main__":
help="Quality threshold. Used to plot the probability of a run being under the quality threshold.",
)
the = can.parse_args()
# Minimum checks.
assert 0 < the.nb_sensors
assert 0 < the.sensor_range <= math.sqrt(2)
assert 0 < the.domain_width
assert 0 < the.iters
# Do not forget the seed option,
# in case you would start "runs" in parallel.
np.random.seed(the.seed)
# Weird numpy way to ensure single line print of array.
np.set_printoptions(linewidth=np.inf) # type: ignore
# Common termination and checkpointing.
history: list[Any] = []
iters = make.iter(
iters.several,
agains=[
make.iter(iters.max, nb_it=the.iters),
make.iter(
iters.save, filename=the.solver + ".csv", fmt="{it} ; {val} ; {sol}\n"
),
make.iter(iters.log, fmt="\r{it} {val}"),
make.iter(iters.history, history=history),
make.iter(iters.target, target=the.target),
iters.steady(the.steady_delta, the.steady_epsilon),
],
)
return can
# Erase the previous file.
with open(the.solver + ".csv", "w") as fd:
fd.write("# {} {}\n".format(the.solver, the.domain_width))
def run_algorithm(the, iters_func):
val, sol, sensors = None, None, None
if the.solver == "num_greedy":
val, sol = algo.greedy(
......@@ -201,7 +169,7 @@ if __name__ == "__main__":
scale=the.variation_scale,
domain_width=the.domain_width,
),
iters,
iters_func,
)
sensors = num.to_sensors(sol)
......@@ -223,7 +191,7 @@ if __name__ == "__main__":
scale=the.variation_scale,
domain_width=the.domain_width,
),
iters,
iters_func,
)
sensors = bit.to_sensors(sol)
elif the.solver == "num_sim_anneal":
......@@ -248,12 +216,11 @@ if __name__ == "__main__":
make.temp(0.99),
make.proba(),
make.rand(),
iters,
iters_func,
)
sensors = num.to_sensors(sol)
elif the.solver == "bit_sim_anneal":
val, sol = algo.simulated_annealing(
make.func(
num.cover_sum,
......@@ -274,7 +241,7 @@ if __name__ == "__main__":
make.temp(0.99),
make.proba(),
make.rand(),
iters,
iters_func,
)
sensors = num.to_sensors(sol)
elif the.solver == "num_random":
......@@ -290,7 +257,7 @@ if __name__ == "__main__":
dim=d * the.nb_sensors,
scale=the.domain_width,
),
iters,
iters_func,
)
sensors = num.to_sensors(sol)
elif the.solver == "num_evolutionary":
......@@ -322,9 +289,51 @@ if __name__ == "__main__":
nb_next_generation=the.nb_selected,
strat=None,
),
iters,
iters_func,
)
sensors = num.to_sensors(sol)
return val,sensors
def main(args=None):
can = get_args_parser()
the = can.parse_args(args)
# Minimum checks.
assert 0 < the.nb_sensors
assert 0 < the.sensor_range <= math.sqrt(2)
assert 0 < the.domain_width
assert 0 < the.iters
# Do not forget the seed option,
# in case you would start "runs" in parallel.
np.random.seed(the.seed)
# Weird numpy way to ensure single line print of array.
np.set_printoptions(linewidth=np.inf) # type: ignore
# Common termination and checkpointing.
history: list[Any] = []
iters_func = make.iter(
iters.several,
agains=[
make.iter(iters.max, nb_it=the.iters),
make.iter(
iters.save, filename=the.solver + ".csv", fmt="{it} ; {val} ; {sol}\n"
),
make.iter(iters.log, fmt="\r{it} {val}"),
make.iter(iters.history, history=history),
make.iter(iters.target, target=the.target),
iters.steady(the.steady_delta, the.steady_epsilon),
],
)
# Erase the previous file.
with open(the.solver + ".csv", "w") as fd:
fd.write("# {} {}\n".format(the.solver, the.domain_width))
val, sensors = run_algorithm(the, iters_func)
# Fancy output.
print("\n{} : {}".format(val, sensors))
......@@ -361,3 +370,8 @@ if __name__ == "__main__":
ax2.imshow(domain)
plt.show()
if __name__ == "__main__":
main()
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment