dfhack/library/lua/utils.lua

780 lines
20 KiB
Lua

local _ENV = mkmodule('utils')
2012-05-08 02:55:06 -06:00
local df = df
2021-01-13 00:42:53 -07:00
local getopt = require('3rdparty.alt_getopt')
2012-05-08 02:55:06 -06:00
-- Comparator function
function compare(a,b)
if a < b then
return -1
elseif a > b then
return 1
else
return 0
end
end
-- Sort strings; compare empty last
function compare_name(a,b)
if a == '' then
if b == '' then
return 0
else
return 1
end
elseif b == '' then
return -1
else
return compare(a,b)
end
end
2012-05-08 02:55:06 -06:00
-- Make a field comparator
function compare_field(field,cmp)
cmp = cmp or compare
if field then
return function (a,b)
return cmp(a[field],b[field])
end
else
return cmp
end
end
-- Make a comparator of field vs key
function compare_field_key(field,cmp)
cmp = cmp or compare
if field then
return function (a,b)
return cmp(a[field],b)
end
else
return cmp
end
end
function is_container(obj)
return df.isvalid(obj) == 'ref' and obj._kind == 'container'
end
-- Make a sequence of numbers in 1..size
function make_index_sequence(istart,iend)
2012-05-08 02:55:06 -06:00
local index = {}
for i=istart,iend do
index[i-istart+1] = i
2012-05-08 02:55:06 -06:00
end
return index
end
--[[
Sort items in data according to ordering.
Each ordering spec is a table with possible fields:
* key = function(value)
Computes comparison key from a data value. Not called on nil.
* key_table = function(data)
Computes a key table from the data table in one go.
* compare = function(a,b)
Comparison function. Defaults to compare above.
Called on non-nil keys; nil sorts last.
* nil_first
If true, nil keys are sorted first instead of last.
* reverse
If true, sort non-nil keys in descending order.
Returns a table of integer indices into data.
--]]
function make_sort_order(data,ordering)
-- Compute sort keys and comparators
local keys = {}
local cmps = {}
local size = data.n or #data
for i=1,#ordering do
local order = ordering[i]
if order.key_table then
keys[i] = order.key_table(data)
elseif order.key then
local kt = {}
local kf = order.key
for j=1,size do
if data[j] == nil then
kt[j] = nil
else
kt[j] = kf(data[j])
end
end
keys[i] = kt
else
keys[i] = data
end
cmps[i] = order.compare or compare
end
-- Make an order table
local index = make_index_sequence(1,size)
-- Sort the ordering table
table.sort(index, function(ia,ib)
for i=1,#keys do
local ka = keys[i][ia]
local kb = keys[i][ib]
-- Sort nil keys to the end
if ka == nil then
if kb ~= nil then
return ordering[i].nil_first
end
elseif kb == nil then
return not ordering[i].nil_first
else
local cmpv = cmps[i](ka,kb)
if ordering[i].reverse then
cmpv = -cmpv
end
if cmpv < 0 then
return true
elseif cmpv > 0 then
return false
end
end
end
return ia < ib -- this should ensure stable sort
end)
return index
end
--[[
Iterate a 'list' structure, e.g. df.global.world.job_list
--]]
local function next_df_list(s,link)
link = link.next
if link then
return link, link.item
end
end
function listpairs(list)
return next_df_list, nil, list
end
--[[
Recursively assign data into a table.
--]]
function assign(tgt,src)
if df.isvalid(tgt) == 'ref' then
df.assign(tgt, src)
elseif type(tgt) == 'table' then
for k,v in pairs(src) do
2012-05-08 02:55:06 -06:00
if type(v) == 'table' then
local cv = tgt[k]
if cv == nil then
cv = {}
tgt[k] = cv
end
assign(cv, v)
else
tgt[k] = v
end
end
else
error('Invalid assign target type: '..tostring(tgt))
end
return tgt
end
2012-05-08 02:55:06 -06:00
local function copy_field(obj,k,v,deep)
if v == nil then
return NULL
end
if deep then
local field = obj:_field(k)
if field == v then
return clone(v,deep)
end
end
return v
end
-- Copy the object as lua data structures.
function clone(obj,deep)
if type(obj) == 'table' then
if deep then
return assign({},obj)
else
return copyall(obj)
end
elseif df.isvalid(obj) == 'ref' then
local kind = obj._kind
if kind == 'primitive' then
return obj.value
elseif kind == 'bitfield' then
local rv = {}
for k,v in pairs(obj) do
rv[k] = v
end
return rv
elseif kind == 'container' then
local rv = {}
for k,v in ipairs(obj) do
rv[k+1] = copy_field(obj,k,v,deep)
end
return rv
else -- struct
local rv = {}
for k,v in pairs(obj) do
rv[k] = copy_field(obj,k,v,deep)
end
return rv
end
else
return obj
end
end
local function get_default(default,key,base)
if type(default) == 'table' then
local dv = default[key]
if dv == nil then
dv = default._default
end
if dv == nil then
dv = base
end
return dv
else
return default
end
end
-- Copy the object as lua data structures, skipping values matching defaults.
function clone_with_default(obj,default,force)
local rv = nil
local function setrv(k,v)
if v ~= nil then
if rv == nil then
rv = {}
end
rv[k] = v
end
end
if default == nil then
return nil
elseif type(obj) == 'table' then
for k,v in pairs(obj) do
setrv(k, clone_with_default(v, get_default(default,k)))
end
elseif df.isvalid(obj) == 'ref' then
local kind = obj._kind
if kind == 'primitive' then
return clone_with_default(obj.value,default,force)
elseif kind == 'bitfield' then
for k,v in pairs(obj) do
setrv(k, clone_with_default(v, get_default(default,k,false)))
end
elseif kind == 'container' then
for k,v in ipairs(obj) do
setrv(k+1, clone_with_default(v, default, true))
end
else -- struct
for k,v in pairs(obj) do
setrv(k, clone_with_default(v, get_default(default,k)))
end
end
elseif obj == default and not force then
return nil
elseif obj == nil then
return NULL
else
return obj
end
if force and rv == nil then
rv = {}
end
return rv
end
-- Parse an integer value into a bitfield table
function parse_bitfield_int(value, type_ref)
if value == 0 then
return nil
end
local res = {}
for i,v in ipairs(type_ref) do
if bit32.extract(value, i) ~= 0 then
res[v] = true
end
end
return res
end
-- List the enabled flag names in the bitfield table
function list_bitfield_flags(bitfield, list)
list = list or {}
if bitfield then
for name,val in pairs(bitfield) do
if val then
table.insert(list, name)
end
end
end
return list
end
2012-05-08 02:55:06 -06:00
-- Sort a vector or lua table
function sort_vector(vector,field,cmp)
local fcmp = compare_field(field,cmp)
local scmp = function(a,b)
return fcmp(a,b) < 0
end
if df.isvalid(vector) then
if vector._kind ~= 'container' then
error('Container expected: '..tostring(vector))
end
local items = clone(vector, true)
table.sort(items, scmp)
vector:assign(items)
else
table.sort(vector, scmp)
end
return vector
end
-- Linear search
function linear_index(vector,key,field)
local min,max
if df.isvalid(vector) then
min,max = 0,#vector-1
else
min,max = 1,#vector
end
if field then
for i=min,max do
local obj = vector[i]
if obj[field] == key then
return i, obj
end
end
else
for i=min,max do
local obj = vector[i]
if obj == key then
return i, obj
end
end
end
return nil
end
2012-05-08 02:55:06 -06:00
-- Binary search in a vector or lua table
function binsearch(vector,key,field,cmp,min,max)
if not(min and max) then
if df.isvalid(vector) then
min = -1
max = #vector
else
min = 0
max = #vector+1
end
end
local mf = math.floor
local fcmp = compare_field_key(field,cmp)
while true do
local mid = mf((min+max)/2)
if mid <= min then
return nil, false, max
end
local item = vector[mid]
local cv = fcmp(item, key)
if cv == 0 then
return item, true, mid
elseif cv < 0 then
min = mid
else
max = mid
end
end
end
-- Binary search and insert
function insert_sorted(vector,item,field,cmp)
local key = item
if field and item then
key = item[field]
end
local cur,found,pos = binsearch(vector,key,field,cmp)
if found then
return false,cur,pos
else
if df.isvalid(vector) then
vector:insert(pos, item)
else
table.insert(vector, pos, item)
end
return true,vector[pos],pos
end
end
-- Binary search, then insert or overwrite
2012-05-08 09:08:34 -06:00
function insert_or_update(vector,item,field,cmp)
2012-05-08 02:55:06 -06:00
local added,cur,pos = insert_sorted(vector,item,field,cmp)
if not added then
vector[pos] = item
cur = vector[pos]
end
return added,cur,pos
end
-- Binary search and erase
function erase_sorted_key(vector,key,field,cmp)
local cur,found,pos = binsearch(vector,key,field,cmp)
if found then
if df.isvalid(vector) then
vector:erase(pos)
else
table.remove(vector, pos)
end
end
return found,cur,pos
end
function erase_sorted(vector,item,field,cmp)
local key = item
if field and item then
key = item[field]
end
return erase_sorted_key(vector,key,field,cmp)
end
-- Calls a method with a string temporary
function call_with_string(obj,methodname,...)
return dfhack.with_temp_object(
df.new "string",
function(str,obj,methodname,...)
obj[methodname](obj,str,...)
return str.value
end,
obj,methodname,...
)
end
function getBuildingName(building)
return call_with_string(building, 'getName')
end
function getBuildingCenter(building)
return xyz2pos(building.centerx, building.centery, building.z)
end
function getItemDescription(item,mode)
return call_with_string(item, 'getItemDescription', mode or 0)
end
function getItemDescriptionPrefix(item,mode)
return call_with_string(item, 'getItemDescriptionPrefix', mode or 0)
end
-- Split the string by the given delimiter
function split_string(self, delimiter)
local result = { }
local from = 1
local delim_from, delim_to = string.find( self, delimiter, from )
while delim_from do
table.insert( result, string.sub( self, from , delim_from-1 ) )
from = delim_to + 1
delim_from, delim_to = string.find( self, delimiter, from )
end
table.insert( result, string.sub( self, from ) )
return result
end
-- Ask a yes-no question
function prompt_yes_no(msg,default)
local prompt = msg
if default == nil then
prompt = prompt..' (y/n): '
elseif default then
2012-06-17 08:44:59 -06:00
prompt = prompt..' (y/n)[y]: '
else
2012-06-17 08:44:59 -06:00
prompt = prompt..' (y/n)[n]: '
end
while true do
local rv,err = dfhack.lineedit(prompt)
if not rv then
qerror(err);
elseif string.match(rv,'^[Yy]') then
return true
elseif string.match(rv,'^[Nn]') then
return false
elseif rv == 'abort' then
qerror('User abort')
elseif rv == '' and default ~= nil then
return default
end
end
end
-- Ask for input with check function
function prompt_input(prompt,check,quit_str)
quit_str = quit_str or '~~~'
while true do
local rv,err = dfhack.lineedit(prompt)
if not rv then
qerror(err);
end
if rv == quit_str then
qerror('User abort')
end
local rtbl = table.pack(check(rv))
if rtbl[1] then
return table.unpack(rtbl,2,rtbl.n)
end
end
end
function check_number(text)
local nv = tonumber(text)
return nv ~= nil, nv
end
function invert(tab)
local result = {}
for k,v in pairs(tab) do
result[v]=k
end
return result
end
function processArgs(args, validArgs)
--[[
standardized argument processing for scripts
-argName value
-argName [list of values]
-argName [list of [nested values] -that can be [whatever] format of matched square brackets]
-arg1 \-arg3
escape sequences
--]]
local result = {}
local argName
local bracketDepth = 0
for i,arg in ipairs(args) do
if argName then
if arg == '[' then
if bracketDepth > 0 then
table.insert(result[argName], arg)
end
bracketDepth = bracketDepth+1
elseif arg == ']' then
bracketDepth = bracketDepth-1
if bracketDepth > 0 then
table.insert(result[argName], arg)
else
argName = nil
end
elseif string.sub(arg,1,1) == '\\' then
if bracketDepth == 0 then
result[argName] = string.sub(arg,2)
argName = nil
else
table.insert(result[argName], string.sub(arg,2))
end
else
if bracketDepth == 0 then
result[argName] = arg
argName = nil
else
table.insert(result[argName], arg)
end
end
elseif string.sub(arg,1,1) == '-' then
argName = string.sub(arg,2)
if validArgs and not validArgs[argName] then
error('error: invalid arg: ' .. i .. ': ' .. argName)
end
if result[argName] then
error('duplicate arg: ' .. i .. ': ' .. argName)
end
if i+1 > #args or string.sub(args[i+1],1,1) == '-' then
result[argName] = ''
argName = nil
else
result[argName] = {}
end
else
error('error parsing arg ' .. i .. ': ' .. arg)
end
end
return result
end
-- processes commandline options according to optionActions and returns all
2021-01-13 00:42:53 -07:00
-- argument strings that are not options. Options and non-option strings can
-- appear in any order, and single-letter options that do not take arguments
-- can be combined into a single option string (e.g. '-abc' is the same as
-- '-a -b -c' if options 'a' and 'b' do not take arguments.
--
-- Numbers cannot be options and negative numbers (e.g. -10) will be interpreted
-- as positional parameters and returned in the nonoptions list.
--
-- optionActions is a vector with elements in the following format:
-- {shortOptionName, longOptionAlias, hasArg=boolean, handler=fn}
-- shortOptionName and handler are required. If the option takes an argument,
-- it will be passed to the handler function.
-- longOptionAlias is optional.
-- hasArgument defaults to false.
--
-- example usage:
--
-- local filename = nil
-- local open_readonly = false
2021-01-13 00:42:53 -07:00
-- local nonoptions = processArgsGetopt(args, {
-- {'r', handler=function() open_readonly = true end},
-- {'f', 'filename', hasArg=true,
-- handler=function(optarg) filename = optarg end}
-- })
--
2021-01-13 00:42:53 -07:00
-- when args is {'first', '-f', 'fname', 'second'} or, equivalently,
-- {'first', '--filename', 'fname', 'second'} (note the double dash in front of
-- the long option alias), then filename will be fname and nonoptions will
-- contain {'first', 'second'}.
function processArgsGetopt(args, optionActions)
local sh_opts, long_opts = '', {}
local handlers = {}
for _,optionAction in ipairs(optionActions) do
local sh_opt,long_opt = optionAction[1], optionAction[2]
if not sh_opt or type(sh_opt) ~= 'string' or #sh_opt ~= 1 then
error('optionAction missing option letter at index 1')
end
if not optionAction.handler then
error(string.format('handler missing for option "%s"', sh_opt))
end
sh_opts = sh_opts .. sh_opt
if optionAction.hasArg then sh_opts = sh_opts .. ':' end
handlers[sh_opt] = optionAction.handler
if long_opt then
long_opts[long_opt] = sh_opt
handlers[long_opt] = optionAction.handler
end
end
2021-01-13 00:42:53 -07:00
local opts, optargs, nonoptions =
getopt.get_ordered_opts(args, sh_opts, long_opts)
for i,v in ipairs(opts) do
handlers[v](optargs[i])
end
return nonoptions
end
2014-07-03 04:01:58 -06:00
function fillTable(table1,table2)
for k,v in pairs(table2) do
table1[k] = v
end
2014-07-03 04:01:58 -06:00
end
function unfillTable(table1,table2)
for k,v in pairs(table2) do
table1[k] = nil
end
end
function df_shortcut_var(k)
if k == 'scr' or k == 'screen' then
return dfhack.gui.getCurViewscreen()
elseif k == 'bld' or k == 'building' then
return dfhack.gui.getSelectedBuilding()
elseif k == 'item' then
return dfhack.gui.getSelectedItem()
elseif k == 'job' then
return dfhack.gui.getSelectedJob()
elseif k == 'wsjob' or k == 'workshop_job' then
return dfhack.gui.getSelectedWorkshopJob()
elseif k == 'unit' then
return dfhack.gui.getSelectedUnit()
elseif k == 'plant' then
return dfhack.gui.getSelectedPlant()
else
for g in pairs(df.global) do
if g == k then
return df.global[k]
end
end
return _G[k]
end
end
function df_shortcut_env()
local env = {}
setmetatable(env, {__index = function(self, k) return df_shortcut_var(k) end})
return env
2014-07-03 04:01:58 -06:00
end
df_env = df_shortcut_env()
function df_expr_to_ref(expr)
expr = expr:gsub('%["(.-)"%]', function(field) return '.' .. field end)
:gsub('%[\'(.-)\'%]', function(field) return '.' .. field end)
:gsub('%[(%d+)]', function(field) return '.' .. field end)
local parts = split_string(expr, '%.')
local obj = df_env[parts[1]]
for i = 2, #parts do
local key = tonumber(parts[i]) or parts[i]
local cur = obj[key]
2016-08-24 15:18:55 -06:00
if i == #parts and ((type(cur) ~= 'userdata') or
type(cur) == 'userdata' and getmetatable(cur) == nil) then
obj = obj:_field(key)
else
obj = obj[key]
end
end
return obj
end
function addressof(obj)
return select(2, df.sizeof(obj))
end
function OrderedTable()
-- store values in a separate table to ensure that __index and __newindex
-- run on every table index operation
local t = {}
local key_to_index = {}
local index_to_key = {}
local mt = {}
function mt:__index(k)
return t[k]
end
function mt:__newindex(k, v)
if not key_to_index[k] then
table.insert(index_to_key, k)
key_to_index[k] = #index_to_key
end
t[k] = v
end
function mt:__pairs()
return function(_, k)
if k then
k = index_to_key[key_to_index[k] + 1]
else
k = index_to_key[1]
end
if k then
return k, t[k]
end
end, nil, nil
end
local self = {}
setmetatable(self, mt)
return self
end
return _ENV