From fa6142d88058898131cc32e1cdd139f89a41b863 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Wed, 26 Jul 2023 11:56:10 -0600 Subject: [PATCH] Started adding back gql tests --- context.go | 20 ++++- gql.go | 30 ++++--- gql_test.go | 250 +++++++++++++-------------------------------------- lockable.go | 51 ++++++++--- node.go | 59 +++++++++++- node_test.go | 4 +- policy.go | 10 +++ thread.go | 236 ++++++++++++++++++++++++++++-------------------- user.go | 14 ++- 9 files changed, 354 insertions(+), 320 deletions(-) diff --git a/context.go b/context.go index 82aecc2..41481c2 100644 --- a/context.go +++ b/context.go @@ -34,10 +34,14 @@ type Context struct { Nodes map[NodeID]*Node } -func (ctx *Context) ExtByType(ext_type ExtType) ExtensionInfo { +func (ctx *Context) ExtByType(ext_type ExtType) *ExtensionInfo { type_hash := ext_type.Hash() - ext, _ := ctx.Extensions[type_hash] - return ext + ext, ok := ctx.Extensions[type_hash] + if ok == true { + return &ext + } else { + return nil + } } func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error { @@ -114,6 +118,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = ctx.RegisterExtension(ListenerExtType, LoadListenerExt, nil) + if err != nil { + return nil, err + } + err = ctx.RegisterExtension(ThreadExtType, LoadThreadExt, NewThreadExtContext()) if err != nil { return nil, err @@ -134,5 +143,10 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = RegisterGQLThread(ctx) + if err != nil { + return nil, err + } + return ctx, nil } diff --git a/gql.go b/gql.go index 3987748..2810018 100644 --- a/gql.go +++ b/gql.go @@ -30,6 +30,21 @@ import ( "encoding/pem" ) +const GQLThreadType = ThreadType("GQL") +func RegisterGQLThread(ctx *Context) error { + thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx) + if err != nil { + return err + } + + err = thread_ctx.RegisterThreadType(GQLThreadType, gql_actions, gql_handlers) + if err != nil { + return err + } + + return nil +} + type AuthReqJSON struct { Time time.Time `json:"time"` Pubkey []byte `json:"pubkey"` @@ -793,19 +808,10 @@ func LoadGQLExt(ctx *Context, data []byte) (Extension, error) { return nil, err } - extension := GQLExt{ - Listen: j.Listen, - Key: key, - ECDH: ecdh_curve, - SubscribeListeners: []chan GraphSignal{}, - tls_key: j.TLSKey, - tls_cert: j.TLSCert, - } - - return &extension, nil + return NewGQLExt(j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey), nil } -func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLExt { +func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) *GQLExt { if tls_cert == nil || tls_key == nil { ssl_key, err := ecdsa.GenerateKey(key.Curve, rand.Reader) if err != nil { @@ -845,7 +851,7 @@ func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_ tls_cert = ssl_cert_pem tls_key = ssl_key_pem } - return GQLExt{ + return &GQLExt{ Listen: listen, SubscribeListeners: []chan GraphSignal{}, Key: key, diff --git a/gql_test.go b/gql_test.go index 57e6698..9e514ee 100644 --- a/gql_test.go +++ b/gql_test.go @@ -3,89 +3,81 @@ package graphvent import ( "testing" "time" - "net/http" - "net" "errors" - "io" - "fmt" - "encoding/json" - "bytes" "crypto/rand" "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" - "crypto/tls" - "encoding/base64" ) func TestGQLDBLoad(t * testing.T) { - ctx := logTestContext(t, []string{"test"}) - l1 := NewListener(RandID(), "Test Listener 1") - ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String()) + ctx := logTestContext(t, []string{"test", "db"}) - t1 := NewThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) - ctx.Log.Logf("test", "T1_ID: %s", t1.ID().String()) - listen_id := RandID() - ctx.Log.Logf("test", "LISTENER_ID: %s", listen_id.String()) - - u1_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + ListenerNodeType := NodeType("LISTENER") + err := ctx.RegisterNodeType(ListenerNodeType, []ExtType{ACLExtType, ListenerExtType, LockableExtType}) fatalErr(t, err) - u1 := NewUser("Test User", time.Now(), &u1_key.PublicKey, []byte{}) - ctx.Log.Logf("test", "U1_ID: %s", u1.ID().String()) + l1 := NewNode(RandID(), ListenerNodeType) + l1.Extensions[ACLExtType] = NewACLExt(nil) + listener_ext := NewListenerExt(10) + l1.Extensions[ListenerExtType] = listener_ext + l1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - fatalErr(t, err) - gql := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) - ctx.Log.Logf("test", "GQL_ID: %s", gql.ID().String()) + ctx.Log.Logf("test", "L1_ID: %s", l1.ID) - // Policy to allow gql to perform all action on all resources - p1 := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{ - gql.ID(): NewNodeActions(nil, []string{"*"}), - }) - p2 := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{ - "signal": []string{"status"}, - }, nil)) + TestThreadNodeType := NodeType("TEST_THREAD") + err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ThreadExtType, LockableExtType}) + fatalErr(t, err) - context := NewWriteContext(ctx) - err = UpdateStates(context, &gql, LockMap{ - p1.ID(): LockInfo{&p1, nil}, - p2.ID(): LockInfo{&p2, nil}, - }, func(context *StateContext) error { - return nil - }) + t1 := NewNode(RandID(), TestThreadNodeType) + t1.Extensions[ACLExtType] = NewACLExt(nil) + t1.Extensions[ThreadExtType], err = NewThreadExt(ctx, BaseThreadType, nil, nil, "init", nil) fatalErr(t, err) + t1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) + + ctx.Log.Logf("test", "T1_ID: %s", t1.ID) - ctx.Log.Logf("test", "P1_ID: %s", p1.ID().String()) - ctx.Log.Logf("test", "P2_ID: %s", p2.ID().String()) - err = AttachPolicies(ctx, &gql, &p1, &p2) + TestUserNodeType := NodeType("TEST_USER") + err = ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType}) fatalErr(t, err) - err = AttachPolicies(ctx, &l1, &p1, &p2) + + u1 := NewNode(RandID(), TestUserNodeType) + u1.Extensions[ACLExtType] = NewACLExt(nil) + + ctx.Log.Logf("test", "U1_ID: %s", u1.ID) + + TestGQLNodeType := NodeType("TEST_GQL") + err = ctx.RegisterNodeType(TestGQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType}) fatalErr(t, err) - err = AttachPolicies(ctx, &t1, &p1, &p2) + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - err = AttachPolicies(ctx, &u1, &p1, &p2) + + gql := NewNode(RandID(), TestGQLNodeType) + gql.Extensions[ACLExtType] = NewACLExt(nil) + gql.Extensions[GroupExtType] = NewGroupExt(nil) + gql.Extensions[GQLExtType] = NewGQLExt(":0", ecdh.P256(), key, nil, nil) + gql.Extensions[ThreadExtType], err = NewThreadExt(ctx, GQLThreadType, nil, nil, "ini", nil) fatalErr(t, err) + gql.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) - info := NewParentThreadInfo(true, "start", "restore") - context = NewWriteContext(ctx) - err = UpdateStates(context, &gql, NewLockMap( - NewLockInfo(&gql, []string{"users"}), - ), func(context *StateContext) error { - gql.UserMap[u1.ID()] = &u1 + ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) - err := LinkThreads(context, &gql, &gql, ChildInfo{&t1, map[InfoType]interface{}{ - "parent": &info, + info := ParentInfo{true, "start", "restore"} + context := NewWriteContext(ctx) + err = UpdateStates(context, &gql, NewACLInfo(&gql, []string{"users"}), func(context *StateContext) error { + err := LinkThreads(context, &gql, &gql, ChildInfo{&t1, map[InfoType]Info{ + ParentInfoType: &info, }}) if err != nil { return err } - return LinkLockables(context, &gql, &l1, []LockableNode{&gql}) + return LinkLockables(context, &gql, &l1, []*Node{&gql}) }) fatalErr(t, err) context = NewReadContext(ctx) - err = Signal(context, &gql, &gql, NewStatusSignal("child_linked", t1.ID())) + err = Signal(context, &gql, &gql, NewStatusSignal("child_linked", t1.ID)) fatalErr(t, err) context = NewReadContext(ctx) err = Signal(context, &gql, &gql, AbortSignal) @@ -96,10 +88,10 @@ func TestGQLDBLoad(t * testing.T) { fatalErr(t, err) } - (*GraphTester)(t).WaitForStatus(ctx, l1.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener") + (*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener") context = NewReadContext(ctx) - err = UseStates(context, &gql, LockList([]Node{&gql, &u1}, nil), func(context *StateContext) error { + err = UseStates(context, &gql, ACLList([]*Node{&gql, &u1}, nil), func(context *StateContext) error { ser1, err := gql.Serialize() ser2, err := u1.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser1) @@ -107,150 +99,30 @@ func TestGQLDBLoad(t * testing.T) { return err }) - gql_loaded, err := LoadNode(ctx, gql.ID()) + // Clear all loaded nodes from the context so it loads them from the database + ctx.Nodes = NodeMap{} + gql_loaded, err := LoadNode(ctx, gql.ID) fatalErr(t, err) - var l1_loaded *Listener = nil context = NewReadContext(ctx) - err = UseStates(context, gql_loaded, NewLockInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error { + err = UseStates(context, gql_loaded, NewACLInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error { ser, err := gql_loaded.Serialize() + lockable_ext, err := GetExt[*LockableExt](gql_loaded) + if err != nil { + return err + } ctx.Log.Logf("test", "\n%s\n\n", ser) - dependency := gql_loaded.(*GQLThread).Thread.Dependencies[l1.ID()].(*Listener) - l1_loaded = dependency - u_loaded := gql_loaded.(*GQLThread).UserMap[u1.ID()] - err = UseStates(context, gql_loaded, NewLockInfo(u_loaded, nil), func(context *StateContext) error { - ser, err := u_loaded.Serialize() - ctx.Log.Logf("test", "\n%s\n\n", ser) + dependency := lockable_ext.Dependencies[l1.ID] + listener_ext, err = GetExt[*ListenerExt](dependency) + if err != nil { return err - }) + } Signal(context, gql_loaded, gql_loaded, StopSignal) return err }) - err = ThreadLoop(ctx, gql_loaded.(ThreadNode), "start") + err = ThreadLoop(ctx, gql_loaded, "start") fatalErr(t, err) - (*GraphTester)(t).WaitForStatus(ctx, l1_loaded.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") + (*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2") } -func TestGQLAuth(t * testing.T) { - ctx := logTestContext(t, []string{"test", "gql", "policy"}) - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - fatalErr(t, err) - - p2 := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{ - "signal": []string{"status"}, - }, nil)) - - l1 := NewListener(RandID(), "GQL Thread") - err = AttachPolicies(ctx, &l1, &p2) - fatalErr(t, err) - - p3 := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{ - l1.ID(): NewNodeActions(nil, []string{"*"}), - }) - - gql := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) - err = AttachPolicies(ctx, &gql, &p2, &p3) - - context := NewWriteContext(ctx) - err = LinkLockables(context, &l1, &l1, []LockableNode{&gql}) - fatalErr(t, err) - - done := make(chan error, 1) - - go func(done chan error, thread ThreadNode) { - timeout := time.After(2*time.Second) - select { - case <-timeout: - ctx.Log.Logf("test", "TIMEOUT") - case <-done: - ctx.Log.Logf("test", "DONE") - } - context := NewReadContext(ctx) - err := Signal(context, thread, thread, StopSignal) - fatalErr(t, err) - }(done, &gql) - - go func(thread ThreadNode){ - (*GraphTester)(t).WaitForStatus(ctx, l1.Chan, "server_started", 100*time.Millisecond, "Server didn't start") - port := gql.tcp_listener.Addr().(*net.TCPAddr).Port - ctx.Log.Logf("test", "GQL_PORT: %d", port) - - customTransport := &http.Transport{ - Proxy: http.DefaultTransport.(*http.Transport).Proxy, - DialContext: http.DefaultTransport.(*http.Transport).DialContext, - MaxIdleConns: http.DefaultTransport.(*http.Transport).MaxIdleConns, - IdleConnTimeout: http.DefaultTransport.(*http.Transport).IdleConnTimeout, - ExpectContinueTimeout: http.DefaultTransport.(*http.Transport).ExpectContinueTimeout, - TLSHandshakeTimeout: http.DefaultTransport.(*http.Transport).TLSHandshakeTimeout, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - client := &http.Client{Transport: customTransport} - url := fmt.Sprintf("https://localhost:%d/auth", port) - - id, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - fatalErr(t, err) - - auth_req, ec_key, err := NewAuthReqJSON(ecdh.P256(), id) - fatalErr(t, err) - - str, err := json.Marshal(auth_req) - fatalErr(t, err) - - b := bytes.NewBuffer(str) - req, err := http.NewRequest("PUT", url, b) - fatalErr(t, err) - - resp, err := client.Do(req) - fatalErr(t, err) - - body, err := io.ReadAll(resp.Body) - fatalErr(t, err) - - resp.Body.Close() - - var j AuthRespJSON - err = json.Unmarshal(body, &j) - fatalErr(t, err) - - shared, err := ParseAuthRespJSON(j, elliptic.P256(), ecdh.P256(), ec_key) - fatalErr(t, err) - - url = fmt.Sprintf("https://localhost:%d/gql", port) - ser, err := json.MarshalIndent(&GQLPayload{ - Query: "query { Self { Users { ID, Name } } }", - }, "", " ") - fatalErr(t, err) - - b = bytes.NewBuffer(ser) - req, err = http.NewRequest("GET", url, b) - fatalErr(t, err) - - req.SetBasicAuth(KeyID(&id.PublicKey).String(), base64.StdEncoding.EncodeToString(shared)) - resp, err = client.Do(req) - fatalErr(t, err) - - body, err = io.ReadAll(resp.Body) - fatalErr(t, err) - - resp.Body.Close() - - ctx.Log.Logf("test", "TEST_RESP: %s", body) - - req.SetBasicAuth(KeyID(&id.PublicKey).String(), "BAD_PASSWORD") - resp, err = client.Do(req) - fatalErr(t, err) - - body, err = io.ReadAll(resp.Body) - fatalErr(t, err) - - resp.Body.Close() - - ctx.Log.Logf("test", "TEST_RESP: %s", body) - - done <- nil - }(&gql) - - err = ThreadLoop(ctx, &gql, "start") - fatalErr(t, err) -} diff --git a/lockable.go b/lockable.go index 08a761d..ba2ea8f 100644 --- a/lockable.go +++ b/lockable.go @@ -6,21 +6,33 @@ import ( ) type ListenerExt struct { + Buffer int Chan chan GraphSignal } -func NewListenerExt(buffer int) ListenerExt { - return ListenerExt{ +func NewListenerExt(buffer int) *ListenerExt { + return &ListenerExt{ + Buffer: buffer, Chan: make(chan GraphSignal, buffer), } } +func LoadListenerExt(ctx *Context, data []byte) (Extension, error) { + var j int + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + return NewListenerExt(j), nil +} + const ListenerExtType = ExtType("LISTENER") func (listener ListenerExt) Type() ExtType { return ListenerExtType } -func (ext ListenerExt) Process(context *StateContext, signal GraphSignal) error { +func (ext ListenerExt) Process(context *StateContext, node *Node, signal GraphSignal) error { select { case ext.Chan <- signal: default: @@ -29,8 +41,8 @@ func (ext ListenerExt) Process(context *StateContext, signal GraphSignal) error return nil } -func (node ListenerExt) Serialize() ([]byte, error) { - return []byte{}, nil +func (ext ListenerExt) Serialize() ([]byte, error) { + return json.MarshalIndent(ext.Buffer, "", " ") } type LockableExt struct { @@ -61,6 +73,27 @@ func (ext *LockableExt) Serialize() ([]byte, error) { }, "", " ") } +func NewLockableExt(owner *Node, requirements NodeMap, dependencies NodeMap, locks_held NodeMap) *LockableExt { + if requirements == nil { + requirements = NodeMap{} + } + + if dependencies == nil { + dependencies = NodeMap{} + } + + if locks_held == nil { + locks_held = NodeMap{} + } + + return &LockableExt{ + Owner: owner, + Requirements: requirements, + Dependencies: dependencies, + LocksHeld: locks_held, + } +} + func LoadLockableExt(ctx *Context, data []byte) (Extension, error) { var j LockableExtJSON err := json.Unmarshal(data, &j) @@ -88,14 +121,8 @@ func LoadLockableExt(ctx *Context, data []byte) (Extension, error) { return nil, err } - extension := LockableExt{ - Owner: owner, - Requirements: requirements, - Dependencies: dependencies, - LocksHeld: locks_held, - } - return &extension, nil + return NewLockableExt(owner, requirements, dependencies, locks_held), nil } func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphSignal) error { diff --git a/node.go b/node.go index c1e1a58..d47b72f 100644 --- a/node.go +++ b/node.go @@ -78,12 +78,29 @@ type Node struct { Extensions map[ExtType]Extension } +func GetCtx[T Extension, C any](ctx *Context) (C, error) { + var zero T + var zero_ctx C + ext_type := zero.Type() + ext_info := ctx.ExtByType(ext_type) + if ext_info == nil { + return zero_ctx, fmt.Errorf("%s is not an extension in ctx", ext_type) + } + + ext_ctx, ok := ext_info.Data.(C) + if ok == false { + return zero_ctx, fmt.Errorf("context for %s is %+v, not %+v", ext_type, reflect.TypeOf(ext_info.Data), reflect.TypeOf(zero)) + } + + return ext_ctx, nil +} + func GetExt[T Extension](node *Node) (T, error) { var zero T ext_type := zero.Type() ext, exists := node.Extensions[ext_type] if exists == false { - return zero, fmt.Errorf("%s does not have %s extension", node.ID, ext_type) + return zero, fmt.Errorf("%s does not have %s extension - %+v", node.ID, ext_type, node.Extensions) } ret, ok := ext.(T) @@ -373,6 +390,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { node = &new_node ctx.Nodes[id] = node + found_extensions := []ExtType{} // Parse each of the extensions from the db for _, ext_db := range(node_db.Extensions) { type_hash := ext_db.Header.TypeHash @@ -385,7 +403,44 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, err } node.Extensions[def.Type] = extension - ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s - 0x%x", id, type_hash) + found_extensions = append(found_extensions, def.Type) + ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s - 0x%x - %+v", id, type_hash, def.Type) + } + + missing_extensions := []ExtType{} + for _, ext := range(node_type.Extensions) { + found := false + for _, found_ext := range(found_extensions) { + if found_ext == ext { + found = true + break + } + } + if found == false { + missing_extensions = append(missing_extensions, ext) + } + } + + if len(missing_extensions) > 0 { + return nil, fmt.Errorf("DB_LOAD_MISSING_EXTENSIONS: %s - %+v - %+v", id, node_type, missing_extensions) + } + + extra_extensions := []ExtType{} + for _, found_ext := range(found_extensions) { + found := false + for _, ext := range(node_type.Extensions) { + if ext == found_ext { + found = true + break + } + } + if found == false { + extra_extensions = append(extra_extensions, found_ext) + } + } + + if len(extra_extensions) > 0 { + return nil, fmt.Errorf("DB_LOAD_EXTRA_EXTENSIONS: %s - %+v - %+v", id, node_type, extra_extensions) } ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) diff --git a/node_test.go b/node_test.go index 71e7e8e..563ca4c 100644 --- a/node_test.go +++ b/node_test.go @@ -7,7 +7,7 @@ import ( func TestNodeDB(t *testing.T) { ctx := logTestContext(t, []string{"test", "db", "node", "policy"}) node_type := NodeType("test") - err := ctx.RegisterNodeType(node_type, []ExtType{}) + err := ctx.RegisterNodeType(node_type, []ExtType{"ACL"}) fatalErr(t, err) node := NewNode(RandID(), node_type) node.Extensions[ACLExtType] = &ACLExt{ @@ -18,7 +18,7 @@ func TestNodeDB(t *testing.T) { context := NewWriteContext(ctx) err = UpdateStates(context, &node, NewACLInfo(&node, []string{"test"}), func(context *StateContext) error { ser, err := node.Serialize() - ctx.Log.Logf("test", "NODE_SER: %s", ser) + ctx.Log.Logf("test", "NODE_SER: %+v", ser) return err }) fatalErr(t, err) diff --git a/policy.go b/policy.go index 9de62c0..9bc6560 100644 --- a/policy.go +++ b/policy.go @@ -65,6 +65,16 @@ func LoadACLExt(ctx *Context, data []byte) (Extension, error) { }, nil } +func NewACLExt(delegations NodeMap) *ACLExt { + if delegations == nil { + delegations = NodeMap{} + } + + return &ACLExt{ + Delegations: delegations, + } +} + func (ext *ACLExt) Serialize() ([]byte, error) { delegations := make([]string, len(ext.Delegations)) i := 0 diff --git a/thread.go b/thread.go index 67b00df..1464502 100644 --- a/thread.go +++ b/thread.go @@ -6,18 +6,113 @@ import ( "sync" "errors" "encoding/json" + "crypto/sha512" + "encoding/binary" ) +type ThreadAction func(*Context, *Node, *ThreadExt)(string, error) +type ThreadActions map[string]ThreadAction +type ThreadHandler func(*Context, *Node, *ThreadExt, GraphSignal)(string, error) +type ThreadHandlers map[string]ThreadHandler + +type InfoType string +func (t InfoType) String() string { + return string(t) +} + +type Info interface { + Serializable[InfoType] +} + +// Data required by a parent thread to restore it's children +type ParentInfo struct { + Start bool `json:"start"` + StartAction string `json:"start_action"` + RestoreAction string `json:"restore_action"` +} + +const ParentInfoType = InfoType("PARENT") +func (info *ParentInfo) Type() InfoType { + return ParentInfoType +} + +func (info *ParentInfo) Serialize() ([]byte, error) { + return json.MarshalIndent(info, "", " ") +} + type QueuedAction struct { Timeout time.Time `json:"time"` Action string `json:"action"` } +type ThreadType string +func (thread ThreadType) Hash() uint64 { + hash := sha512.Sum512([]byte(fmt.Sprintf("THREAD: %s", string(thread)))) + return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)]) +} + +type ThreadInfo struct { + Actions ThreadActions + Handlers ThreadHandlers +} + +type InfoLoadFunc func([]byte)(Info, error) type ThreadExtContext struct { - Loads map[InfoType]func([]byte)ThreadInfo + Types map[ThreadType]ThreadInfo + Loads map[InfoType]InfoLoadFunc } +const BaseThreadType = ThreadType("BASE") func NewThreadExtContext() *ThreadExtContext { + return &ThreadExtContext{ + Types: map[ThreadType]ThreadInfo{ + BaseThreadType: ThreadInfo{ + Actions: BaseThreadActions, + Handlers: BaseThreadHandlers, + }, + }, + Loads: map[InfoType]InfoLoadFunc{ + ParentInfoType: func(data []byte) (Info, error) { + var info ParentInfo + err := json.Unmarshal(data, &info) + if err != nil { + return nil, err + } + + return &info, nil + }, + }, + } +} + +func (ctx *ThreadExtContext) RegisterThreadType(thread_type ThreadType, actions ThreadActions, handlers ThreadHandlers) error { + if actions == nil || handlers == nil { + return fmt.Errorf("Cannot register ThreadType %s with nil actions or handlers", thread_type) + } + + _, exists := ctx.Types[thread_type] + if exists == true { + return fmt.Errorf("ThreadType %s already registered in ThreadExtContext, cannot register again", thread_type) + } + ctx.Types[thread_type] = ThreadInfo{ + Actions: actions, + Handlers: handlers, + } + + return nil +} + +func (ctx *ThreadExtContext) RegisterInfoType(info_type InfoType, load_fn InfoLoadFunc) error { + if load_fn == nil { + return fmt.Errorf("Cannot register %s with nil load_fn", info_type) + } + + _, exists := ctx.Loads[info_type] + if exists == true { + return fmt.Errorf("InfoType %s is already registered in ThreadExtContext, cannot register again", info_type) + } + + ctx.Loads[info_type] = load_fn return nil } @@ -25,6 +120,8 @@ type ThreadExt struct { Actions ThreadActions Handlers ThreadHandlers + ThreadType ThreadType + SignalChan chan GraphSignal TimeoutChan <-chan time.Time @@ -43,6 +140,7 @@ type ThreadExt struct { type ThreadExtJSON struct { State string `json:"state"` + Type string `json:"type"` Parent string `json:"parent"` Children map[string][]byte `json:"children"` ActionQueue []QueuedAction @@ -52,6 +150,39 @@ func (ext *ThreadExt) Serialize() ([]byte, error) { return nil, fmt.Errorf("NOT_IMPLEMENTED") } +func NewThreadExt(ctx*Context, thread_type ThreadType, parent *Node, children map[NodeID]ChildInfo, state string, action_queue []QueuedAction) (*ThreadExt, error) { + if children == nil { + children = map[NodeID]ChildInfo{} + } + + if action_queue == nil { + action_queue = []QueuedAction{} + } + + thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx) + if err != nil { + return nil, err + } + type_info, exists := thread_ctx.Types[thread_type] + if exists == false { + return nil, fmt.Errorf("Tried to load thread type %s which is not in context", thread_type) + } + next_action, timeout_chan := SoonestAction(action_queue) + + return &ThreadExt{ + Actions: type_info.Actions, + Handlers: type_info.Handlers, + SignalChan: make(chan GraphSignal, THREAD_BUFFER_SIZE), + TimeoutChan: timeout_chan, + Active: false, + State: state, + Parent: parent, + Children: children, + ActionQueue: action_queue, + NextAction: next_action, + }, nil +} + const THREAD_BUFFER_SIZE int = 1024 func LoadThreadExt(ctx *Context, data []byte) (Extension, error) { var j ThreadExtJSON @@ -75,26 +206,11 @@ func LoadThreadExt(ctx *Context, data []byte) (Extension, error) { children[child_node.ID] = ChildInfo{ Child: child_node, - Infos: map[InfoType]ThreadInfo{}, + Infos: map[InfoType]Info{}, } } - next_action, timeout_chan := SoonestAction(j.ActionQueue) - - extension := ThreadExt{ - Actions: BaseThreadActions, - Handlers: BaseThreadHandlers, - SignalChan: make(chan GraphSignal, THREAD_BUFFER_SIZE), - TimeoutChan: timeout_chan, - Active: false, - State: j.State, - Parent: parent, - Children: children, - ActionQueue: j.ActionQueue, - NextAction: next_action, - } - - return &extension, nil + return NewThreadExt(ctx, ThreadType(j.Type), parent, children, j.State, j.ActionQueue) } const ThreadExtType = ExtType("THREAD") @@ -281,44 +397,14 @@ func LinkThreads(context *StateContext, principal *Node, thread *Node, info Chil }) } -type ThreadAction func(*Context, *Node, *ThreadExt)(string, error) -type ThreadActions map[string]ThreadAction -type ThreadHandler func(*Context, *Node, *ThreadExt, GraphSignal)(string, error) -type ThreadHandlers map[string]ThreadHandler - -type InfoType string -func (t InfoType) String() string { - return string(t) -} - -type ThreadInfo interface { - Serializable[InfoType] -} - -// Data required by a parent thread to restore it's children -type ParentThreadInfo struct { - Start bool `json:"start"` - StartAction string `json:"start_action"` - RestoreAction string `json:"restore_action"` -} - -const ParentThreadInfoType = InfoType("PARENT") -func (info *ParentThreadInfo) Type() InfoType { - return ParentThreadInfoType -} - -func (info *ParentThreadInfo) Serialize() ([]byte, error) { - return json.MarshalIndent(info, "", " ") -} - type ChildInfo struct { Child *Node - Infos map[InfoType]ThreadInfo + Infos map[InfoType]Info } -func NewChildInfo(child *Node, infos map[InfoType]ThreadInfo) ChildInfo { +func NewChildInfo(child *Node, infos map[InfoType]Info) ChildInfo { if infos == nil { - infos = map[InfoType]ThreadInfo{} + infos = map[InfoType]Info{} } return ChildInfo{ @@ -327,48 +413,6 @@ func NewChildInfo(child *Node, infos map[InfoType]ThreadInfo) ChildInfo { } } -var deserializers = map[InfoType]func(interface{})(interface{}, error) { - "parent": func(raw interface{})(interface{}, error) { - m, ok := raw.(map[string]interface{}) - if ok == false { - return nil, fmt.Errorf("Failed to cast parent info to map") - } - start, ok := m["start"].(bool) - if ok == false { - return nil, fmt.Errorf("Failed to get start from parent info") - } - start_action, ok := m["start_action"].(string) - if ok == false { - return nil, fmt.Errorf("Failed to get start_action from parent info") - } - restore_action, ok := m["restore_action"].(string) - if ok == false { - return nil, fmt.Errorf("Failed to get restore_action from parent info") - } - - return &ParentThreadInfo{ - Start: start, - StartAction: start_action, - RestoreAction: restore_action, - }, nil - }, -} - -func NewThreadExt(buffer int, name string, state string, actions ThreadActions, handlers ThreadHandlers) ThreadExt { - return ThreadExt{ - Actions: actions, - Handlers: handlers, - SignalChan: make(chan GraphSignal, buffer), - TimeoutChan: nil, - Active: false, - State: state, - Parent: nil, - Children: map[NodeID]ChildInfo{}, - ActionQueue: []QueuedAction{}, - NextAction: nil, - } -} - func (ext *ThreadExt) SetActive(active bool) error { ext.ActiveLock.Lock() defer ext.ActiveLock.Unlock() @@ -485,7 +529,7 @@ func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID) return nil } - parent_info, exists := info.Infos["parent"].(*ParentThreadInfo) + parent_info, exists := info.Infos["parent"].(*ParentInfo) if exists == false { panic("ran ThreadChildLinked from a thread that doesn't require 'parent' child info. library party foul") } @@ -520,7 +564,7 @@ func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal } return UpdateStates(context, thread, NewACLInfo(info.Child, []string{"start"}), func(context *StateContext) error { - parent_info, exists := info.Infos["parent"].(*ParentThreadInfo) + parent_info, exists := info.Infos["parent"].(*ParentInfo) if exists == false { return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info") } @@ -544,7 +588,7 @@ func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start boo return err } - parent_info := info.Infos["parent"].(*ParentThreadInfo) + parent_info := info.Infos["parent"].(*ParentInfo) if parent_info.Start == true && child_ext.State != "finished" { ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID) if start == true { diff --git a/user.go b/user.go index 67ce10a..930fbd9 100644 --- a/user.go +++ b/user.go @@ -88,6 +88,15 @@ func (ext *GroupExt) Serialize() ([]byte, error) { }, "", " ") } +func NewGroupExt(members NodeMap) *GroupExt { + if members == nil { + members = NodeMap{} + } + return &GroupExt{ + Members: members, + } +} + func LoadGroupExt(ctx *Context, data []byte) (Extension, error) { var j struct { Members []string `json:"members"` @@ -103,10 +112,7 @@ func LoadGroupExt(ctx *Context, data []byte) (Extension, error) { return nil, err } - extension := GroupExt{ - Members: members, - } - return &extension, nil + return NewGroupExt(members), nil } func (ext *GroupExt) Process(context *StateContext, node *Node, signal GraphSignal) error {