diff --git a/library/LuaTypes.cpp b/library/LuaTypes.cpp index 845d273e3..d496ba56f 100644 --- a/library/LuaTypes.cpp +++ b/library/LuaTypes.cpp @@ -707,15 +707,18 @@ static type_identity *find_primitive_field(lua_State *state, int field, const ch */ static int meta_primitive_index(lua_State *state) { - const char *attr = lua_tostring(state, -1); - if (strcmp(attr, "ref_target") == 0) { - const struct_field_info *field_info = get_object_ref_header(state, 1)->field_info; - if (field_info && field_info->extra && field_info->extra->ref_target) { - LookupInTable(state, field_info->extra->ref_target, &DFHACK_TYPEID_TABLE_TOKEN); - } else { - lua_pushnil(state); + if (lua_type(state, -1) == LUA_TSTRING) + { + const char *attr = lua_tostring(state, -1); + if (strcmp(attr, "ref_target") == 0) { + const struct_field_info *field_info = get_object_ref_header(state, 1)->field_info; + if (field_info && field_info->extra && field_info->extra->ref_target) { + LookupInTable(state, field_info->extra->ref_target, &DFHACK_TYPEID_TABLE_TOKEN); + } else { + lua_pushnil(state); + } + return 1; } - return 1; } uint8_t *ptr = get_object_addr(state, 1, 2, "read"); diff --git a/test/structures/primitive_refs.lua b/test/structures/primitive_refs.lua new file mode 100644 index 000000000..9de5bc980 --- /dev/null +++ b/test/structures/primitive_refs.lua @@ -0,0 +1,92 @@ +utils = require('utils') + +function with_temp_ref(...) + local args = {...} + local dtype = 'int32_t' + if type(args[1]) == 'string' then + dtype = table.remove(args, 1) + end + local count = 1 + if type(args[1]) == 'number' then + count = table.remove(args, 1) + end + local func = table.remove(args, 1) + assert(type(func) == 'function') + dfhack.with_temp_object(df.new(dtype, count), func, table.unpack(args)) +end + +function test.zero_initialize() + with_temp_ref(function(x) + expect.eq(x.value, 0) + end) +end + +function test.zero_initialize_array() + with_temp_ref(2, function(x) + expect.eq(x:_displace(1).value, 0) + end) +end + +function test.displace_zero() + with_temp_ref(function(x) + expect.eq(x, x:_displace(0)) + end) +end + +function test.displace_nonzero() + with_temp_ref(2, function(x) + expect.ne(x, x:_displace(1)) + expect.eq(utils.addressof(x) + x:sizeof(), utils.addressof(x:_displace(1))) + end) +end + +function test.displace_negative() + with_temp_ref(2, function(x) + expect.true_(x:_displace(1):_displace(-1).value) + end) +end + +function test.index_read() + with_temp_ref(function(x) + expect.eq(x.value, x[0]) + end) +end + +function test.index_write() + with_temp_ref(function(x) + x[0] = 1 + expect.eq(x.value, 1) + expect.eq(x[0], 1) + x.value = 2 + expect.eq(x.value, 2) + expect.eq(x[0], 2) + end) +end + +function test.index_write_multi() + local len = 3 + with_temp_ref(len, function(x) + for i = 0, len - 1 do + x[i] = i * i + end + for i = 0, len - 1 do + expect.eq(x[i], i * i) + end + end) +end + +function test.index_read_negative() + with_temp_ref(function(x) + expect.error_match(function() + expect.true_(x:_displace(1)[-1]) + end, 'negative index') + end) +end + +function test.index_write_negative() + with_temp_ref(function(x) + expect.error_match(function() + x:_displace(1)[-1] = 7 + end, 'negative index') + end) +end