Fixed interface conversion panics

graph-rework-2
noah metz 2023-07-24 17:07:27 -06:00
parent fc2e36043f
commit 7d04923b3b
7 changed files with 100 additions and 53 deletions

@ -755,15 +755,16 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e
thread.Users = map[NodeID]*User{} thread.Users = map[NodeID]*User{}
for _, id_str := range(j.Users) { for _, id_str := range(j.Users) {
id, err := ParseID(id_str) ctx.Log.Logf("db", "THREAD_LOAD_USER: %s", id_str)
user_id, err := ParseID(id_str)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := LoadNodeRecurse(ctx, id, nodes) user, err := LoadNodeRecurse(ctx, user_id, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
thread.Users[id] = user.(*User) thread.Users[user_id] = user.(*User)
} }
err = RestoreThread(ctx, &thread.Thread, j.ThreadJSON, nodes) err = RestoreThread(ctx, &thread.Thread, j.ThreadJSON, nodes)

@ -19,8 +19,8 @@ import (
) )
func TestGQLDBLoad(t * testing.T) { func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{"test", "signal", "policy", "thread"}) ctx := logTestContext(t, []string{"test", "signal", "policy", "thread", "db"})
l1 := NewListener(RandID(), "Test Lockable 1") l1 := NewListener(RandID(), "Test Listener 1")
ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String()) ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String())
t1 := NewThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) t1 := NewThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
@ -58,13 +58,13 @@ func TestGQLDBLoad(t * testing.T) {
ctx.Log.Logf("test", "P1_ID: %s", p1.ID().String()) ctx.Log.Logf("test", "P1_ID: %s", p1.ID().String())
ctx.Log.Logf("test", "P2_ID: %s", p2.ID().String()) ctx.Log.Logf("test", "P2_ID: %s", p2.ID().String())
err = AttachPolicies(ctx, &gql.SimpleNode, &p1, &p2) err = AttachPolicies(ctx, &gql, &p1, &p2)
fatalErr(t, err) fatalErr(t, err)
err = AttachPolicies(ctx, &l1.SimpleNode, &p1, &p2) err = AttachPolicies(ctx, &l1, &p1, &p2)
fatalErr(t, err) fatalErr(t, err)
err = AttachPolicies(ctx, &t1.SimpleNode, &p1, &p2) err = AttachPolicies(ctx, &t1, &p1, &p2)
fatalErr(t, err) fatalErr(t, err)
err = AttachPolicies(ctx, &u1.SimpleNode, &p1, &p2) err = AttachPolicies(ctx, &u1, &p1, &p2)
fatalErr(t, err) fatalErr(t, err)
info := NewParentThreadInfo(true, "start", "restore") info := NewParentThreadInfo(true, "start", "restore")
@ -72,7 +72,7 @@ func TestGQLDBLoad(t * testing.T) {
err = UpdateStates(context, &gql, NewLockMap( err = UpdateStates(context, &gql, NewLockMap(
NewLockInfo(&gql, []string{"users"}), NewLockInfo(&gql, []string{"users"}),
), func(context *StateContext) error { ), func(context *StateContext) error {
gql.Users[KeyID(&u1_key.PublicKey)] = &u1 gql.Users[u1.ID()] = &u1
err := LinkThreads(context, &gql, &gql, ChildInfo{&t1, map[InfoType]interface{}{ err := LinkThreads(context, &gql, &gql, ChildInfo{&t1, map[InfoType]interface{}{
"parent": &info, "parent": &info,
@ -80,7 +80,7 @@ func TestGQLDBLoad(t * testing.T) {
if err != nil { if err != nil {
return err return err
} }
return LinkLockables(context, &gql, &gql, []LockableNode{&l1}) return LinkLockables(context, &gql, &l1, []LockableNode{&gql})
}) })
fatalErr(t, err) fatalErr(t, err)
@ -144,10 +144,10 @@ func TestGQLAuth(t * testing.T) {
gql_t := &gql_t_r gql_t := &gql_t_r
l1 := NewListener(RandID(), "GQL Thread") l1 := NewListener(RandID(), "GQL Thread")
err = AttachPolicies(ctx, &l1.SimpleNode, p1) err = AttachPolicies(ctx, &l1, p1)
fatalErr(t, err) fatalErr(t, err)
err = AttachPolicies(ctx, &gql_t.SimpleNode, p1) err = AttachPolicies(ctx, gql_t, p1)
done := make(chan error, 1) done := make(chan error, 1)
context := NewWriteContext(ctx) context := NewWriteContext(ctx)

@ -2,6 +2,7 @@ package graphvent
import ( import (
"fmt" "fmt"
"reflect"
"encoding/json" "encoding/json"
) )
@ -15,6 +16,7 @@ func (node *Listener) Type() NodeType {
} }
func (node *Listener) Process(context *StateContext, signal GraphSignal) error { func (node *Listener) Process(context *StateContext, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "LISTENER_PROCESS: %s", node.ID())
select { select {
case node.Chan <- signal: case node.Chan <- signal:
default: default:
@ -149,11 +151,9 @@ func (lockable *Lockable) RecordLock(l LockableNode, last_owner LockableNode) {
// Assumed that lockable is already locked for signal // Assumed that lockable is already locked for signal
func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) error { func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) error {
err := lockable.SimpleNode.Process(context, signal) context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", lockable.ID())
if err != nil {
return err
}
var err error
switch signal.Direction() { switch signal.Direction() {
case Up: case Up:
err = UseStates(context, lockable, err = UseStates(context, lockable,
@ -191,7 +191,10 @@ func (lockable *Lockable) Process(context *StateContext, signal GraphSignal) err
default: default:
return fmt.Errorf("invalid signal direction %d", signal.Direction()) return fmt.Errorf("invalid signal direction %d", signal.Direction())
} }
if err != nil {
return err return err
}
return lockable.SimpleNode.Process(context, signal)
} }
// Removes requirement as a requirement from lockable // Removes requirement as a requirement from lockable
@ -251,7 +254,7 @@ func LinkLockables(context *StateContext, princ Node, lockable_node LockableNode
} }
return UpdateStates(context, princ, NewLockMap( return UpdateStates(context, princ, NewLockMap(
NewLockInfo(lockable, []string{"requirements"}), NewLockInfo(lockable_node, []string{"requirements"}),
LockList(requirements, []string{"dependencies"}), LockList(requirements, []string{"dependencies"}),
), func(context *StateContext) error { ), func(context *StateContext) error {
// Check that all the requirements can be added // Check that all the requirements can be added
@ -520,6 +523,7 @@ func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes No
if ok == false { if ok == false {
return fmt.Errorf("%+v is not a Lockable as expected", dep_node) return fmt.Errorf("%+v is not a Lockable as expected", dep_node)
} }
ctx.Log.Logf("db", "LOCKABLE_LOAD_DEPENDENCY: %s - %s - %+v", lockable.ID(), dep_id, reflect.TypeOf(dep))
lockable.Dependencies[dep_id] = dep lockable.Dependencies[dep_id] = dep
} }

@ -14,6 +14,10 @@ import (
// IDs are how nodes are uniquely identified, and can be serialized for the database // IDs are how nodes are uniquely identified, and can be serialized for the database
type NodeID uuid.UUID type NodeID uuid.UUID
func (id NodeID) MarshalJSON() ([]byte, error) {
str := id.String()
return json.Marshal(&str)
}
var ZeroUUID = uuid.UUID{} var ZeroUUID = uuid.UUID{}
var ZeroID = NodeID(ZeroUUID) var ZeroID = NodeID(ZeroUUID)
@ -62,6 +66,7 @@ type Node interface {
UnlockState(write bool) UnlockState(write bool)
Process(context *StateContext, signal GraphSignal) error Process(context *StateContext, signal GraphSignal) error
Policies() []Policy Policies() []Policy
NodeHandle() *SimpleNode
} }
type SimpleNode struct { type SimpleNode struct {
@ -70,6 +75,10 @@ type SimpleNode struct {
policies map[NodeID]Policy policies map[NodeID]Policy
} }
func (node *SimpleNode) NodeHandle() *SimpleNode {
return node
}
func NewSimpleNode(id NodeID) SimpleNode { func NewSimpleNode(id NodeID) SimpleNode {
return SimpleNode{ return SimpleNode{
id: id, id: id,
@ -82,7 +91,7 @@ type SimpleNodeJSON struct {
} }
func (node *SimpleNode) Process(context *StateContext, signal GraphSignal) error { func (node *SimpleNode) Process(context *StateContext, signal GraphSignal) error {
context.Graph.Log.Logf("signal", "SIMPLE_NODE_SIGNAL: %s - %+v", node.id, signal) context.Graph.Log.Logf("signal", "SIMPLE_NODE_SIGNAL: %s - %s", node.id, signal)
return nil return nil
} }
@ -214,11 +223,11 @@ func Signal(context *StateContext, node Node, princ Node, signal GraphSignal) er
return node.Process(context, signal) return node.Process(context, signal)
} }
func AttachPolicies(ctx *Context, node *SimpleNode, policies ...Policy) error { func AttachPolicies(ctx *Context, node Node, policies ...Policy) error {
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
return UpdateStates(context, node, NewLockInfo(node, []string{"policies"}), func(context *StateContext) error { return UpdateStates(context, node, NewLockInfo(node, []string{"policies"}), func(context *StateContext) error {
for _, policy := range(policies) { for _, policy := range(policies) {
node.policies[policy.ID()] = policy node.NodeHandle().policies[policy.ID()] = policy
} }
return nil return nil
}) })
@ -252,23 +261,6 @@ func NewDBHeader(node_type NodeType) DBHeader {
} }
} }
// Internal function to serialize a node and wrap it with the DB Header
func getNodeBytes(node Node) ([]byte, error) {
if node == nil {
return nil, fmt.Errorf("DB_SERIALIZE_ERROR: cannot serialize nil node")
}
ser, err := node.Serialize()
if err != nil {
return nil, fmt.Errorf("DB_SERIALIZE_ERROR: %s", err)
}
header := NewDBHeader(node.Type())
db_data := append(header.Serialize(), ser...)
return db_data, nil
}
// Write multiple nodes to the database in a single transaction // Write multiple nodes to the database in a single transaction
func WriteNodes(context *StateContext) error { func WriteNodes(context *StateContext) error {
err := ValidateStateContext(context, "write", true) err := ValidateStateContext(context, "write", true)
@ -282,15 +274,26 @@ func WriteNodes(context *StateContext) error {
serialized_ids := make([][]byte, len(context.Locked)) serialized_ids := make([][]byte, len(context.Locked))
i := 0 i := 0
for _, node := range(context.Locked) { for _, node := range(context.Locked) {
node_bytes, err := getNodeBytes(node) if node == nil {
context.Graph.Log.Logf("db", "DB_WRITE: %+v", node) return fmt.Errorf("DB_SERIALIZE_ERROR: cannot serialize nil node")
}
ser, err := node.Serialize()
if err != nil {
return fmt.Errorf("DB_SERIALIZE_ERROR: %s", err)
}
header := NewDBHeader(node.Type())
db_data := append(header.Serialize(), ser...)
context.Graph.Log.Logf("db", "DB_WRITING_TYPE: %s - %+v %+v: %+v", node.ID(), node.Type(), header, node)
if err != nil { if err != nil {
return err return err
} }
id_ser := node.ID().Serialize() id_ser := node.ID().Serialize()
serialized_bytes[i] = node_bytes serialized_bytes[i] = db_data
serialized_ids[i] = id_ser serialized_ids[i] = id_ser
i++ i++
@ -342,7 +345,7 @@ func readNodeBytes(ctx * Context, id NodeID) (uint64, []byte, error) {
node_bytes := make([]byte, len(bytes) - NODE_DB_HEADER_LEN) node_bytes := make([]byte, len(bytes) - NODE_DB_HEADER_LEN)
copy(node_bytes, bytes[NODE_DB_HEADER_LEN:]) copy(node_bytes, bytes[NODE_DB_HEADER_LEN:])
ctx.Log.Logf("db", "DB_READ: %s - %s", id, string(bytes)) ctx.Log.Logf("db", "DB_READ: %s %+v - %s", id, header, string(bytes))
return header.TypeHash, node_bytes, nil return header.TypeHash, node_bytes, nil
} }
@ -365,6 +368,7 @@ func LoadNodeRecurse(ctx * Context, id NodeID, nodes NodeMap) (Node, error) {
} }
node_type, exists := ctx.Types[type_hash] node_type, exists := ctx.Types[type_hash]
ctx.Log.Logf("db", "DB_LOADING_TYPE: %s - %+v", id, node_type)
if exists == false { if exists == false {
return nil, fmt.Errorf("0x%x is not a known node type: %+s", type_hash, bytes) return nil, fmt.Errorf("0x%x is not a known node type: %+s", type_hash, bytes)
} }

@ -25,8 +25,8 @@ type BaseSignal struct {
FType string `json:"type"` FType string `json:"type"`
} }
func (state BaseSignal) String() string { func (signal BaseSignal) String() string {
ser, err := json.Marshal(state) ser, err := json.Marshal(signal)
if err != nil { if err != nil {
return "STATE_SER_ERR" return "STATE_SER_ERR"
} }
@ -69,6 +69,14 @@ type IDSignal struct {
ID NodeID `json:"id"` ID NodeID `json:"id"`
} }
func (signal IDSignal) String() string {
ser, err := json.Marshal(signal)
if err != nil {
return "STATE_SER_ERR"
}
return string(ser)
}
func NewIDSignal(_type string, direction SignalDirection, id NodeID) IDSignal { func NewIDSignal(_type string, direction SignalDirection, id NodeID) IDSignal {
return IDSignal{ return IDSignal{
BaseSignal: NewBaseSignal(_type, direction), BaseSignal: NewBaseSignal(_type, direction),
@ -78,7 +86,15 @@ func NewIDSignal(_type string, direction SignalDirection, id NodeID) IDSignal {
type StatusSignal struct { type StatusSignal struct {
IDSignal IDSignal
Status string Status string `json:"status"`
}
func (signal StatusSignal) String() string {
ser, err := json.Marshal(signal)
if err != nil {
return "STATE_SER_ERR"
}
return string(ser)
} }
func NewStatusSignal(status string, source NodeID) StatusSignal { func NewStatusSignal(status string, source NodeID) StatusSignal {

@ -10,11 +10,9 @@ import (
// Assumed that thread is already locked for signal // Assumed that thread is already locked for signal
func (thread *Thread) Process(context *StateContext, signal GraphSignal) error { func (thread *Thread) Process(context *StateContext, signal GraphSignal) error {
err := thread.Lockable.Process(context, signal) context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", thread.ID())
if err != nil {
return err
}
var err error
switch signal.Direction() { switch signal.Direction() {
case Up: case Up:
err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error { err = UseStates(context, thread, NewLockInfo(thread, []string{"parent"}), func(context *StateContext) error {
@ -44,7 +42,7 @@ func (thread *Thread) Process(context *StateContext, signal GraphSignal) error {
} }
thread.Chan <- signal thread.Chan <- signal
return nil return thread.Lockable.Process(context, signal)
} }
// Requires thread and childs thread to be locked for write // Requires thread and childs thread to be locked for write
@ -197,7 +195,7 @@ func (thread *Thread) ThreadHandle() *Thread {
} }
func (thread *Thread) Type() NodeType { func (thread *Thread) Type() NodeType {
return NodeType("simple_thread") return NodeType("thread")
} }
func (thread *Thread) Serialize() ([]byte, error) { func (thread *Thread) Serialize() ([]byte, error) {
@ -319,7 +317,30 @@ func RestoreThread(ctx *Context, thread *Thread, j ThreadJSON, nodes NodeMap) er
} }
var deserializers = map[InfoType]func(interface{})(interface{}, error) { var deserializers = map[InfoType]func(interface{})(interface{}, error) {
"parent": func(raw interface{})(interface{}, error) {
m, ok := raw.(map[string]interface{})
if ok == false {
return nil, fmt.Errorf("Failed to cast parent info to map")
}
start, ok := m["start"].(bool)
if ok == false {
return nil, fmt.Errorf("Failed to get start from parent info")
}
start_action, ok := m["start_action"].(string)
if ok == false {
return nil, fmt.Errorf("Failed to get start_action from parent info")
}
restore_action, ok := m["restore_action"].(string)
if ok == false {
return nil, fmt.Errorf("Failed to get restore_action from parent info")
}
return &ParentThreadInfo{
Start: start,
StartAction: start_action,
RestoreAction: restore_action,
}, nil
},
} }
func DeserializeChildInfo(ctx *Context, infos_raw map[string]interface{}) (map[InfoType]interface{}, error) { func DeserializeChildInfo(ctx *Context, infos_raw map[string]interface{}) (map[InfoType]interface{}, error) {
@ -401,7 +422,7 @@ func ChildGo(ctx * Context, thread *Thread, child ThreadNode, first_action strin
defer thread.ChildWaits.Done() defer thread.ChildWaits.Done()
err := ThreadLoop(ctx, child, first_action) err := ThreadLoop(ctx, child, first_action)
if err != nil { if err != nil {
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %e", child.ID(), err) ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %s", child.ID(), err)
} else { } else {
ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID()) ctx.Log.Logf("thread", "THREAD_CHILD_RUN_DONE: %s", child.ID())
} }

@ -46,6 +46,7 @@ func (user *User) Serialize() ([]byte, error) {
} }
func LoadUser(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { func LoadUser(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
ctx.Log.Logf("test", "LOADING_USER: %s", id)
var j UserJSON var j UserJSON
err := json.Unmarshal(data, &j) err := json.Unmarshal(data, &j)
if err != nil { if err != nil {