920 lines
26 KiB
Lua
920 lines
26 KiB
Lua
-- Common startup file for all dfhack scripts and plugins with lua support
|
|
-- The global dfhack table is already created by C++ init code.
|
|
|
|
-- Setup the global environment.
|
|
-- BASE_G is the original lua global environment,
|
|
-- preserved as a common denominator for all modules.
|
|
-- This file uses it instead of the new default one.
|
|
|
|
local dfhack = dfhack
|
|
local base_env = dfhack.BASE_G
|
|
local _ENV = base_env
|
|
|
|
CR_LINK_FAILURE = -3
|
|
CR_NEEDS_CONSOLE = -2
|
|
CR_NOT_IMPLEMENTED = -1
|
|
CR_OK = 0
|
|
CR_FAILURE = 1
|
|
CR_WRONG_USAGE = 2
|
|
CR_NOT_FOUND = 3
|
|
|
|
-- Console color constants
|
|
|
|
COLOR_RESET = -1
|
|
COLOR_BLACK = 0
|
|
COLOR_BLUE = 1
|
|
COLOR_GREEN = 2
|
|
COLOR_CYAN = 3
|
|
COLOR_RED = 4
|
|
COLOR_MAGENTA = 5
|
|
COLOR_BROWN = 6
|
|
COLOR_GREY = 7
|
|
COLOR_DARKGREY = 8
|
|
COLOR_LIGHTBLUE = 9
|
|
COLOR_LIGHTGREEN = 10
|
|
COLOR_LIGHTCYAN = 11
|
|
COLOR_LIGHTRED = 12
|
|
COLOR_LIGHTMAGENTA = 13
|
|
COLOR_YELLOW = 14
|
|
COLOR_WHITE = 15
|
|
|
|
-- Events
|
|
|
|
if dfhack.is_core_context then
|
|
SC_WORLD_LOADED = 0
|
|
SC_WORLD_UNLOADED = 1
|
|
SC_MAP_LOADED = 2
|
|
SC_MAP_UNLOADED = 3
|
|
SC_VIEWSCREEN_CHANGED = 4
|
|
SC_CORE_INITIALIZED = 5
|
|
SC_PAUSED = 7
|
|
SC_UNPAUSED = 8
|
|
end
|
|
|
|
-- Error handling
|
|
|
|
safecall = dfhack.safecall
|
|
curry = dfhack.curry
|
|
|
|
function dfhack.pcall(f, ...)
|
|
return xpcall(f, dfhack.onerror, ...)
|
|
end
|
|
|
|
function qerror(msg, level)
|
|
local name = dfhack.current_script_name()
|
|
if name and not tostring(msg):match(name) then
|
|
msg = name .. ': ' .. tostring(msg)
|
|
end
|
|
dfhack.error(msg, (level or 1) + 1, false)
|
|
end
|
|
|
|
function dfhack.with_finalize(...)
|
|
return dfhack.call_with_finalizer(0,true,...)
|
|
end
|
|
function dfhack.with_onerror(...)
|
|
return dfhack.call_with_finalizer(0,false,...)
|
|
end
|
|
|
|
local function call_delete(obj)
|
|
if obj then obj:delete() end
|
|
end
|
|
|
|
function dfhack.with_temp_object(obj,fn,...)
|
|
return dfhack.call_with_finalizer(1,true,call_delete,obj,fn,obj,...)
|
|
end
|
|
|
|
dfhack.exception.__index = dfhack.exception
|
|
|
|
-- Module loading
|
|
|
|
local function find_required_module_arg()
|
|
-- require -> module code -> mkmodule -> find_...
|
|
if debug.getinfo(4,'f').func == require then
|
|
return debug.getlocal(4, 1)
|
|
end
|
|
-- reload -> dofile -> module code -> mkmodule -> find_...
|
|
if debug.getinfo(5,'f').func == reload then
|
|
return debug.getlocal(5, 1)
|
|
end
|
|
end
|
|
|
|
function mkmodule(module,env)
|
|
-- Verify that the module name is correct
|
|
local _, rq_modname = find_required_module_arg()
|
|
if not rq_modname then
|
|
error('The mkmodule function must be used at the start of a module')
|
|
end
|
|
if rq_modname ~= module then
|
|
error('Found module '..module..' during require '..rq_modname)
|
|
end
|
|
-- Reuse the already loaded module table
|
|
local pkg = package.loaded[module]
|
|
if pkg == nil then
|
|
pkg = {}
|
|
else
|
|
if type(pkg) ~= 'table' then
|
|
error("Not a table in package.loaded["..module.."]")
|
|
end
|
|
end
|
|
-- Inject the plugin-exported functions when appropriate
|
|
local plugname = string.match(module,'^plugins%.([%w%-]+)$')
|
|
if plugname then
|
|
dfhack.open_plugin(pkg,plugname)
|
|
end
|
|
setmetatable(pkg, { __index = (env or base_env) })
|
|
return pkg
|
|
end
|
|
|
|
function reload(module)
|
|
if type(package.loaded[module]) ~= 'table' then
|
|
error("Module not loaded: "..module)
|
|
end
|
|
local path,err = package.searchpath(module,package.path)
|
|
if not path then
|
|
error(err)
|
|
end
|
|
dofile(path)
|
|
end
|
|
|
|
-- Trivial classes
|
|
|
|
function rawset_default(target,source)
|
|
for k,v in pairs(source) do
|
|
if rawget(target,k) == nil then
|
|
rawset(target,k,v)
|
|
end
|
|
end
|
|
end
|
|
|
|
DEFAULT_NIL = DEFAULT_NIL or {} -- Unique token
|
|
|
|
function defclass(...)
|
|
return require('class').defclass(...)
|
|
end
|
|
|
|
function mkinstance(...)
|
|
return require('class').mkinstance(...)
|
|
end
|
|
|
|
-- Misc functions
|
|
|
|
NEWLINE = "\n"
|
|
COMMA = ","
|
|
PERIOD = "."
|
|
|
|
local function _wrap_iterator(next_fn, ...)
|
|
local wrapped_iter = function(...)
|
|
local ret = {pcall(next_fn, ...)}
|
|
local ok = table.remove(ret, 1)
|
|
if ok then
|
|
return table.unpack(ret)
|
|
end
|
|
end
|
|
return wrapped_iter, ...
|
|
end
|
|
|
|
function safe_pairs(t, iterator_fn)
|
|
iterator_fn = iterator_fn or pairs
|
|
if (pcall(iterator_fn, t)) then
|
|
return _wrap_iterator(iterator_fn(t))
|
|
else
|
|
return function() end
|
|
end
|
|
end
|
|
|
|
-- calls elem_cb(k, v) for each element of the table
|
|
-- returns true if we iterated successfully, false if not
|
|
-- this differs from safe_pairs() above in that it only calls pcall() once per
|
|
-- full iteration and it returns whether iteration succeeded or failed.
|
|
local function safe_iterate(table, iterator_fn, elem_cb)
|
|
local function iterate()
|
|
for k,v in iterator_fn(table) do elem_cb(k, v) end
|
|
end
|
|
return pcall(iterate)
|
|
end
|
|
|
|
local function print_element(k, v)
|
|
dfhack.println(string.format("%-23s\t = %s", tostring(k), tostring(v)))
|
|
end
|
|
|
|
function printall(table)
|
|
safe_iterate(table, pairs, print_element)
|
|
end
|
|
|
|
function printall_ipairs(table)
|
|
safe_iterate(table, ipairs, print_element)
|
|
end
|
|
|
|
local do_print_recurse
|
|
|
|
local function print_string(printfn, v, seen, indent)
|
|
local str = tostring(v)
|
|
printfn(str)
|
|
return #str;
|
|
end
|
|
|
|
local fill_chars = {
|
|
__index = function(table, key, value)
|
|
local rv = string.rep(' ', 23 - key) .. ' = '
|
|
rawset(table, key, rv)
|
|
return rv
|
|
end,
|
|
}
|
|
|
|
setmetatable(fill_chars, fill_chars)
|
|
|
|
local function print_fields(value, seen, indent, prefix)
|
|
local prev_value = "not a value"
|
|
local repeated = 0
|
|
local print_field = function(k, v)
|
|
-- Only show set values of bitfields
|
|
if value._kind ~= "bitfield" or v then
|
|
local continue = false
|
|
if type(k) == "number" then
|
|
if prev_value == v then
|
|
repeated = repeated + 1
|
|
continue = true
|
|
else
|
|
prev_value = v
|
|
end
|
|
else
|
|
prev_value = "not a value"
|
|
end
|
|
if not continue then
|
|
if repeated > 0 then
|
|
dfhack.println(prefix .. "<Repeated " .. repeated .. " times>")
|
|
repeated = 0
|
|
end
|
|
dfhack.print(prefix)
|
|
local len = do_print_recurse(dfhack.print, k, seen, indent + 1)
|
|
dfhack.print(fill_chars[len <= 23 and len or 23])
|
|
do_print_recurse(dfhack.println, v, seen, indent + 1)
|
|
end
|
|
end
|
|
end
|
|
if not safe_iterate(value, pairs, print_field) then
|
|
dfhack.print(prefix)
|
|
dfhack.println('<Type doesn\'t support iteration with pairs>')
|
|
elseif repeated > 0 then
|
|
dfhack.println(prefix .. "<Repeated " .. repeated .. " times>")
|
|
end
|
|
return 0
|
|
end
|
|
|
|
-- This should be same as print_array but userdata doesn't compare equal even if
|
|
-- they hold same pointer.
|
|
local function print_userdata(printfn, value, seen, indent)
|
|
local prefix = string.rep(' ', indent)
|
|
local strvalue = tostring(value)
|
|
dfhack.println(strvalue)
|
|
if seen[strvalue] then
|
|
dfhack.print(prefix)
|
|
dfhack.println('<Cyclic reference! Skipping fields>\n')
|
|
return 0
|
|
end
|
|
seen[strvalue] = true
|
|
return print_fields(value, seen, indent, prefix)
|
|
end
|
|
|
|
local function print_array(printfn, value, seen, indent)
|
|
local prefix = string.rep(' ', indent)
|
|
dfhack.println(tostring(value))
|
|
if seen[value] then
|
|
dfhack.print(prefix)
|
|
dfhack.println('<Cyclic reference! skipping fields>\n')
|
|
return 0
|
|
end
|
|
seen[value] = true
|
|
return print_fields(value, seen, indent, prefix)
|
|
end
|
|
|
|
local recurse_type_map = {
|
|
number = print_string,
|
|
string = print_string,
|
|
boolean = print_string,
|
|
['function'] = print_string,
|
|
['nil'] = print_string,
|
|
userdata = print_userdata,
|
|
table = print_array,
|
|
}
|
|
|
|
do_print_recurse = function(printfn, value, seen, indent)
|
|
local t = type(value)
|
|
if not recurse_type_map[t] then
|
|
printfn("Unknown type " .. t .. " " .. tostring(value))
|
|
return
|
|
end
|
|
return recurse_type_map[t](printfn, value, seen, indent)
|
|
end
|
|
|
|
function printall_recurse(value, seen)
|
|
local seen = seen or {}
|
|
do_print_recurse(dfhack.println, value, seen, 0)
|
|
end
|
|
|
|
function copyall(table)
|
|
local rv = {}
|
|
for k,v in pairs(table) do rv[k] = v end
|
|
return rv
|
|
end
|
|
|
|
function pos2xyz(pos)
|
|
if pos then
|
|
local x = pos.x
|
|
if x and x ~= -30000 then
|
|
return x, pos.y, pos.z
|
|
end
|
|
end
|
|
end
|
|
|
|
function xyz2pos(x,y,z)
|
|
if x then
|
|
return {x=x,y=y,z=z}
|
|
else
|
|
return {x=-30000,y=-30000,z=-30000}
|
|
end
|
|
end
|
|
|
|
function same_xyz(a,b)
|
|
return a and b and a.x == b.x and a.y == b.y and a.z == b.z
|
|
end
|
|
|
|
function get_path_xyz(path,i)
|
|
return path.x[i], path.y[i], path.z[i]
|
|
end
|
|
|
|
function pos2xy(pos)
|
|
if pos then
|
|
local x = pos.x
|
|
if x and x ~= -30000 then
|
|
return x, pos.y
|
|
end
|
|
end
|
|
end
|
|
|
|
function xy2pos(x,y)
|
|
if x then
|
|
return {x=x,y=y}
|
|
else
|
|
return {x=-30000,y=-30000}
|
|
end
|
|
end
|
|
|
|
function same_xy(a,b)
|
|
return a and b and a.x == b.x and a.y == b.y
|
|
end
|
|
|
|
function get_path_xy(path,i)
|
|
return path.x[i], path.y[i]
|
|
end
|
|
|
|
function safe_index(obj,idx,...)
|
|
if obj == nil or idx == nil then
|
|
return nil
|
|
end
|
|
if type(idx) == 'number' and
|
|
type(obj) == 'userdata' and -- this check is only relevant for c++
|
|
(idx < 0 or idx >= #obj) then
|
|
return nil
|
|
end
|
|
obj = obj[idx]
|
|
if select('#',...) > 0 then
|
|
return safe_index(obj,...)
|
|
else
|
|
return obj
|
|
end
|
|
end
|
|
|
|
function ensure_key(t, key, default_value)
|
|
if t[key] == nil then
|
|
t[key] = (default_value ~= nil) and default_value or {}
|
|
end
|
|
return t[key]
|
|
end
|
|
|
|
-- String class extentions
|
|
|
|
-- prefix is a literal string, not a pattern
|
|
function string:startswith(prefix)
|
|
return self:sub(1, #prefix) == prefix
|
|
end
|
|
|
|
-- suffix is a literal string, not a pattern
|
|
function string:endswith(suffix)
|
|
return self:sub(-#suffix) == suffix or #suffix == 0
|
|
end
|
|
|
|
-- Split a string by the given delimiter. If no delimiter is specified, space
|
|
-- (' ') is used. The delimter is treated as a pattern unless a <plain> is
|
|
-- specified and set to true. To treat multiple successive delimiter characters
|
|
-- as a single delimiter, e.g. to avoid getting empty string elements, pass a
|
|
-- pattern like ' +'. Be aware that passing patterns that match empty strings
|
|
-- (like ' *') will result in improper string splits.
|
|
function string:split(delimiter, plain)
|
|
delimiter = delimiter or ' '
|
|
local result = {}
|
|
local from = 1
|
|
local delim_from, delim_to = self:find(delimiter, from, plain)
|
|
-- delim_from will be greater than delim_to when the delimiter matches the
|
|
-- empty string, which would lead to an infinite loop if we didn't check it
|
|
while delim_from and delim_from <= delim_to do
|
|
table.insert(result, self:sub(from, delim_from - 1))
|
|
from = delim_to + 1
|
|
delim_from, delim_to = self:find(delimiter, from, plain)
|
|
end
|
|
table.insert(result, self:sub(from))
|
|
return result
|
|
end
|
|
|
|
-- Removes spaces (i.e. everything that matches '%s') from the start and end of
|
|
-- a string. Spaces between non-space characters are left untouched.
|
|
function string:trim()
|
|
local _, _, content = self:find('^%s*(.-)%s*$')
|
|
return content
|
|
end
|
|
|
|
-- Inserts newlines into a string so no individual line exceeds the given width.
|
|
-- Lines are split at space-separated word boundaries. Any existing newlines are
|
|
-- kept in place. If a single word is longer than width, it is split over
|
|
-- multiple lines. If width is not specified, 72 is used.
|
|
function string:wrap(width)
|
|
width = width or 72
|
|
if width <= 0 then error('expected width > 0; got: '..tostring(width)) end
|
|
local wrapped_text = {}
|
|
for line in self:gmatch('[^\n]*') do
|
|
local line_start_pos = 1
|
|
local wrapped_line = line:gsub(
|
|
'%s*()(%S+)()',
|
|
function(start_pos, word, end_pos)
|
|
-- word fits within the current line
|
|
if end_pos - line_start_pos <= width then return end
|
|
-- word needs to go on the next line, but is not itself longer
|
|
-- than the specified width
|
|
if #word <= width then
|
|
line_start_pos = start_pos
|
|
return '\n' .. word
|
|
end
|
|
-- word is too long to fit on one line and needs to be split up
|
|
local num_chars, str = 0, start_pos == 1 and '' or '\n'
|
|
repeat
|
|
local word_frag = word:sub(num_chars + 1, num_chars + width)
|
|
str = str .. word_frag
|
|
num_chars = num_chars + #word_frag
|
|
if num_chars < #word then
|
|
str = str .. '\n'
|
|
end
|
|
line_start_pos = start_pos + num_chars
|
|
until end_pos - line_start_pos <= width
|
|
return str .. word:sub(num_chars + 1)
|
|
end)
|
|
table.insert(wrapped_text, wrapped_line)
|
|
end
|
|
return table.concat(wrapped_text, '\n')
|
|
end
|
|
|
|
-- Escapes regex special chars in a string. E.g. "a+b" -> "a%+b"
|
|
local regex_chars_pattern = '(['..('%^$()[].*+-?'):gsub('(.)', '%%%1')..'])'
|
|
function string:escape_pattern()
|
|
return self:gsub(regex_chars_pattern, '%%%1')
|
|
end
|
|
|
|
-- String conversions
|
|
|
|
function dfhack.persistent:__tostring()
|
|
return "<persistent "..self.entry_id..":"..self.key.."=\""
|
|
..self.value.."\":"..table.concat(self.ints,",")..">"
|
|
end
|
|
|
|
function dfhack.matinfo:__tostring()
|
|
return "<material "..self.type..":"..self.index.." "..self:getToken()..">"
|
|
end
|
|
|
|
dfhack.random.__index = dfhack.random
|
|
|
|
function dfhack.random:__tostring()
|
|
return "<random generator>"
|
|
end
|
|
|
|
dfhack.penarray.__index = dfhack.penarray
|
|
|
|
function dfhack.penarray.__tostring()
|
|
return "<penarray>"
|
|
end
|
|
|
|
function dfhack.maps.getSize()
|
|
local map = df.global.world.map
|
|
return map.x_count_block, map.y_count_block, map.z_count_block
|
|
end
|
|
|
|
function dfhack.maps.getTileSize()
|
|
local map = df.global.world.map
|
|
return map.x_count, map.y_count, map.z_count
|
|
end
|
|
|
|
function dfhack.buildings.getSize(bld)
|
|
local x, y = bld.x1, bld.y1
|
|
return bld.x2+1-x, bld.y2+1-y, bld.centerx-x, bld.centery-y
|
|
end
|
|
|
|
function dfhack.gui.getViewscreenByType(scr_type, n)
|
|
-- translated from modules/Gui.cpp
|
|
if n == nil then
|
|
n = 1
|
|
end
|
|
local limit = (n > 0)
|
|
local scr = dfhack.gui.getCurViewscreen()
|
|
while scr do
|
|
if limit then
|
|
n = n - 1
|
|
if n < 0 then
|
|
return nil
|
|
end
|
|
end
|
|
if scr_type:is_instance(scr) then
|
|
return scr
|
|
end
|
|
scr = scr.parent
|
|
end
|
|
end
|
|
|
|
-- Interactive
|
|
|
|
local print_banner = true
|
|
|
|
function dfhack.interpreter(prompt,hfile,env)
|
|
if not dfhack.is_interactive() then
|
|
return nil, 'not interactive'
|
|
end
|
|
|
|
print("Type quit to exit interactive lua interpreter.")
|
|
|
|
if print_banner then
|
|
print("Shortcuts:\n"..
|
|
" '= foo' => '_1,_2,... = foo'\n"..
|
|
" '! foo' => 'print(foo)'\n"..
|
|
" '~ foo' => 'printall(foo)'\n"..
|
|
" '^ foo' => 'printall_recurse(foo)'\n"..
|
|
" '@ foo' => 'printall_ipairs(foo)'\n"..
|
|
"All of these save the first result as '_'.")
|
|
print_banner = false
|
|
end
|
|
|
|
local prompt_str = "["..(prompt or 'lua').."]# "
|
|
local prompt_cont = string.rep(' ',#prompt_str-4)..">>> "
|
|
local prompt_env = {}
|
|
local cmdlinelist = {}
|
|
local t_prompt = nil
|
|
local vcnt = 1
|
|
|
|
local pfix_handlers = {
|
|
['!'] = function(data)
|
|
print(table.unpack(data,2,data.n))
|
|
end,
|
|
['~'] = function(data)
|
|
print(table.unpack(data,2,data.n))
|
|
printall(data[2])
|
|
end,
|
|
['@'] = function(data)
|
|
print(table.unpack(data,2,data.n))
|
|
printall_ipairs(data[2])
|
|
end,
|
|
['^'] = function(data)
|
|
printall_recurse(data[2])
|
|
end,
|
|
['='] = function(data)
|
|
for i=2,data.n do
|
|
local varname = '_'..vcnt
|
|
prompt_env[varname] = data[i]
|
|
dfhack.print(varname..' = ')
|
|
safecall(print, data[i])
|
|
vcnt = vcnt + 1
|
|
end
|
|
end
|
|
}
|
|
|
|
setmetatable(prompt_env, { __index = env or _G })
|
|
|
|
while true do
|
|
local cmdline = dfhack.lineedit(t_prompt or prompt_str, hfile)
|
|
|
|
if cmdline == nil or cmdline == 'quit' then
|
|
break
|
|
elseif cmdline ~= '' then
|
|
local pfix = string.sub(cmdline,1,1)
|
|
|
|
if not t_prompt and pfix_handlers[pfix] then
|
|
cmdline = 'return '..string.sub(cmdline,2)
|
|
else
|
|
pfix = nil
|
|
end
|
|
|
|
table.insert(cmdlinelist,cmdline)
|
|
cmdline = table.concat(cmdlinelist,'\n')
|
|
|
|
local code,err = load(cmdline, '=(interactive)', 't', prompt_env)
|
|
|
|
if code == nil then
|
|
if not pfix and err:sub(-5)=="<eof>" then
|
|
t_prompt=prompt_cont
|
|
else
|
|
dfhack.printerr(err)
|
|
cmdlinelist={}
|
|
t_prompt=nil
|
|
end
|
|
else
|
|
cmdlinelist={}
|
|
t_prompt=nil
|
|
|
|
local data = table.pack(safecall(code))
|
|
|
|
if data[1] and data.n > 1 then
|
|
prompt_env._ = data[2]
|
|
safecall(pfix_handlers[pfix], data)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
return true
|
|
end
|
|
|
|
-- Command scripts
|
|
|
|
local internal = dfhack.internal
|
|
|
|
Script = defclass(Script)
|
|
function Script:init(path)
|
|
self.path = path
|
|
self.mtime = dfhack.filesystem.mtime(path)
|
|
self._flags = {}
|
|
end
|
|
function Script:needs_update()
|
|
return (not self.env) or self.mtime ~= dfhack.filesystem.mtime(self.path)
|
|
end
|
|
function Script:get_flags()
|
|
local mtime = dfhack.filesystem.mtime(self.path)
|
|
if self.flags_mtime ~= mtime then
|
|
self.flags_mtime = mtime
|
|
self._flags = {}
|
|
local f = io.open(self.path)
|
|
local contents = f:read('*all')
|
|
f:close()
|
|
for line in contents:gmatch('%-%-@([^\n]+)') do
|
|
local chunk = load(line, self.path, 't', self._flags)
|
|
if chunk then
|
|
chunk()
|
|
else
|
|
dfhack.printerr('Parse error: ' .. line)
|
|
end
|
|
end
|
|
end
|
|
return self._flags
|
|
end
|
|
|
|
internal.scripts = internal.scripts or {}
|
|
|
|
local hack_path = dfhack.getHackPath()
|
|
|
|
function dfhack.findScript(name)
|
|
return dfhack.internal.findScript(name .. '.lua')
|
|
end
|
|
|
|
local valid_script_flags = {
|
|
enable = {required = true, error = 'Does not recognize enable/disable commands'},
|
|
enable_state = {required = false},
|
|
module = {
|
|
required = function(flags)
|
|
if flags.module_strict == false then return false end
|
|
return true
|
|
end,
|
|
error = 'Cannot be used as a module'
|
|
},
|
|
module_strict = {required = false},
|
|
alias = {required = false},
|
|
alias_count = {required = false},
|
|
scripts = {required = false},
|
|
}
|
|
|
|
function dfhack.run_script(name,...)
|
|
return dfhack.run_script_with_env(nil, name, nil, ...)
|
|
end
|
|
|
|
function dfhack.enable_script(name, state)
|
|
local res, err = dfhack.pcall(dfhack.run_script_with_env, nil, name, {enable=true, enable_state=state})
|
|
if not res then
|
|
dfhack.printerr(err.message)
|
|
qerror(('Cannot %s Lua script: %s'):format(state and 'enable' or 'disable', name))
|
|
end
|
|
end
|
|
|
|
function dfhack.reqscript(name)
|
|
return dfhack.script_environment(name, true)
|
|
end
|
|
reqscript = dfhack.reqscript
|
|
|
|
function dfhack.script_environment(name, strict)
|
|
local scripts = internal.scripts
|
|
local path = dfhack.findScript(name)
|
|
if not scripts[path] or scripts[path]:needs_update() then
|
|
local _, env = dfhack.run_script_with_env(nil, name, {
|
|
module=true,
|
|
module_strict=(strict and true or false) -- ensure that this key is present if 'strict' is nil
|
|
})
|
|
return env
|
|
else
|
|
if strict and not scripts[path]:get_flags().module then
|
|
error(('%s: %s'):format(name, valid_script_flags.module.error))
|
|
end
|
|
return scripts[path].env
|
|
end
|
|
end
|
|
|
|
function dfhack.run_script_with_env(envVars, name, flags, ...)
|
|
if type(flags) ~= 'table' then flags = {} end
|
|
local file = dfhack.findScript(name)
|
|
if not file then
|
|
error('Could not find script ' .. name)
|
|
end
|
|
|
|
local scripts = flags.scripts or internal.scripts
|
|
if scripts[file] == nil then
|
|
scripts[file] = Script(file)
|
|
end
|
|
local script_flags = scripts[file]:get_flags()
|
|
if script_flags.alias then
|
|
flags.alias_count = (flags.alias_count or 0) + 1
|
|
if flags.alias_count > 10 then
|
|
error('Too many script aliases: ' .. flags.alias_count)
|
|
end
|
|
return dfhack.run_script_with_env(envVars, script_flags.alias, flags, ...)
|
|
end
|
|
for flag, value in pairs(flags) do
|
|
if value then
|
|
local v = valid_script_flags[flag]
|
|
if not v then
|
|
error('Invalid flag: ' .. flag)
|
|
elseif ((type(v.required) == 'boolean' and v.required) or
|
|
(type(v.required) == 'function' and v.required(flags))) then
|
|
if not script_flags[flag] then
|
|
local msg = v.error or 'Flag "' .. flag .. '" not recognized'
|
|
error(name .. ': ' .. msg)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
local env = scripts[file].env
|
|
if env == nil then
|
|
env = {}
|
|
setmetatable(env, { __index = base_env })
|
|
end
|
|
for x,y in pairs(envVars or {}) do
|
|
env[x] = y
|
|
end
|
|
env.dfhack_flags = flags
|
|
env.moduleMode = flags.module
|
|
local script_code
|
|
local perr
|
|
local time = dfhack.filesystem.mtime(file)
|
|
if time == scripts[file].mtime and scripts[file].run then
|
|
script_code = scripts[file].run
|
|
else
|
|
--reload
|
|
script_code, perr = loadfile(file, 't', env)
|
|
if not script_code then
|
|
error(perr)
|
|
end
|
|
-- avoid updating mtime if the script failed to load
|
|
scripts[file].mtime = time
|
|
end
|
|
scripts[file].env = env
|
|
scripts[file].run = script_code
|
|
local args = {...}
|
|
for i,v in ipairs(args) do
|
|
args[i] = tostring(v) -- ensure passed parameters are strings
|
|
end
|
|
return script_code(table.unpack(args)), env
|
|
end
|
|
|
|
function dfhack.current_script_name()
|
|
local frame = 1
|
|
while true do
|
|
local info = debug.getinfo(frame, 'f')
|
|
if not info then break end
|
|
if info.func == dfhack.run_script_with_env then
|
|
local i = 1
|
|
while true do
|
|
local name, value = debug.getlocal(frame, i)
|
|
if not name then break end
|
|
if name == 'name' then
|
|
return value
|
|
end
|
|
i = i + 1
|
|
end
|
|
break
|
|
end
|
|
frame = frame + 1
|
|
end
|
|
end
|
|
|
|
function dfhack.script_help(script_name, extension)
|
|
script_name = script_name or dfhack.current_script_name()
|
|
return require('helpdb').get_entry_long_help(script_name)
|
|
end
|
|
|
|
local function _run_command(args, use_console)
|
|
if type(args[1]) == 'table' then
|
|
command = args[1]
|
|
elseif #args > 1 and type(args[2]) == 'table' then
|
|
-- {args[1]} + args[2]
|
|
command = args[2]
|
|
table.insert(command, 1, args[1])
|
|
elseif #args == 1 and type(args[1]) == 'string' then
|
|
command = args[1]
|
|
elseif #args > 1 and type(args[1]) == 'string' then
|
|
command = args
|
|
else
|
|
error('Invalid arguments')
|
|
end
|
|
return internal.runCommand(command, use_console)
|
|
end
|
|
|
|
function dfhack.run_command_silent(...)
|
|
local result = _run_command({...})
|
|
local output = ""
|
|
for i, f in pairs(result) do
|
|
if type(f) == 'table' then
|
|
output = output .. f[2]
|
|
end
|
|
end
|
|
return output, result.status
|
|
end
|
|
|
|
function dfhack.run_command(...)
|
|
local result = _run_command({...}, true)
|
|
return result.status
|
|
end
|
|
|
|
-- Per-save init file
|
|
|
|
function dfhack.getSavePath()
|
|
if dfhack.isWorldLoaded() then
|
|
return dfhack.getDFPath() .. '/data/save/' .. df.global.world.cur_savegame.save_dir
|
|
end
|
|
end
|
|
|
|
if dfhack.is_core_context then
|
|
local function loadInitFile(path, name)
|
|
local env = setmetatable({ SAVE_PATH = path }, { __index = base_env })
|
|
local f,perr = loadfile(name, 't', env)
|
|
if f == nil then
|
|
if dfhack.filesystem.exists(name) then
|
|
dfhack.printerr(perr)
|
|
end
|
|
elseif safecall(f) then
|
|
if not internal.save_init then
|
|
internal.save_init = {}
|
|
end
|
|
table.insert(internal.save_init, env)
|
|
end
|
|
end
|
|
|
|
dfhack.onStateChange.DFHACK_PER_SAVE = function(op)
|
|
if op == SC_WORLD_LOADED or op == SC_WORLD_UNLOADED then
|
|
if internal.save_init then
|
|
for k,v in ipairs(internal.save_init) do
|
|
if v.onUnload then
|
|
safecall(v.onUnload)
|
|
end
|
|
end
|
|
internal.save_init = nil
|
|
end
|
|
|
|
local path = dfhack.getSavePath()
|
|
|
|
if path and op == SC_WORLD_LOADED then
|
|
loadInitFile(path, path..'/raw/init.lua')
|
|
|
|
local dirlist = dfhack.internal.getDir(path..'/raw/init.d/')
|
|
if dirlist then
|
|
table.sort(dirlist)
|
|
for i,name in ipairs(dirlist) do
|
|
if string.match(name,'%.lua$') then
|
|
loadInitFile(path, path..'/raw/init.d/'..name)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
elseif internal.save_init then
|
|
for k,v in ipairs(internal.save_init) do
|
|
if v.onStateChange then
|
|
safecall(v.onStateChange, op)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
-- Feed the table back to the require() mechanism.
|
|
return dfhack
|