diff --git a/library/DataDefs.cpp b/library/DataDefs.cpp index 28761efa3..f376edc6a 100644 --- a/library/DataDefs.cpp +++ b/library/DataDefs.cpp @@ -159,6 +159,14 @@ enum_identity::enum_identity(size_t size, } } +enum_identity::enum_identity(enum_identity *base_enum, type_identity *override_base_type) + : enum_identity(override_base_type->byte_size(), base_enum->getScopeParent(), + base_enum->getName(), override_base_type, base_enum->first_item_value, + base_enum->last_item_value, base_enum->keys, base_enum->complex, + base_enum->attrs, base_enum->attr_type) +{ +} + enum_identity::ComplexData::ComplexData(std::initializer_list values) { size_t i = 0; @@ -455,18 +463,22 @@ void DFHack::flagarrayToString(std::vector *pvec, const void *p, } } -static const struct_field_info *find_union_tag_candidate(const struct_field_info *fields, const struct_field_info *union_field) +static const struct_field_info *find_union_tag_candidate(struct_identity *structure, const struct_field_info *union_field) { if (union_field->extra && union_field->extra->union_tag_field) { auto defined_field_name = union_field->extra->union_tag_field; - for (auto field = fields; field->mode != struct_field_info::END; field++) + for (auto p = structure; p; p = p->getParent()) { - if (!strcmp(field->name, defined_field_name)) + for (auto field = p->getFields(); field && field->mode != struct_field_info::END; field++) { - return field; + if (!strcmp(field->name, defined_field_name)) + { + return field; + } } } + return nullptr; } @@ -476,32 +488,32 @@ static const struct_field_info *find_union_tag_candidate(const struct_field_info name.erase(name.length() - 4, 4); name += "type"; - for (auto field = fields; field->mode != struct_field_info::END; field++) + for (auto p = structure; p; p = p->getParent()) { - if (field->name == name) + for (auto field = p->getFields(); field && field->mode != struct_field_info::END; field++) { - return field; + if (field->name == name) + { + return field; + } } } } - if (name.length() > 7 && - name.substr(name.length() - 7) == "_target" && - fields != union_field && - (union_field - 1)->name == name.substr(0, name.length() - 7)) - { - return union_field - 1; - } - - return union_field + 1; + return nullptr; } -const struct_field_info *DFHack::find_union_tag(const struct_field_info *fields, const struct_field_info *union_field) +const struct_field_info *DFHack::find_union_tag(struct_identity *structure, const struct_field_info *union_field) { - CHECK_NULL_POINTER(fields); + CHECK_NULL_POINTER(structure); CHECK_NULL_POINTER(union_field); - auto tag_candidate = find_union_tag_candidate(fields, union_field); + auto tag_candidate = find_union_tag_candidate(structure, union_field); + + if (!tag_candidate) + { + return nullptr; + } if (union_field->mode == struct_field_info::SUBSTRUCT && union_field->type && diff --git a/library/LuaTypes.cpp b/library/LuaTypes.cpp index 740b3f689..ef69958e0 100644 --- a/library/LuaTypes.cpp +++ b/library/LuaTypes.cpp @@ -616,6 +616,16 @@ static int meta_struct_index(lua_State *state) if (!field) return 1; read_field(state, field, ptr + field->offset); + if (field->mode == struct_field_info::SUBSTRUCT || field->mode == struct_field_info::CONTAINER) + { + auto struct_type = (struct_identity*)get_object_identity(state, 1, "read", false); + if (auto tag_field = find_union_tag(struct_type, field)) + { + get_object_ref_header(state, -1)->tag_ptr = ptr + tag_field->offset; + get_object_ref_header(state, -1)->tag_identity = tag_field->type; + get_object_ref_header(state, -1)->tag_attr = field->extra ? field->extra->union_tag_attr : nullptr; + } + } return 1; } @@ -631,6 +641,16 @@ static int meta_struct_field_reference(lua_State *state) if (!field) field_error(state, 2, "builtin property or method", "reference"); field_reference(state, field, ptr + field->offset); + if (field->mode == struct_field_info::SUBSTRUCT || field->mode == struct_field_info::CONTAINER) + { + auto struct_type = (struct_identity*)get_object_identity(state, 1, "reference", false); + if (auto tag_field = find_union_tag(struct_type, field)) + { + get_object_ref_header(state, -1)->tag_ptr = ptr + tag_field->offset; + get_object_ref_header(state, -1)->tag_identity = tag_field->type; + get_object_ref_header(state, -1)->tag_attr = field->extra ? field->extra->union_tag_attr : nullptr; + } + } return 1; } @@ -679,6 +699,81 @@ static int meta_struct_next(lua_State *state) return 2; } +/** + * Metamethod: iterator for unions. + */ +static int meta_union_next(lua_State *state) +{ + if (lua_gettop(state) < 2) lua_pushnil(state); + + int len = lua_rawlen(state, UPVAL_FIELDTABLE); + int idx = cur_iter_index(state, len+1, 2, 0); + if (idx == len) + return 0; + + auto header = get_object_ref_header(state, 1); + if (header->tag_ptr) + { + if (idx != 0) + return 0; + + auto enum_id = (enum_identity*)header->tag_identity; + auto tag_val = *(int64_t*)header->tag_ptr; + size_t tag_shift = 64 - 8 * enum_id->byte_size(); + tag_val <<= tag_shift; + tag_val >>= tag_shift; + + size_t tag_index = tag_val - enum_id->getFirstItem(); + if (auto complex = enum_id->getComplex()) + tag_index = complex->value_index_map.count(tag_val) ? complex->value_index_map.at(tag_val) : size_t(-1); + + if (tag_index >= size_t(enum_id->getCount())) + return 0; + + const char *tag_name = nullptr; + if (header->tag_attr) + { + for (auto enum_field = enum_id->getAttrType()->getFields(); enum_field->mode != struct_field_info::END; enum_field++) + { + if (!strcmp(enum_field->name, header->tag_attr)) + { + if (enum_field->type == df::identity_traits::get()) + { + auto attrs = ((uint8_t*)enum_id->getAttrs()) + (tag_index * enum_id->getAttrType()->byte_size()); + tag_name = *(const char **)(attrs + enum_field->offset); + } + break; + } + } + } + else + { + tag_name = enum_id->getKeys()[tag_index]; + } + + if (!tag_name) + return 0; + + lua_getfield(state, UPVAL_FIELDTABLE, tag_name); + if (lua_isnil(state, lua_gettop(state))) + { + lua_pop(state, 1); + return 0; + } + + lua_pop(state, 1); + lua_pushstring(state, tag_name); + lua_getfield(state, 1, tag_name); + + return 2; + } + + lua_rawgeti(state, UPVAL_FIELDTABLE, idx+1); + lua_dup(state); + lua_gettable(state, 1); + return 2; +} + /** * Field lookup for primitive refs: behave as a quasi-array with numeric indices. */ @@ -806,6 +901,25 @@ static int check_container_index(lua_State *state, int len, return idx; } +static void attach_container_item_tagged_union(lua_State *state, int container, int item, int idx) +{ + auto header = get_object_ref_header(state, container); + if (header->tag_ptr) + { + // TODO: handle bit vector for tag + + auto tag_container = (container_identity*)header->tag_identity; + + auto ref = get_object_ref_header(state, item); + + // on both msvc and gcc, vectors have the same memory layout + auto item_type = tag_container->getItemType(); + ref->tag_ptr = (*(uint8_t**)header->tag_ptr) + size_t(idx) * item_type->byte_size(); + ref->tag_identity = item_type; + ref->tag_attr = header->tag_attr; + } +} + /** * Metamethod: __index for containers. */ @@ -820,6 +934,7 @@ static int meta_container_index(lua_State *state) int len = id->lua_item_count(state, ptr, container_identity::COUNT_READ); int idx = check_container_index(state, len, 2, iidx, "read"); id->lua_item_read(state, 2, ptr, idx); + attach_container_item_tagged_union(state, 1, -1, idx); return 1; } @@ -837,6 +952,7 @@ static int meta_container_field_reference(lua_State *state) int len = id->lua_item_count(state, ptr, container_identity::COUNT_WRITE); int idx = check_container_index(state, len, 2, iidx, "reference"); id->lua_item_reference(state, 2, ptr, idx); + attach_container_item_tagged_union(state, 1, -1, idx); return 1; } @@ -873,6 +989,7 @@ static int meta_container_nexti(lua_State *state) lua_pushinteger(state, idx); id->lua_item_read(state, 2, ptr, idx); + attach_container_item_tagged_union(state, 1, -1, idx); return 2; } @@ -1194,7 +1311,6 @@ static void IndexFields(lua_State *state, int base, struct_identity *pstruct, bo lua_pop(state, 1); bool add_to_enum = true; - const struct_field_info *tag_field = nullptr; // Handle the field switch (fields[i].mode) @@ -1208,16 +1324,11 @@ static void IndexFields(lua_State *state, int base, struct_identity *pstruct, bo continue; case struct_field_info::POINTER: - // Skip class-typed pointers within unions and other bad pointers - if ((pstruct->type() == IDTYPE_UNION || (fields[i].count & 2) != 0) && fields[i].type) + // Skip potentially bad pointers + if ((fields[i].count & 2) != 0 && fields[i].type) add_to_enum = false; break; - case struct_field_info::SUBSTRUCT: - case struct_field_info::CONTAINER: - tag_field = find_union_tag(fields, &fields[i]); - break; - default: break; } @@ -1229,17 +1340,6 @@ static void IndexFields(lua_State *state, int base, struct_identity *pstruct, bo if (add_to_enum) AssociateId(state, base+3, ++cnt, name.c_str()); - if (tag_field) - { - // TODO: handle tagged unions - // - // tagged unions are treated as if they have at most one field, - // with the same name as the corresponding enumeration value. - // - // if no field's name matches the enumeration value's name, - // the tagged union is treated as a structure with no fields. - } - lua_pushlightuserdata(state, (void*)&fields[i]); lua_setfield(state, base+2, name.c_str()); } @@ -1275,7 +1375,8 @@ void LuaWrapper::IndexStatics(lua_State *state, int meta_idx, int ftable_idx, st * Make a struct-style object metatable. */ static void MakeFieldMetatable(lua_State *state, struct_identity *pstruct, - lua_CFunction reader, lua_CFunction writer, bool globals = false) + lua_CFunction reader, lua_CFunction writer, + lua_CFunction iterator, bool globals = false) { int base = lua_gettop(state); @@ -1287,7 +1388,7 @@ static void MakeFieldMetatable(lua_State *state, struct_identity *pstruct, IndexFields(state, base, pstruct, globals); // Add the iteration metamethods - PushStructMethod(state, base+1, base+3, meta_struct_next); + PushStructMethod(state, base+1, base+3, iterator); SetPairsMethod(state, base+1, "__pairs"); lua_pushnil(state); SetPairsMethod(state, base+1, "__ipairs"); @@ -1434,7 +1535,15 @@ void bitfield_identity::build_metatable(lua_State *state) void struct_identity::build_metatable(lua_State *state) { int base = lua_gettop(state); - MakeFieldMetatable(state, this, meta_struct_index, meta_struct_newindex); + MakeFieldMetatable(state, this, meta_struct_index, meta_struct_newindex, meta_struct_next); + SetStructMethod(state, base+1, base+2, meta_struct_field_reference, "_field"); + SetPtrMethods(state, base+1, base+2); +} + +void union_identity::build_metatable(lua_State *state) +{ + int base = lua_gettop(state); + MakeFieldMetatable(state, this, meta_struct_index, meta_struct_newindex, meta_union_next); SetStructMethod(state, base+1, base+2, meta_struct_field_reference, "_field"); SetPtrMethods(state, base+1, base+2); } @@ -1442,7 +1551,7 @@ void struct_identity::build_metatable(lua_State *state) void other_vectors_identity::build_metatable(lua_State *state) { int base = lua_gettop(state); - MakeFieldMetatable(state, this, meta_struct_index, meta_struct_newindex); + MakeFieldMetatable(state, this, meta_struct_index, meta_struct_newindex, meta_struct_next); EnableMetaField(state, base+2, "_enum"); @@ -1464,7 +1573,7 @@ void other_vectors_identity::build_metatable(lua_State *state) void global_identity::build_metatable(lua_State *state) { int base = lua_gettop(state); - MakeFieldMetatable(state, this, meta_global_index, meta_global_newindex, true); + MakeFieldMetatable(state, this, meta_global_index, meta_global_newindex, meta_struct_next, true); SetStructMethod(state, base+1, base+2, meta_global_field_reference, "_field"); SetPtrMethods(state, base+1, base+2); } diff --git a/library/LuaWrapper.cpp b/library/LuaWrapper.cpp index 05c4db789..59bd96732 100644 --- a/library/LuaWrapper.cpp +++ b/library/LuaWrapper.cpp @@ -171,6 +171,9 @@ void LuaWrapper::push_object_ref(lua_State *state, void *ptr) auto ref = (DFRefHeader*)lua_newuserdata(state, sizeof(DFRefHeader)); ref->ptr = ptr; ref->field_info = NULL; + ref->tag_ptr = NULL; + ref->tag_identity = NULL; + ref->tag_attr = NULL; lua_swap(state); lua_setmetatable(state, -2); diff --git a/library/include/DataDefs.h b/library/include/DataDefs.h index f4eff8510..174156e76 100644 --- a/library/include/DataDefs.h +++ b/library/include/DataDefs.h @@ -239,6 +239,7 @@ namespace DFHack const char *const *keys, const ComplexData *complex, const void *attrs, struct_identity *attr_type); + enum_identity(enum_identity *enum_type, type_identity *override_base_type); virtual identity_type type() { return IDTYPE_ENUM; } @@ -332,6 +333,8 @@ namespace DFHack struct_identity *parent, const struct_field_info *fields); virtual identity_type type() { return IDTYPE_UNION; } + + virtual void build_metatable(lua_State *state); }; class DFHACK_EXPORT other_vectors_identity : public struct_identity { @@ -842,7 +845,7 @@ namespace DFHack { * As a special case, a container-type union can have a tag field that is * a bit vector if it has exactly two members. */ - DFHACK_EXPORT const struct_field_info *find_union_tag(const struct_field_info *fields, const struct_field_info *union_field); + DFHACK_EXPORT const struct_field_info *find_union_tag(struct_identity *structure, const struct_field_info *union_field); } #define ENUM_ATTR(enum,attr,val) (df::enum_traits::attrs(val).attr) diff --git a/library/include/DataIdentity.h b/library/include/DataIdentity.h index 88a96f9f3..711712af7 100644 --- a/library/include/DataIdentity.h +++ b/library/include/DataIdentity.h @@ -578,7 +578,7 @@ namespace df #ifdef BUILD_DFHACK_LIB template struct identity_traits > { - static primitive_identity *get(); + static enum_identity *get(); }; #endif @@ -631,8 +631,9 @@ namespace df #ifdef BUILD_DFHACK_LIB template - inline primitive_identity *identity_traits >::get() { - return identity_traits::get(); + inline enum_identity *identity_traits >::get() { + static enum_identity identity(identity_traits::get(), identity_traits::get()); + return &identity; } #endif diff --git a/library/include/LuaWrapper.h b/library/include/LuaWrapper.h index b23e1f912..c5daed0fc 100644 --- a/library/include/LuaWrapper.h +++ b/library/include/LuaWrapper.h @@ -127,6 +127,9 @@ namespace LuaWrapper { struct DFRefHeader { void *ptr; const struct_field_info *field_info; + const void *tag_ptr; + const type_identity *tag_identity; + const char *tag_attr; }; /** diff --git a/plugins/devel/check-structures-sanity/check-structures-sanity.h b/plugins/devel/check-structures-sanity/check-structures-sanity.h index da790edf0..b606e9d86 100644 --- a/plugins/devel/check-structures-sanity/check-structures-sanity.h +++ b/plugins/devel/check-structures-sanity/check-structures-sanity.h @@ -131,7 +131,7 @@ private: void dispatch_bitfield(const QueueItem &, const CheckedStructure &); void dispatch_enum(const QueueItem &, const CheckedStructure &); void dispatch_struct(const QueueItem &, const CheckedStructure &); - void dispatch_field(const QueueItem &, const CheckedStructure &, const struct_field_info *, const struct_field_info *); + void dispatch_field(const QueueItem &, const CheckedStructure &, struct_identity *, const struct_field_info *); void dispatch_class(const QueueItem &, const CheckedStructure &); void dispatch_buffer(const QueueItem &, const CheckedStructure &); void dispatch_stl_ptr_vector(const QueueItem &, const CheckedStructure &); diff --git a/plugins/devel/check-structures-sanity/dispatch.cpp b/plugins/devel/check-structures-sanity/dispatch.cpp index f6342d03c..cd288ebfe 100644 --- a/plugins/devel/check-structures-sanity/dispatch.cpp +++ b/plugins/devel/check-structures-sanity/dispatch.cpp @@ -528,11 +528,11 @@ void Checker::dispatch_struct(const QueueItem & item, const CheckedStructure & c for (auto field = fields; field->mode != struct_field_info::END; field++) { - dispatch_field(item, cs, fields, field); + dispatch_field(item, cs, identity, field); } } } -void Checker::dispatch_field(const QueueItem & item, const CheckedStructure & cs, const struct_field_info *fields, const struct_field_info *field) +void Checker::dispatch_field(const QueueItem & item, const CheckedStructure & cs, struct_identity *identity, const struct_field_info *field) { if (field->mode == struct_field_info::OBJ_METHOD || field->mode == struct_field_info::CLASS_METHOD) @@ -544,7 +544,7 @@ void Checker::dispatch_field(const QueueItem & item, const CheckedStructure & cs QueueItem field_item(item, field->name, field_ptr); CheckedStructure field_cs(field); - auto tag_field = find_union_tag(fields, field); + auto tag_field = find_union_tag(identity, field); if (tag_field) { auto tag_ptr = PTR_ADD(item.ptr, tag_field->offset); diff --git a/test/structures/unions.lua b/test/structures/unions.lua index be469da6f..714786a96 100644 --- a/test/structures/unions.lua +++ b/test/structures/unions.lua @@ -15,6 +15,7 @@ function test.unit_action_type() expect.true_(name, "unit_action_type entry without name: " .. tostring(index)) local tag = df.unit_action_type.attrs[name].tag expect.true_(tag, "unit_action_type entry missing tag: name=" .. name) + action.type = index expect.pairs_contains(action.data, tag, "unit_action_type entry missing from unit_action.data: name=" .. name) end