tagged union support for lua (#1818)

develop
Ben Lubar 2021-03-30 15:55:06 -05:00 committed by GitHub
parent 5e7653bbf5
commit c06d1f8e52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 183 additions and 51 deletions

@ -159,6 +159,14 @@ enum_identity::enum_identity(size_t size,
} }
} }
enum_identity::enum_identity(enum_identity *base_enum, type_identity *override_base_type)
: enum_identity(override_base_type->byte_size(), base_enum->getScopeParent(),
base_enum->getName(), override_base_type, base_enum->first_item_value,
base_enum->last_item_value, base_enum->keys, base_enum->complex,
base_enum->attrs, base_enum->attr_type)
{
}
enum_identity::ComplexData::ComplexData(std::initializer_list<int64_t> values) enum_identity::ComplexData::ComplexData(std::initializer_list<int64_t> values)
{ {
size_t i = 0; size_t i = 0;
@ -455,18 +463,22 @@ void DFHack::flagarrayToString(std::vector<std::string> *pvec, const void *p,
} }
} }
static const struct_field_info *find_union_tag_candidate(const struct_field_info *fields, const struct_field_info *union_field) static const struct_field_info *find_union_tag_candidate(struct_identity *structure, const struct_field_info *union_field)
{ {
if (union_field->extra && union_field->extra->union_tag_field) if (union_field->extra && union_field->extra->union_tag_field)
{ {
auto defined_field_name = union_field->extra->union_tag_field; auto defined_field_name = union_field->extra->union_tag_field;
for (auto field = fields; field->mode != struct_field_info::END; field++) for (auto p = structure; p; p = p->getParent())
{ {
if (!strcmp(field->name, defined_field_name)) for (auto field = p->getFields(); field && field->mode != struct_field_info::END; field++)
{ {
return field; if (!strcmp(field->name, defined_field_name))
{
return field;
}
} }
} }
return nullptr; return nullptr;
} }
@ -476,32 +488,32 @@ static const struct_field_info *find_union_tag_candidate(const struct_field_info
name.erase(name.length() - 4, 4); name.erase(name.length() - 4, 4);
name += "type"; name += "type";
for (auto field = fields; field->mode != struct_field_info::END; field++) for (auto p = structure; p; p = p->getParent())
{ {
if (field->name == name) for (auto field = p->getFields(); field && field->mode != struct_field_info::END; field++)
{ {
return field; if (field->name == name)
{
return field;
}
} }
} }
} }
if (name.length() > 7 && return nullptr;
name.substr(name.length() - 7) == "_target" &&
fields != union_field &&
(union_field - 1)->name == name.substr(0, name.length() - 7))
{
return union_field - 1;
}
return union_field + 1;
} }
const struct_field_info *DFHack::find_union_tag(const struct_field_info *fields, const struct_field_info *union_field) const struct_field_info *DFHack::find_union_tag(struct_identity *structure, const struct_field_info *union_field)
{ {
CHECK_NULL_POINTER(fields); CHECK_NULL_POINTER(structure);
CHECK_NULL_POINTER(union_field); CHECK_NULL_POINTER(union_field);
auto tag_candidate = find_union_tag_candidate(fields, union_field); auto tag_candidate = find_union_tag_candidate(structure, union_field);
if (!tag_candidate)
{
return nullptr;
}
if (union_field->mode == struct_field_info::SUBSTRUCT && if (union_field->mode == struct_field_info::SUBSTRUCT &&
union_field->type && union_field->type &&

@ -616,6 +616,16 @@ static int meta_struct_index(lua_State *state)
if (!field) if (!field)
return 1; return 1;
read_field(state, field, ptr + field->offset); read_field(state, field, ptr + field->offset);
if (field->mode == struct_field_info::SUBSTRUCT || field->mode == struct_field_info::CONTAINER)
{
auto struct_type = (struct_identity*)get_object_identity(state, 1, "read", false);
if (auto tag_field = find_union_tag(struct_type, field))
{
get_object_ref_header(state, -1)->tag_ptr = ptr + tag_field->offset;
get_object_ref_header(state, -1)->tag_identity = tag_field->type;
get_object_ref_header(state, -1)->tag_attr = field->extra ? field->extra->union_tag_attr : nullptr;
}
}
return 1; return 1;
} }
@ -631,6 +641,16 @@ static int meta_struct_field_reference(lua_State *state)
if (!field) if (!field)
field_error(state, 2, "builtin property or method", "reference"); field_error(state, 2, "builtin property or method", "reference");
field_reference(state, field, ptr + field->offset); field_reference(state, field, ptr + field->offset);
if (field->mode == struct_field_info::SUBSTRUCT || field->mode == struct_field_info::CONTAINER)
{
auto struct_type = (struct_identity*)get_object_identity(state, 1, "reference", false);
if (auto tag_field = find_union_tag(struct_type, field))
{
get_object_ref_header(state, -1)->tag_ptr = ptr + tag_field->offset;
get_object_ref_header(state, -1)->tag_identity = tag_field->type;
get_object_ref_header(state, -1)->tag_attr = field->extra ? field->extra->union_tag_attr : nullptr;
}
}
return 1; return 1;
} }
@ -679,6 +699,81 @@ static int meta_struct_next(lua_State *state)
return 2; return 2;
} }
/**
* Metamethod: iterator for unions.
*/
static int meta_union_next(lua_State *state)
{
if (lua_gettop(state) < 2) lua_pushnil(state);
int len = lua_rawlen(state, UPVAL_FIELDTABLE);
int idx = cur_iter_index(state, len+1, 2, 0);
if (idx == len)
return 0;
auto header = get_object_ref_header(state, 1);
if (header->tag_ptr)
{
if (idx != 0)
return 0;
auto enum_id = (enum_identity*)header->tag_identity;
auto tag_val = *(int64_t*)header->tag_ptr;
size_t tag_shift = 64 - 8 * enum_id->byte_size();
tag_val <<= tag_shift;
tag_val >>= tag_shift;
size_t tag_index = tag_val - enum_id->getFirstItem();
if (auto complex = enum_id->getComplex())
tag_index = complex->value_index_map.count(tag_val) ? complex->value_index_map.at(tag_val) : size_t(-1);
if (tag_index >= size_t(enum_id->getCount()))
return 0;
const char *tag_name = nullptr;
if (header->tag_attr)
{
for (auto enum_field = enum_id->getAttrType()->getFields(); enum_field->mode != struct_field_info::END; enum_field++)
{
if (!strcmp(enum_field->name, header->tag_attr))
{
if (enum_field->type == df::identity_traits<const char*>::get())
{
auto attrs = ((uint8_t*)enum_id->getAttrs()) + (tag_index * enum_id->getAttrType()->byte_size());
tag_name = *(const char **)(attrs + enum_field->offset);
}
break;
}
}
}
else
{
tag_name = enum_id->getKeys()[tag_index];
}
if (!tag_name)
return 0;
lua_getfield(state, UPVAL_FIELDTABLE, tag_name);
if (lua_isnil(state, lua_gettop(state)))
{
lua_pop(state, 1);
return 0;
}
lua_pop(state, 1);
lua_pushstring(state, tag_name);
lua_getfield(state, 1, tag_name);
return 2;
}
lua_rawgeti(state, UPVAL_FIELDTABLE, idx+1);
lua_dup(state);
lua_gettable(state, 1);
return 2;
}
/** /**
* Field lookup for primitive refs: behave as a quasi-array with numeric indices. * Field lookup for primitive refs: behave as a quasi-array with numeric indices.
*/ */
@ -806,6 +901,25 @@ static int check_container_index(lua_State *state, int len,
return idx; return idx;
} }
static void attach_container_item_tagged_union(lua_State *state, int container, int item, int idx)
{
auto header = get_object_ref_header(state, container);
if (header->tag_ptr)
{
// TODO: handle bit vector for tag
auto tag_container = (container_identity*)header->tag_identity;
auto ref = get_object_ref_header(state, item);
// on both msvc and gcc, vectors have the same memory layout
auto item_type = tag_container->getItemType();
ref->tag_ptr = (*(uint8_t**)header->tag_ptr) + size_t(idx) * item_type->byte_size();
ref->tag_identity = item_type;
ref->tag_attr = header->tag_attr;
}
}
/** /**
* Metamethod: __index for containers. * Metamethod: __index for containers.
*/ */
@ -820,6 +934,7 @@ static int meta_container_index(lua_State *state)
int len = id->lua_item_count(state, ptr, container_identity::COUNT_READ); int len = id->lua_item_count(state, ptr, container_identity::COUNT_READ);
int idx = check_container_index(state, len, 2, iidx, "read"); int idx = check_container_index(state, len, 2, iidx, "read");
id->lua_item_read(state, 2, ptr, idx); id->lua_item_read(state, 2, ptr, idx);
attach_container_item_tagged_union(state, 1, -1, idx);
return 1; return 1;
} }
@ -837,6 +952,7 @@ static int meta_container_field_reference(lua_State *state)
int len = id->lua_item_count(state, ptr, container_identity::COUNT_WRITE); int len = id->lua_item_count(state, ptr, container_identity::COUNT_WRITE);
int idx = check_container_index(state, len, 2, iidx, "reference"); int idx = check_container_index(state, len, 2, iidx, "reference");
id->lua_item_reference(state, 2, ptr, idx); id->lua_item_reference(state, 2, ptr, idx);
attach_container_item_tagged_union(state, 1, -1, idx);
return 1; return 1;
} }
@ -873,6 +989,7 @@ static int meta_container_nexti(lua_State *state)
lua_pushinteger(state, idx); lua_pushinteger(state, idx);
id->lua_item_read(state, 2, ptr, idx); id->lua_item_read(state, 2, ptr, idx);
attach_container_item_tagged_union(state, 1, -1, idx);
return 2; return 2;
} }
@ -1194,7 +1311,6 @@ static void IndexFields(lua_State *state, int base, struct_identity *pstruct, bo
lua_pop(state, 1); lua_pop(state, 1);
bool add_to_enum = true; bool add_to_enum = true;
const struct_field_info *tag_field = nullptr;
// Handle the field // Handle the field
switch (fields[i].mode) switch (fields[i].mode)
@ -1208,16 +1324,11 @@ static void IndexFields(lua_State *state, int base, struct_identity *pstruct, bo
continue; continue;
case struct_field_info::POINTER: case struct_field_info::POINTER:
// Skip class-typed pointers within unions and other bad pointers // Skip potentially bad pointers
if ((pstruct->type() == IDTYPE_UNION || (fields[i].count & 2) != 0) && fields[i].type) if ((fields[i].count & 2) != 0 && fields[i].type)
add_to_enum = false; add_to_enum = false;
break; break;
case struct_field_info::SUBSTRUCT:
case struct_field_info::CONTAINER:
tag_field = find_union_tag(fields, &fields[i]);
break;
default: default:
break; break;
} }
@ -1229,17 +1340,6 @@ static void IndexFields(lua_State *state, int base, struct_identity *pstruct, bo
if (add_to_enum) if (add_to_enum)
AssociateId(state, base+3, ++cnt, name.c_str()); AssociateId(state, base+3, ++cnt, name.c_str());
if (tag_field)
{
// TODO: handle tagged unions
//
// tagged unions are treated as if they have at most one field,
// with the same name as the corresponding enumeration value.
//
// if no field's name matches the enumeration value's name,
// the tagged union is treated as a structure with no fields.
}
lua_pushlightuserdata(state, (void*)&fields[i]); lua_pushlightuserdata(state, (void*)&fields[i]);
lua_setfield(state, base+2, name.c_str()); lua_setfield(state, base+2, name.c_str());
} }
@ -1275,7 +1375,8 @@ void LuaWrapper::IndexStatics(lua_State *state, int meta_idx, int ftable_idx, st
* Make a struct-style object metatable. * Make a struct-style object metatable.
*/ */
static void MakeFieldMetatable(lua_State *state, struct_identity *pstruct, static void MakeFieldMetatable(lua_State *state, struct_identity *pstruct,
lua_CFunction reader, lua_CFunction writer, bool globals = false) lua_CFunction reader, lua_CFunction writer,
lua_CFunction iterator, bool globals = false)
{ {
int base = lua_gettop(state); int base = lua_gettop(state);
@ -1287,7 +1388,7 @@ static void MakeFieldMetatable(lua_State *state, struct_identity *pstruct,
IndexFields(state, base, pstruct, globals); IndexFields(state, base, pstruct, globals);
// Add the iteration metamethods // Add the iteration metamethods
PushStructMethod(state, base+1, base+3, meta_struct_next); PushStructMethod(state, base+1, base+3, iterator);
SetPairsMethod(state, base+1, "__pairs"); SetPairsMethod(state, base+1, "__pairs");
lua_pushnil(state); lua_pushnil(state);
SetPairsMethod(state, base+1, "__ipairs"); SetPairsMethod(state, base+1, "__ipairs");
@ -1434,7 +1535,15 @@ void bitfield_identity::build_metatable(lua_State *state)
void struct_identity::build_metatable(lua_State *state) void struct_identity::build_metatable(lua_State *state)
{ {
int base = lua_gettop(state); int base = lua_gettop(state);
MakeFieldMetatable(state, this, meta_struct_index, meta_struct_newindex); MakeFieldMetatable(state, this, meta_struct_index, meta_struct_newindex, meta_struct_next);
SetStructMethod(state, base+1, base+2, meta_struct_field_reference, "_field");
SetPtrMethods(state, base+1, base+2);
}
void union_identity::build_metatable(lua_State *state)
{
int base = lua_gettop(state);
MakeFieldMetatable(state, this, meta_struct_index, meta_struct_newindex, meta_union_next);
SetStructMethod(state, base+1, base+2, meta_struct_field_reference, "_field"); SetStructMethod(state, base+1, base+2, meta_struct_field_reference, "_field");
SetPtrMethods(state, base+1, base+2); SetPtrMethods(state, base+1, base+2);
} }
@ -1442,7 +1551,7 @@ void struct_identity::build_metatable(lua_State *state)
void other_vectors_identity::build_metatable(lua_State *state) void other_vectors_identity::build_metatable(lua_State *state)
{ {
int base = lua_gettop(state); int base = lua_gettop(state);
MakeFieldMetatable(state, this, meta_struct_index, meta_struct_newindex); MakeFieldMetatable(state, this, meta_struct_index, meta_struct_newindex, meta_struct_next);
EnableMetaField(state, base+2, "_enum"); EnableMetaField(state, base+2, "_enum");
@ -1464,7 +1573,7 @@ void other_vectors_identity::build_metatable(lua_State *state)
void global_identity::build_metatable(lua_State *state) void global_identity::build_metatable(lua_State *state)
{ {
int base = lua_gettop(state); int base = lua_gettop(state);
MakeFieldMetatable(state, this, meta_global_index, meta_global_newindex, true); MakeFieldMetatable(state, this, meta_global_index, meta_global_newindex, meta_struct_next, true);
SetStructMethod(state, base+1, base+2, meta_global_field_reference, "_field"); SetStructMethod(state, base+1, base+2, meta_global_field_reference, "_field");
SetPtrMethods(state, base+1, base+2); SetPtrMethods(state, base+1, base+2);
} }

@ -171,6 +171,9 @@ void LuaWrapper::push_object_ref(lua_State *state, void *ptr)
auto ref = (DFRefHeader*)lua_newuserdata(state, sizeof(DFRefHeader)); auto ref = (DFRefHeader*)lua_newuserdata(state, sizeof(DFRefHeader));
ref->ptr = ptr; ref->ptr = ptr;
ref->field_info = NULL; ref->field_info = NULL;
ref->tag_ptr = NULL;
ref->tag_identity = NULL;
ref->tag_attr = NULL;
lua_swap(state); lua_swap(state);
lua_setmetatable(state, -2); lua_setmetatable(state, -2);

@ -239,6 +239,7 @@ namespace DFHack
const char *const *keys, const char *const *keys,
const ComplexData *complex, const ComplexData *complex,
const void *attrs, struct_identity *attr_type); const void *attrs, struct_identity *attr_type);
enum_identity(enum_identity *enum_type, type_identity *override_base_type);
virtual identity_type type() { return IDTYPE_ENUM; } virtual identity_type type() { return IDTYPE_ENUM; }
@ -332,6 +333,8 @@ namespace DFHack
struct_identity *parent, const struct_field_info *fields); struct_identity *parent, const struct_field_info *fields);
virtual identity_type type() { return IDTYPE_UNION; } virtual identity_type type() { return IDTYPE_UNION; }
virtual void build_metatable(lua_State *state);
}; };
class DFHACK_EXPORT other_vectors_identity : public struct_identity { class DFHACK_EXPORT other_vectors_identity : public struct_identity {
@ -842,7 +845,7 @@ namespace DFHack {
* As a special case, a container-type union can have a tag field that is * As a special case, a container-type union can have a tag field that is
* a bit vector if it has exactly two members. * a bit vector if it has exactly two members.
*/ */
DFHACK_EXPORT const struct_field_info *find_union_tag(const struct_field_info *fields, const struct_field_info *union_field); DFHACK_EXPORT const struct_field_info *find_union_tag(struct_identity *structure, const struct_field_info *union_field);
} }
#define ENUM_ATTR(enum,attr,val) (df::enum_traits<df::enum>::attrs(val).attr) #define ENUM_ATTR(enum,attr,val) (df::enum_traits<df::enum>::attrs(val).attr)

@ -578,7 +578,7 @@ namespace df
#ifdef BUILD_DFHACK_LIB #ifdef BUILD_DFHACK_LIB
template<class Enum, class FT> struct identity_traits<enum_field<Enum,FT> > { template<class Enum, class FT> struct identity_traits<enum_field<Enum,FT> > {
static primitive_identity *get(); static enum_identity *get();
}; };
#endif #endif
@ -631,8 +631,9 @@ namespace df
#ifdef BUILD_DFHACK_LIB #ifdef BUILD_DFHACK_LIB
template<class Enum, class FT> template<class Enum, class FT>
inline primitive_identity *identity_traits<enum_field<Enum,FT> >::get() { inline enum_identity *identity_traits<enum_field<Enum,FT> >::get() {
return identity_traits<FT>::get(); static enum_identity identity(identity_traits<Enum>::get(), identity_traits<FT>::get());
return &identity;
} }
#endif #endif

@ -127,6 +127,9 @@ namespace LuaWrapper {
struct DFRefHeader { struct DFRefHeader {
void *ptr; void *ptr;
const struct_field_info *field_info; const struct_field_info *field_info;
const void *tag_ptr;
const type_identity *tag_identity;
const char *tag_attr;
}; };
/** /**

@ -131,7 +131,7 @@ private:
void dispatch_bitfield(const QueueItem &, const CheckedStructure &); void dispatch_bitfield(const QueueItem &, const CheckedStructure &);
void dispatch_enum(const QueueItem &, const CheckedStructure &); void dispatch_enum(const QueueItem &, const CheckedStructure &);
void dispatch_struct(const QueueItem &, const CheckedStructure &); void dispatch_struct(const QueueItem &, const CheckedStructure &);
void dispatch_field(const QueueItem &, const CheckedStructure &, const struct_field_info *, const struct_field_info *); void dispatch_field(const QueueItem &, const CheckedStructure &, struct_identity *, const struct_field_info *);
void dispatch_class(const QueueItem &, const CheckedStructure &); void dispatch_class(const QueueItem &, const CheckedStructure &);
void dispatch_buffer(const QueueItem &, const CheckedStructure &); void dispatch_buffer(const QueueItem &, const CheckedStructure &);
void dispatch_stl_ptr_vector(const QueueItem &, const CheckedStructure &); void dispatch_stl_ptr_vector(const QueueItem &, const CheckedStructure &);

@ -528,11 +528,11 @@ void Checker::dispatch_struct(const QueueItem & item, const CheckedStructure & c
for (auto field = fields; field->mode != struct_field_info::END; field++) for (auto field = fields; field->mode != struct_field_info::END; field++)
{ {
dispatch_field(item, cs, fields, field); dispatch_field(item, cs, identity, field);
} }
} }
} }
void Checker::dispatch_field(const QueueItem & item, const CheckedStructure & cs, const struct_field_info *fields, const struct_field_info *field) void Checker::dispatch_field(const QueueItem & item, const CheckedStructure & cs, struct_identity *identity, const struct_field_info *field)
{ {
if (field->mode == struct_field_info::OBJ_METHOD || if (field->mode == struct_field_info::OBJ_METHOD ||
field->mode == struct_field_info::CLASS_METHOD) field->mode == struct_field_info::CLASS_METHOD)
@ -544,7 +544,7 @@ void Checker::dispatch_field(const QueueItem & item, const CheckedStructure & cs
QueueItem field_item(item, field->name, field_ptr); QueueItem field_item(item, field->name, field_ptr);
CheckedStructure field_cs(field); CheckedStructure field_cs(field);
auto tag_field = find_union_tag(fields, field); auto tag_field = find_union_tag(identity, field);
if (tag_field) if (tag_field)
{ {
auto tag_ptr = PTR_ADD(item.ptr, tag_field->offset); auto tag_ptr = PTR_ADD(item.ptr, tag_field->offset);

@ -15,6 +15,7 @@ function test.unit_action_type()
expect.true_(name, "unit_action_type entry without name: " .. tostring(index)) expect.true_(name, "unit_action_type entry without name: " .. tostring(index))
local tag = df.unit_action_type.attrs[name].tag local tag = df.unit_action_type.attrs[name].tag
expect.true_(tag, "unit_action_type entry missing tag: name=" .. name) expect.true_(tag, "unit_action_type entry missing tag: name=" .. name)
action.type = index
expect.pairs_contains(action.data, tag, expect.pairs_contains(action.data, tag,
"unit_action_type entry missing from unit_action.data: name=" .. name) "unit_action_type entry missing from unit_action.data: name=" .. name)
end end