local _ENV = mkmodule('utils') local df = df function getval(obj, ...) if type(obj) == 'function' then return obj(...) else return obj end end -- Comparator function function compare(a,b) if a < b then return -1 elseif a > b then return 1 else return 0 end end -- Sort strings; compare empty last function compare_name(a,b) if a == '' then if b == '' then return 0 else return 1 end elseif b == '' then return -1 else return compare(a,b) end end -- Make a field comparator function compare_field(field,cmp) cmp = cmp or compare if field then return function (a,b) return cmp(a[field],b[field]) end else return cmp end end -- Make a comparator of field vs key function compare_field_key(field,cmp) cmp = cmp or compare if field then return function (a,b) return cmp(a[field],b) end else return cmp end end function is_container(obj) return df.isvalid(obj) == 'ref' and obj._kind == 'container' end -- Make a sequence of numbers in 1..size function make_index_sequence(istart,iend) local index = {} for i=istart,iend do index[i-istart+1] = i end return index end --[[ Sort items in data according to ordering. Each ordering spec is a table with possible fields: * key = function(value) Computes comparison key from a data value. Not called on nil. * key_table = function(data) Computes a key table from the data table in one go. * compare = function(a,b) Comparison function. Defaults to compare above. Called on non-nil keys; nil sorts last. * nil_first If true, nil keys are sorted first instead of last. * reverse If true, sort non-nil keys in descending order. Returns a table of integer indices into data. --]] function make_sort_order(data,ordering) -- Compute sort keys and comparators local keys = {} local cmps = {} local size = data.n or #data for i=1,#ordering do local order = ordering[i] if order.key_table then keys[i] = order.key_table(data) elseif order.key then local kt = {} local kf = order.key for j=1,size do if data[j] == nil then kt[j] = nil else kt[j] = kf(data[j]) end end keys[i] = kt else keys[i] = data end cmps[i] = order.compare or compare end -- Make an order table local index = make_index_sequence(1,size) -- Sort the ordering table table.sort(index, function(ia,ib) for i=1,#keys do local ka = keys[i][ia] local kb = keys[i][ib] -- Sort nil keys to the end if ka == nil then if kb ~= nil then return ordering[i].nil_first end elseif kb == nil then return not ordering[i].nil_first else local cmpv = cmps[i](ka,kb) if ordering[i].reverse then cmpv = -cmpv end if cmpv < 0 then return true elseif cmpv > 0 then return false end end end return ia < ib -- this should ensure stable sort end) return index end --[[ Iterate a 'list' structure, e.g. df.global.world.job_list --]] local function next_df_list(s,link) link = link.next if link then return link, link.item end end function listpairs(list) return next_df_list, nil, list end --[[ Recursively assign data into a table. --]] function assign(tgt,src) if df.isvalid(tgt) == 'ref' then df.assign(tgt, src) elseif type(tgt) == 'table' then for k,v in pairs(src) do if type(v) == 'table' then local cv = tgt[k] if cv == nil then cv = {} tgt[k] = cv end assign(cv, v) else tgt[k] = v end end else error('Invalid assign target type: '..tostring(tgt)) end return tgt end local function copy_field(obj,k,v,deep) if v == nil then return NULL end if deep then local field = obj:_field(k) if field == v then return clone(v,deep) end end return v end -- Copy the object as lua data structures. function clone(obj,deep) if type(obj) == 'table' then if deep then return assign({},obj) else return copyall(obj) end elseif df.isvalid(obj) == 'ref' then local kind = obj._kind if kind == 'primitive' then return obj.value elseif kind == 'bitfield' then local rv = {} for k,v in pairs(obj) do rv[k] = v end return rv elseif kind == 'container' then local rv = {} for k,v in ipairs(obj) do rv[k+1] = copy_field(obj,k,v,deep) end return rv else -- struct local rv = {} for k,v in pairs(obj) do rv[k] = copy_field(obj,k,v,deep) end return rv end else return obj end end local function get_default(default,key,base) if type(default) == 'table' then local dv = default[key] if dv == nil then dv = default._default end if dv == nil then dv = base end return dv else return default end end -- Copy the object as lua data structures, skipping values matching defaults. function clone_with_default(obj,default,force) local rv = nil local function setrv(k,v) if v ~= nil then if rv == nil then rv = {} end rv[k] = v end end if default == nil then return nil elseif type(obj) == 'table' then for k,v in pairs(obj) do setrv(k, clone_with_default(v, get_default(default,k))) end elseif df.isvalid(obj) == 'ref' then local kind = obj._kind if kind == 'primitive' then return clone_with_default(obj.value,default,force) elseif kind == 'bitfield' then for k,v in pairs(obj) do setrv(k, clone_with_default(v, get_default(default,k,false))) end elseif kind == 'container' then for k,v in ipairs(obj) do setrv(k+1, clone_with_default(v, default, true)) end else -- struct for k,v in pairs(obj) do setrv(k, clone_with_default(v, get_default(default,k))) end end elseif obj == default and not force then return nil elseif obj == nil then return NULL else return obj end if force and rv == nil then rv = {} end return rv end -- Parse an integer value into a bitfield table function parse_bitfield_int(value, type_ref) if value == 0 then return nil end local res = {} for i,v in ipairs(type_ref) do if bit32.extract(value, i) ~= 0 then res[v] = true end end return res end -- List the enabled flag names in the bitfield table function list_bitfield_flags(bitfield, list) list = list or {} if bitfield then for name,val in pairs(bitfield) do if val then table.insert(list, name) end end end return list end -- Sort a vector or lua table function sort_vector(vector,field,cmp) local fcmp = compare_field(field,cmp) local scmp = function(a,b) return fcmp(a,b) < 0 end if df.isvalid(vector) then if vector._kind ~= 'container' then error('Container expected: '..tostring(vector)) end local items = clone(vector, true) table.sort(items, scmp) vector:assign(items) else table.sort(vector, scmp) end return vector end -- Linear search function linear_index(vector,key,field) local min,max if df.isvalid(vector) then min,max = 0,#vector-1 else min,max = 1,#vector end if field then for i=min,max do local obj = vector[i] if obj[field] == key then return i, obj end end else for i=min,max do local obj = vector[i] if obj == key then return i, obj end end end return nil end -- Binary search in a vector or lua table function binsearch(vector,key,field,cmp,min,max) if not(min and max) then if df.isvalid(vector) then min = -1 max = #vector else min = 0 max = #vector+1 end end local mf = math.floor local fcmp = compare_field_key(field,cmp) while true do local mid = mf((min+max)/2) if mid <= min then return nil, false, max end local item = vector[mid] local cv = fcmp(item, key) if cv == 0 then return item, true, mid elseif cv < 0 then min = mid else max = mid end end end -- Binary search and insert function insert_sorted(vector,item,field,cmp) local key = item if field and item then key = item[field] end local cur,found,pos = binsearch(vector,key,field,cmp) if found then return false,cur,pos else if df.isvalid(vector) then vector:insert(pos, item) else table.insert(vector, pos, item) end return true,vector[pos],pos end end -- Binary search, then insert or overwrite function insert_or_update(vector,item,field,cmp) local added,cur,pos = insert_sorted(vector,item,field,cmp) if not added then vector[pos] = item cur = vector[pos] end return added,cur,pos end -- Binary search and erase function erase_sorted_key(vector,key,field,cmp) local cur,found,pos = binsearch(vector,key,field,cmp) if found then if df.isvalid(vector) then vector:erase(pos) else table.remove(vector, pos) end end return found,cur,pos end function erase_sorted(vector,item,field,cmp) local key = item if field and item then key = item[field] end return erase_sorted_key(vector,key,field,cmp) end FILTER_FULL_TEXT = false function search_text(text, search_tokens) text = dfhack.toSearchNormalized(text) if type(search_tokens) ~= 'table' then search_tokens = search_tokens:split() end for _,search_token in ipairs(search_tokens) do if search_token == '' then goto continue end search_token = dfhack.toSearchNormalized(search_token:escape_pattern()) -- the separate checks for non-space or non-punctuation allows -- punctuation itself to be matched if that is useful (e.g. -- filenames or parameter names) if not FILTER_FULL_TEXT and not text:match('%f[^%p\x00]'..search_token) and not text:match('%f[^%s\x00]'..search_token) then return false elseif FILTER_FULL_TEXT and not text:find(search_token) then return false end ::continue:: end return true end -- Calls a method with a string temporary function call_with_string(obj,methodname,...) return dfhack.with_temp_object( df.new "string", function(str,obj,methodname,...) obj[methodname](obj,str,...) return str.value end, obj,methodname,... ) end function getBuildingName(building) return call_with_string(building, 'getName') end function getBuildingCenter(building) return xyz2pos(building.centerx, building.centery, building.z) end function getItemDescription(item,mode) return call_with_string(item, 'getItemDescription', mode or 0) end function getItemDescriptionPrefix(item,mode) return call_with_string(item, 'getItemDescriptionPrefix', mode or 0) end -- Split the string by the given delimiter function split_string(self, delimiter) return self:split(delimiter) end -- Ask a yes-no question function prompt_yes_no(msg,default) local prompt = msg if default == nil then prompt = prompt..' (y/n): ' elseif default then prompt = prompt..' (y/n)[y]: ' else prompt = prompt..' (y/n)[n]: ' end while true do local rv,err = dfhack.lineedit(prompt) if not rv then qerror(err); elseif string.match(rv,'^[Yy]') then return true elseif string.match(rv,'^[Nn]') then return false elseif rv == 'abort' then qerror('User abort') elseif rv == '' and default ~= nil then return default end end end -- Ask for input with check function function prompt_input(prompt,check,quit_str) quit_str = quit_str or '~~~' while true do local rv,err = dfhack.lineedit(prompt) if not rv then qerror(err); end if rv == quit_str then qerror('User abort') end local rtbl = table.pack(check(rv)) if rtbl[1] then return table.unpack(rtbl,2,rtbl.n) end end end function check_number(text) local nv = tonumber(text) return nv ~= nil, nv end -- Normalize directory separator slashes across platforms to '/' and collapse -- adjacent slashes into a single slash. local PLATFORM_SLASH = package.config:sub(1,1) function normalizePath(path) return path:gsub(PLATFORM_SLASH, '/'):gsub('/+', '/') end function invert(tab) local result = {} for k,v in pairs(tab) do result[v]=k end return result end -- processArgs() and processArgsGetopt() have been moved to argparse.lua. -- The 'require' statements are within the functions to avoid adding hard -- dependencies to utils.lua (which could lead to circular dependency issues). function processArgs(args, validArgs) return require('argparse').processArgs(args, validArgs) end function processArgsGetopt(args, optionActions) return require('argparse').processArgsGetopt(args, optionActions) end function fillTable(table1,table2) for k,v in pairs(table2) do table1[k] = v end end function unfillTable(table1,table2) for k,v in pairs(table2) do table1[k] = nil end end function df_shortcut_var(k) if k == 'scr' or k == 'screen' then return dfhack.gui.getCurViewscreen() elseif k == 'bld' or k == 'building' then return dfhack.gui.getSelectedBuilding() elseif k == 'item' then return dfhack.gui.getSelectedItem() elseif k == 'job' then return dfhack.gui.getSelectedJob() elseif k == 'wsjob' or k == 'workshop_job' then return dfhack.gui.getSelectedWorkshopJob() elseif k == 'unit' then return dfhack.gui.getSelectedUnit() elseif k == 'plant' then return dfhack.gui.getSelectedPlant() else for g in pairs(df.global) do if g == k then return df.global[k] end end return _G[k] end end function df_shortcut_env() local env = {} setmetatable(env, {__index = function(self, k) return df_shortcut_var(k) end}) return env end df_env = df_shortcut_env() function df_expr_to_ref(expr) expr = expr:gsub('%["(.-)"%]', function(field) return '.' .. field end) :gsub('%[\'(.-)\'%]', function(field) return '.' .. field end) :gsub('%[(%-?%d+)%]', function(field) return '.' .. field end) local parts = expr:split('.', true) local obj = df_env[parts[1]] for i = 2, #parts do local key = tonumber(parts[i]) or parts[i] if i == #parts then local ok, ret = pcall(function() return obj:_field(key) end) if ok then return ret end end obj = obj[key] end return obj end function addressof(obj) return select(2, df.sizeof(obj)) end function OrderedTable() -- store values in a separate table to ensure that __index and __newindex -- run on every table index operation local t = {} local key_to_index = {} local index_to_key = {} local mt = {} function mt:__index(k) return t[k] end function mt:__newindex(k, v) if not key_to_index[k] then table.insert(index_to_key, k) key_to_index[k] = #index_to_key end t[k] = v end function mt:__pairs() return function(_, k) if k then k = index_to_key[key_to_index[k] + 1] else k = index_to_key[1] end if k then return k, t[k] end end, nil, nil end local self = {} setmetatable(self, mt) return self end return _ENV