Skip to content

Update NG vs PTC Tests #303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/tests/test-ptc-maps/test-quad-maps.mad
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ local function testQUADhe ()
local cfg = ref_cfg "quadhe" {
elm = [[
QUADRUPOLE, at=0.75, l=1.5, k1=${k1}, k1s=${k1s}, fringe=7,
k0=${k0}, e1=${e1}, e2=${e2}, h1=${h1}, h2=${h2}, bend_fringe=true
k0=${k0}, e1=${e1}, e2=${e2}, h1=${h1}, h2=${h2}, bend_fringe=true,
fringe_max=${fringe_max}
]],

tol = 500,
Expand All @@ -262,9 +263,10 @@ local function testQUADhe ()
e2 = {0, 0.3, 0.8},
h1 = {0, 0.04, 0.05},
h2 = {0, 0.04, 0.05},
fringe_max = {2, 5},
alist = tblcat(
ref_cfg.alist,
{"k1", "k1s", "k0", "e1", "e2", "h1", "h2"}
{"k1", "k1s", "k0", "e1", "e2", "h1", "h2", "fringe_max"}
),

plot_info = {
Expand Down
290 changes: 33 additions & 257 deletions tests/tests/test-ptc-maps/trackvsptc.mad
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
-- locals ---------------------------------------------------------------------o
local matrix, mtable, tostring, damap, vector, object, plot in MAD
local eps in MAD.constant
local is_number, is_string, is_vector, is_table in MAD.typeid
local openfile, tblcpy, val2keys, fileexists, tblcat, strinter in MAD.utility
local pi, abs, floor, log, max in math
local round in MAD.gmath
local mtable, damap, object in MAD
local openfile, fileexists, tblcat, strinter in MAD.utility
local max in math

package.path = package.path .. ";../tools/?.mad"
local get_diff, save_results, gen_cfg, in_dir, out_dir, plt_dir,
plot_trk_res, get_prev_res, prnt_results, add_trk_gen_cols,
show_res in require "track-tool"

local create_dif = require "madl_dbgmap".cmpmdump
local plot_id = 0
-- unique plot_id generator
local function newSID ()
plot_id = plot_id % 25 + 1
return plot_id
end

local dum = damap() -- get a dummy damap object for reading
local X0s = {{x=0 , px=0 , y=0 , py=0 , t=0 , pt=0 }, -- zero
{x=3e-3, px=-2e-4, y=-2e-3, py=3e-4, t=0 , pt=0 }, -- 4D
{x=3e-3, px=-2e-4, y=-2e-3, py=3e-4, t=0 , pt=2e-5}, -- 5D
{x=3e-3, px=-2e-4, y=-2e-3, py=3e-4, t=1e-5, pt=2e-5}} -- 6D
{x=3e-3, px=-2e-4, y=-2e-3, py=3e-4, t=0 , pt=0 }, -- 4D
{x=3e-3, px=-2e-4, y=-2e-3, py=3e-4, t=0 , pt=2e-5}, -- 5D
{x=3e-3, px=-2e-4, y=-2e-3, py=3e-4, t=1e-5, pt=2e-5}} -- 6D

local coord_str = [[
x0 = ${x};
Expand Down Expand Up @@ -48,19 +44,12 @@ energy = ${energy};
${test_ctx}
]]

local in_dir = \s -> 'input/' ..(s or '')
local ref_file = openfile(in_dir("ref.madx"), "r")
local madx_ref = ref_file:read("*a")
ref_file:close()
local ref_file = openfile(in_dir("ref.mad"), "r")
local mad_ref, mad_file = ref_file:read("*a")
ref_file:close()

local out_dir = \s -> 'output/'..(s or '')
local plt_dir = \s -> out_dir('plots/'..(s or ''))

os.execute("mkdir -p "..out_dir()) -- Create output dir if it doesn't exist
os.execute("mkdir -p "..plt_dir()) -- Create image dir if it doesn't exist
-------------------------------------------------------------------------------o

-- Write MAD-X script to elmseq.seq and generate a MAD-NG script and return it-o
Expand Down Expand Up @@ -110,253 +99,42 @@ local function do_trck(cfg)

-- Grab PTC last map from out file and get diff with mflw[1]
local ptc_res = getlastmap(out_dir(cfg.name .. "_p.txt")):fromptc()
local dif = mflw[1]:dif(ptc_res)

-- Setup max matrix (coords x order)
local max_difs = matrix(6, cfg.order+1)

for i, c in ipairs({"x", "px", "y", "py", "t", "pt"}) do
-- Get max idx for each coordinate at each order
local _, max_idxs = dif[c]:maxbyord()

-- Create dummy vector to store max values
local max_vals = vector(#max_idxs)

-- Get max value for each order
max_idxs:map(\x-> x~=0 and abs(dif[c]:get(x))/eps or 0, max_vals)

-- Add max values to row of matrix
max_difs:setrow(i, max_vals)

-- Get and add coordinate max to results table
res[c.."_eps"] = max_vals:max()
end

-- Get and add order max to results table
for i = 1, max_difs.ncol do
res["order"..i-1.."_eps"] = max_difs:getcol(i):max()
end
return res
return get_diff(mflw[1], ptc_res, cfg.order)
end

local function run_cfg (cfg, results)
-- Run track for a single configuration
local res = do_trck(cfg)

-- Add results and configuration to table
results:addrow{
__cfg=tblcpy(cfg.cur_cfg),
__res=res,
}

if cfg.doprnt then -- If print mode is on, print results
io.write(cfg.cur_cfg.cfgid, "\t")
-- Print max dif for each order
for i = 1, cfg.order+1 do
local ord_max_dif = res["order"..i-1.."_eps"]
io.write(
ord_max_dif > cfg.tol and (">e+" .. floor(log(ord_max_dif, 10)))
or string.format("%d", ord_max_dif),
"\t"
save_results(cfg, res, results)
prnt_results(cfg, res)
if not cfg.dodbg then return end-- If debug mode is on, stop when max dif is greater than tol
for i = 1, cfg.order+1 do
if res["order"..i-1.."_eps"] > cfg.tol then
io.write("Max dif greater than tolerance, stopping...\n")
-- Run mad in debug mode and set the program to stop
create_seq(cfg {debug = 6})
openfile(in_dir(cfg.name.."_ref.mad"), "w"):write(mad_ref%cfg):close()
os.execute(
'../mad '.. in_dir(cfg.name.."_ref.mad")
..' >' .. out_dir(cfg.name .. "_n.txt")
)
end
-- Print current configuration
for i, attr in ipairs(cfg.alist) do
io.write(attr, "=", tostring(cfg.cur_cfg[attr]), ", ")
end
io.write("\n")
end

if cfg.dodbg then -- If debug mode is on, stop when max dif is greater than tol
for i = 1, cfg.order+1 do
if res["order"..i-1.."_eps"] > cfg.tol then
print("Max dif greater than tolerance, stopping...")
-- Run mad in debug mode and set the program to stop
create_seq(cfg {debug = 6})
openfile(in_dir(cfg.name.."_ref.mad"), "w"):write(mad_ref%cfg):close()
os.execute(
'../mad '.. in_dir(cfg.name.."_ref.mad")
..' >' .. out_dir(cfg.name .. "_n.txt")
)
create_dif({nam=out_dir(cfg.name)})
cfg.stop = true
end
end
end
end

-- From cfg object, create every configuration through recursion --------------o
local function gen_cfg(cfg, idx, gen_fun)
if cfg.stop then return end -- Stop if the stop flag is set
local k = cfg.alist[idx]
if not k then
cfg.cur_cfg.cfgid = cfg.cur_cfg.cfgid + 1
return gen_fun() -- Could be changed to any function
end
for i, v in ipairs(cfg[k]) do
cfg.cur_cfg[k] = v
gen_cfg(cfg, idx+1, gen_fun) -- Index required as this needs to stay constant during each call
end
end

-- Add the generator columns to the table ---------------------------------------o
local function add_gen_cols(results, cfg)
-- Create the result column names as a list
local ord_lst = {}
for i = 0, cfg.order do ord_lst[i+1] = "order"..i.."_eps" end
results.res_cols = tblcat(
ord_lst, {"x_eps", "px_eps", "y_eps", "py_eps", "t_eps", "pt_eps"}
)

-- Add the cfg columns to the mtable
results:addcol("cfgid", \ri, m -> m.__cfg[ri].cfgid)
for _, k in ipairs(cfg.alist) do
results:addcol(k, \ri, m =>
local v = m.__cfg[ri][k]
return is_table(v) and MAD.tostring(v) or v
end)
end

-- Add the result columns to the mtable
for _, k in ipairs(results.res_cols) do
results:addcol(k, \ri, m -> round(m.__res[ri][k], 2))
end
end
-------------------------------------------------------------------------------o

-- Output results of test ----------------------------------------------------o
local function get_lower_bnds(res, tol)
if is_string(tol) then
local bnds_file = mtable:read(tol)
assert(
#bnds_file == #res,
"The tolerance file must have the same number of rows as the configuration file"
)
return \o, ri->bnds_file[ri]["order"..o.."_eps"]
elseif is_number(tol) then
return \->tol
else
return \o->tol[o]
end
end

local function show_res(res, cfg, attr_cols, tol)
local tol = get_lower_bnds(res, tol)
local col_tbl = {}; for i = 0, res.max_order do col_tbl[i] = {} end
local dum_tbl = mtable(tblcpy(attr_cols))
dum_tbl.novector = true

io.write("For each order, the number of configurations that failed:\n")
for o = 0, res.max_order do
local err_tbl = dum_tbl:copy()
local max_err = 0
res:foreach(\r, ri =>
max_err = max(max_err, abs(res[ri]["order"..o.."_eps"]))
if res[ri]["order"..o.."_eps"] > tol(o, ri) then
for i, v in ipairs(attr_cols) do
err_tbl[v][ri] = cfg[ri][v]
end
end end)

-- Printing
io.write("\norder ", o, " (max error = ", max_err, ", tol = ", cfg.run_tol, "):\n")
for _, col_name in ipairs(attr_cols) do
if not (col_name == "cfgid") then
local _, key_cnt = val2keys(err_tbl:getcol(col_name))
io.write(col_name, "\t= ", tostring(key_cnt), "\n")
end
create_dif({nam=out_dir(cfg.name)})
cfg.stop = true
return
end
end
end

local function get_prev_res(test_name)
-- Read the previous results
local cfg = mtable:read(out_dir(test_name.."_cfg.tfs"))
local res = mtable:read(out_dir(test_name.."_res.tfs"))
return cfg, res
end
-------------------------------------------------------------------------------o

-- Plot the results -----------------------------------------------------------o
local to_text = \str-> str:gsub("{", ""):gsub("}", ""):gsub("%$", "")
local plot_template = plot {
legendpos = "right top",
prolog = "reset",
ylabel = "Error",
yrange = {1e-16, 1e1},
wsizex = 1080,
wsizey = 720,
exec = false
}

local colours = {
"red", "blue", "green", "orange", "purple", "brown", "pink", "grey", "black"
}

local function plot_res(res_cfg_tbl, cfg, cfg_tbl_)
local plot_info = cfg.plot_info or {}
local x_axis_attr = plot_info.x_axis_attr or "${cfgid}"
local cfg_plot = plot_template {
sid = newSID(),
title = cfg.name,
output = plot_info.filename and plt_dir(plot_info.filename) or 2,
xlabel = to_text(x_axis_attr),
exec = false
}
local cfg_tbl = cfg_tbl_ or res_cfg_tbl
local res_tbl = res_cfg_tbl
local n_points = #res_tbl
local series = plot_info.series or {"${cfgid} > 0"}
local n_series = #series
local x_data = table.new(n_series, 0)
local y_data = table.new(n_series, 0)
for i = 1, n_series do
x_data[i], y_data[i] = {}, {}
local cnt = 1
for j = 1, n_points do
if loadstring("return " .. series[i] % cfg_tbl[j])() then
x_data[i][cnt] = loadstring("return " .. x_axis_attr % cfg_tbl[j])()
local y_value = 0
for i = 1, res_tbl.max_order do
y_value = max(y_value, res_tbl[j]["order"..i.."_eps"])
end
y_value = y_value * eps
y_data[i][cnt] = y_value > 1e-16 and y_value or 1e-16
cnt = cnt + 1
end
end
end

cfg_plot.data, cfg_plot.datastyles, cfg_plot.x1y1, cfg_plot.legend = {}, {}, {}, {}
for i = 1, n_series do
cfg_plot.data["x"..i] = x_data[i]
cfg_plot.data["y"..i] = y_data[i]
cfg_plot.x1y1["x"..i] = "y"..i
cfg_plot.datastyles["y"..i] = {
style = "points",
pointtype = i,
color = colours[i],
}
cfg_plot.legend["y"..i] = to_text(series[i])
end

plot_info.plotcfg =
[[
set logscale y
]] .. (plot_info.plotcfg or "")

cfg_plot (plot_info)
end

-- Run test -------------------------------------------------------------------o
local function run_test(cfg)
-- If the user does not want to run the test,
-- just show results from previous run
if not cfg.dorun then
local cfg_tbl, res_tbl = get_prev_res(cfg.name)
local cfg_tbl, res_tbl = get_prev_res(cfg.name, out_dir)
res_tbl.max_order = cfg.order
if cfg.doprnt then show_res(res_tbl, cfg_tbl, cfg_tbl:colnames(), cfg.tol) end
if cfg.doplot then plot_res(res_tbl, cfg, cfg_tbl) end
if cfg.doplot then plot_trk_res(res_tbl, cfg, plt_dir, cfg_tbl) end
return
end

Expand All @@ -378,18 +156,16 @@ local function run_test(cfg)
}

if cfg.doprnt then
io.write("Running ", cfg.name, " (tol = ", cfg.tol, ")\n")
-- Print the header
io.write("cfgid\t")
for i = 0, cfg.order do io.write("order "..i.."\t") end
io.write("Running ", cfg.name, " (tol = ", cfg.tol, ")\n", "cfgid\t")
for i = 0, cfg.order do io.write("order ", i, "\t") end
io.write("\n")
end

-- Fill the mtable with the cfg and results
gen_cfg(cfg, 1, \-> run_cfg(cfg, results))

-- Add the generator columns to the results table
add_gen_cols(results, cfg)
add_trk_gen_cols(results, cfg)

-- Decide whether to save the results
local dosave = cfg.dosave or not (
Expand Down Expand Up @@ -419,7 +195,7 @@ local function run_test(cfg)

-- Plot the results
if cfg.doplot then
plot_res(results, cfg)
plot_trk_res(results, cfg, plt_dir)
end

-- Cleanup excess files if the program is not stopped mid-test
Expand Down
Loading