dfhack/library/lua/utils.lua

302 lines
7.0 KiB
Lua

local _ENV = mkmodule('utils')
local df = df
-- Comparator function
function compare(a,b)
if a < b then
return -1
elseif a > b then
return 1
else
return 0
end
end
-- Sort strings; compare empty last
function compare_name(a,b)
if a == '' then
if b == '' then
return 0
else
return 1
end
elseif b == '' then
return -1
else
return compare(a,b)
end
end
-- Make a field comparator
function compare_field(field,cmp)
cmp = cmp or compare
if field then
return function (a,b)
return cmp(a[field],b[field])
end
else
return cmp
end
end
-- Make a comparator of field vs key
function compare_field_key(field,cmp)
cmp = cmp or compare
if field then
return function (a,b)
return cmp(a[field],b)
end
else
return cmp
end
end
function is_container(obj)
return df.isvalid(obj) == 'ref' and obj._kind == 'container'
end
-- Make a sequence of numbers in 1..size
function make_index_sequence(size)
local index = {}
for i=1,size do
index[i] = i
end
return index
end
--[[
Sort items in data according to ordering.
Each ordering spec is a table with possible fields:
* key = function(value)
Computes comparison key from a data value. Not called on nil.
* key_table = function(data)
Computes a key table from the data table in one go.
* compare = function(a,b)
Comparison function. Defaults to compare above.
Called on non-nil keys; nil sorts last.
* nil_first
If true, nil keys are sorted first instead of last.
* reverse
If true, sort non-nil keys in descending order.
Returns a table of integer indices into data.
--]]
function make_sort_order(data,ordering)
-- Compute sort keys and comparators
local keys = {}
local cmps = {}
local size = data.n or #data
for i=1,#ordering do
local order = ordering[i]
if order.key_table then
keys[i] = order.key_table(data)
elseif order.key then
local kt = {}
local kf = order.key
for j=1,size do
if data[j] == nil then
kt[j] = nil
else
kt[j] = kf(data[j])
end
end
keys[i] = kt
else
keys[i] = data
end
cmps[i] = order.compare or compare
end
-- Make an order table
local index = make_index_sequence(size)
-- Sort the ordering table
table.sort(index, function(ia,ib)
for i=1,#keys do
local ka = keys[i][ia]
local kb = keys[i][ib]
-- Sort nil keys to the end
if ka == nil then
if kb ~= nil then
return ordering[i].nil_first
end
elseif kb == nil then
return not ordering[i].nil_first
else
local cmpv = cmps[i](ka,kb)
if ordering[i].reverse then
cmpv = -cmpv
end
if cmpv < 0 then
return true
elseif cmpv > 0 then
return false
end
end
end
return ia < ib -- this should ensure stable sort
end)
return index
end
--[[
Recursively assign data into a table.
--]]
function assign(tgt,src)
if df.isvalid(tgt) == 'ref' then
df.assign(tgt, src)
elseif type(tgt) == 'table' then
for k,v in pairs(src) do
if type(v) == 'table' then
local cv = tgt[k]
if cv == nil then
cv = {}
tgt[k] = cv
end
assign(cv, v)
else
tgt[k] = v
end
end
else
error('Invalid assign target type: '..tostring(tgt))
end
return tgt
end
local function copy_field(obj,k,v,deep)
if v == nil then
return NULL
end
if deep then
local field = obj:_field(k)
if field == v then
return clone(v,deep)
end
end
return v
end
-- Copy the object as lua data structures.
function clone(obj,deep)
if type(obj) == 'table' then
if deep then
return assign({},obj)
else
return copyall(obj)
end
elseif df.isvalid(obj) == 'ref' then
local kind = obj._kind
if kind == 'primitive' then
return obj.value
elseif kind == 'bitfield' then
local rv = {}
for k,v in pairs(obj) do
rv[k] = v
end
return rv
elseif kind == 'container' then
local rv = {}
for k,v in ipairs(obj) do
rv[k+1] = copy_field(obj,k,v,deep)
end
return rv
else -- struct
local rv = {}
for k,v in pairs(obj) do
rv[k] = copy_field(obj,k,v,deep)
end
return rv
end
else
return obj
end
end
-- Sort a vector or lua table
function sort_vector(vector,field,cmp)
local fcmp = compare_field(field,cmp)
local scmp = function(a,b)
return fcmp(a,b) < 0
end
if df.isvalid(vector) then
if vector._kind ~= 'container' then
error('Container expected: '..tostring(vector))
end
local items = clone(vector, true)
table.sort(items, scmp)
vector:assign(items)
else
table.sort(vector, scmp)
end
return vector
end
-- Binary search in a vector or lua table
function binsearch(vector,key,field,cmp,min,max)
if not(min and max) then
if df.isvalid(vector) then
min = -1
max = #vector
else
min = 0
max = #vector+1
end
end
local mf = math.floor
local fcmp = compare_field_key(field,cmp)
while true do
local mid = mf((min+max)/2)
if mid <= min then
return nil, false, max
end
local item = vector[mid]
local cv = fcmp(item, key)
if cv == 0 then
return item, true, mid
elseif cv < 0 then
min = mid
else
max = mid
end
end
end
-- Binary search and insert
function insert_sorted(vector,item,field,cmp)
local key = item
if field and item then
key = item[field]
end
local cur,found,pos = binsearch(vector,key,field,cmp)
if found then
return false,cur,pos
else
if df.isvalid(vector) then
vector:insert(pos, item)
else
table.insert(vector, pos, item)
end
return true,vector[pos],pos
end
end
-- Binary search, then insert or overwrite
function insert_or_update(vector,item,field,cmp)
local added,cur,pos = insert_sorted(vector,item,field,cmp)
if not added then
vector[pos] = item
cur = vector[pos]
end
return added,cur,pos
end
return _ENV