Started adding back gql tests

graph-rework-2
noah metz 2023-07-26 11:56:10 -06:00
parent c4156ee146
commit fa6142d880
9 changed files with 354 additions and 320 deletions

@ -34,10 +34,14 @@ type Context struct {
Nodes map[NodeID]*Node
}
func (ctx *Context) ExtByType(ext_type ExtType) ExtensionInfo {
func (ctx *Context) ExtByType(ext_type ExtType) *ExtensionInfo {
type_hash := ext_type.Hash()
ext, _ := ctx.Extensions[type_hash]
return ext
ext, ok := ctx.Extensions[type_hash]
if ok == true {
return &ext
} else {
return nil
}
}
func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) error {
@ -114,6 +118,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err
}
err = ctx.RegisterExtension(ListenerExtType, LoadListenerExt, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(ThreadExtType, LoadThreadExt, NewThreadExtContext())
if err != nil {
return nil, err
@ -134,5 +143,10 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err
}
err = RegisterGQLThread(ctx)
if err != nil {
return nil, err
}
return ctx, nil
}

@ -30,6 +30,21 @@ import (
"encoding/pem"
)
const GQLThreadType = ThreadType("GQL")
func RegisterGQLThread(ctx *Context) error {
thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx)
if err != nil {
return err
}
err = thread_ctx.RegisterThreadType(GQLThreadType, gql_actions, gql_handlers)
if err != nil {
return err
}
return nil
}
type AuthReqJSON struct {
Time time.Time `json:"time"`
Pubkey []byte `json:"pubkey"`
@ -793,19 +808,10 @@ func LoadGQLExt(ctx *Context, data []byte) (Extension, error) {
return nil, err
}
extension := GQLExt{
Listen: j.Listen,
Key: key,
ECDH: ecdh_curve,
SubscribeListeners: []chan GraphSignal{},
tls_key: j.TLSKey,
tls_cert: j.TLSCert,
}
return &extension, nil
return NewGQLExt(j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey), nil
}
func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLExt {
func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) *GQLExt {
if tls_cert == nil || tls_key == nil {
ssl_key, err := ecdsa.GenerateKey(key.Curve, rand.Reader)
if err != nil {
@ -845,7 +851,7 @@ func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_
tls_cert = ssl_cert_pem
tls_key = ssl_key_pem
}
return GQLExt{
return &GQLExt{
Listen: listen,
SubscribeListeners: []chan GraphSignal{},
Key: key,

@ -3,89 +3,81 @@ package graphvent
import (
"testing"
"time"
"net/http"
"net"
"errors"
"io"
"fmt"
"encoding/json"
"bytes"
"crypto/rand"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/tls"
"encoding/base64"
)
func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{"test"})
l1 := NewListener(RandID(), "Test Listener 1")
ctx.Log.Logf("test", "L1_ID: %s", l1.ID().String())
ctx := logTestContext(t, []string{"test", "db"})
t1 := NewThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
ctx.Log.Logf("test", "T1_ID: %s", t1.ID().String())
listen_id := RandID()
ctx.Log.Logf("test", "LISTENER_ID: %s", listen_id.String())
u1_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
ListenerNodeType := NodeType("LISTENER")
err := ctx.RegisterNodeType(ListenerNodeType, []ExtType{ACLExtType, ListenerExtType, LockableExtType})
fatalErr(t, err)
u1 := NewUser("Test User", time.Now(), &u1_key.PublicKey, []byte{})
ctx.Log.Logf("test", "U1_ID: %s", u1.ID().String())
l1 := NewNode(RandID(), ListenerNodeType)
l1.Extensions[ACLExtType] = NewACLExt(nil)
listener_ext := NewListenerExt(10)
l1.Extensions[ListenerExtType] = listener_ext
l1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
gql := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID().String())
ctx.Log.Logf("test", "L1_ID: %s", l1.ID)
// Policy to allow gql to perform all action on all resources
p1 := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{
gql.ID(): NewNodeActions(nil, []string{"*"}),
})
p2 := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{
"signal": []string{"status"},
}, nil))
TestThreadNodeType := NodeType("TEST_THREAD")
err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ThreadExtType, LockableExtType})
fatalErr(t, err)
context := NewWriteContext(ctx)
err = UpdateStates(context, &gql, LockMap{
p1.ID(): LockInfo{&p1, nil},
p2.ID(): LockInfo{&p2, nil},
}, func(context *StateContext) error {
return nil
})
t1 := NewNode(RandID(), TestThreadNodeType)
t1.Extensions[ACLExtType] = NewACLExt(nil)
t1.Extensions[ThreadExtType], err = NewThreadExt(ctx, BaseThreadType, nil, nil, "init", nil)
fatalErr(t, err)
t1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
ctx.Log.Logf("test", "T1_ID: %s", t1.ID)
ctx.Log.Logf("test", "P1_ID: %s", p1.ID().String())
ctx.Log.Logf("test", "P2_ID: %s", p2.ID().String())
err = AttachPolicies(ctx, &gql, &p1, &p2)
TestUserNodeType := NodeType("TEST_USER")
err = ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType})
fatalErr(t, err)
err = AttachPolicies(ctx, &l1, &p1, &p2)
u1 := NewNode(RandID(), TestUserNodeType)
u1.Extensions[ACLExtType] = NewACLExt(nil)
ctx.Log.Logf("test", "U1_ID: %s", u1.ID)
TestGQLNodeType := NodeType("TEST_GQL")
err = ctx.RegisterNodeType(TestGQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType})
fatalErr(t, err)
err = AttachPolicies(ctx, &t1, &p1, &p2)
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
err = AttachPolicies(ctx, &u1, &p1, &p2)
gql := NewNode(RandID(), TestGQLNodeType)
gql.Extensions[ACLExtType] = NewACLExt(nil)
gql.Extensions[GroupExtType] = NewGroupExt(nil)
gql.Extensions[GQLExtType] = NewGQLExt(":0", ecdh.P256(), key, nil, nil)
gql.Extensions[ThreadExtType], err = NewThreadExt(ctx, GQLThreadType, nil, nil, "ini", nil)
fatalErr(t, err)
gql.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
info := NewParentThreadInfo(true, "start", "restore")
context = NewWriteContext(ctx)
err = UpdateStates(context, &gql, NewLockMap(
NewLockInfo(&gql, []string{"users"}),
), func(context *StateContext) error {
gql.UserMap[u1.ID()] = &u1
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
err := LinkThreads(context, &gql, &gql, ChildInfo{&t1, map[InfoType]interface{}{
"parent": &info,
info := ParentInfo{true, "start", "restore"}
context := NewWriteContext(ctx)
err = UpdateStates(context, &gql, NewACLInfo(&gql, []string{"users"}), func(context *StateContext) error {
err := LinkThreads(context, &gql, &gql, ChildInfo{&t1, map[InfoType]Info{
ParentInfoType: &info,
}})
if err != nil {
return err
}
return LinkLockables(context, &gql, &l1, []LockableNode{&gql})
return LinkLockables(context, &gql, &l1, []*Node{&gql})
})
fatalErr(t, err)
context = NewReadContext(ctx)
err = Signal(context, &gql, &gql, NewStatusSignal("child_linked", t1.ID()))
err = Signal(context, &gql, &gql, NewStatusSignal("child_linked", t1.ID))
fatalErr(t, err)
context = NewReadContext(ctx)
err = Signal(context, &gql, &gql, AbortSignal)
@ -96,10 +88,10 @@ func TestGQLDBLoad(t * testing.T) {
fatalErr(t, err)
}
(*GraphTester)(t).WaitForStatus(ctx, l1.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener")
(*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener")
context = NewReadContext(ctx)
err = UseStates(context, &gql, LockList([]Node{&gql, &u1}, nil), func(context *StateContext) error {
err = UseStates(context, &gql, ACLList([]*Node{&gql, &u1}, nil), func(context *StateContext) error {
ser1, err := gql.Serialize()
ser2, err := u1.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser1)
@ -107,150 +99,30 @@ func TestGQLDBLoad(t * testing.T) {
return err
})
gql_loaded, err := LoadNode(ctx, gql.ID())
// Clear all loaded nodes from the context so it loads them from the database
ctx.Nodes = NodeMap{}
gql_loaded, err := LoadNode(ctx, gql.ID)
fatalErr(t, err)
var l1_loaded *Listener = nil
context = NewReadContext(ctx)
err = UseStates(context, gql_loaded, NewLockInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error {
err = UseStates(context, gql_loaded, NewACLInfo(gql_loaded, []string{"users", "children", "requirements"}), func(context *StateContext) error {
ser, err := gql_loaded.Serialize()
lockable_ext, err := GetExt[*LockableExt](gql_loaded)
if err != nil {
return err
}
ctx.Log.Logf("test", "\n%s\n\n", ser)
dependency := gql_loaded.(*GQLThread).Thread.Dependencies[l1.ID()].(*Listener)
l1_loaded = dependency
u_loaded := gql_loaded.(*GQLThread).UserMap[u1.ID()]
err = UseStates(context, gql_loaded, NewLockInfo(u_loaded, nil), func(context *StateContext) error {
ser, err := u_loaded.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser)
dependency := lockable_ext.Dependencies[l1.ID]
listener_ext, err = GetExt[*ListenerExt](dependency)
if err != nil {
return err
})
}
Signal(context, gql_loaded, gql_loaded, StopSignal)
return err
})
err = ThreadLoop(ctx, gql_loaded.(ThreadNode), "start")
err = ThreadLoop(ctx, gql_loaded, "start")
fatalErr(t, err)
(*GraphTester)(t).WaitForStatus(ctx, l1_loaded.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
(*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "stopped", 100*time.Millisecond, "Didn't receive stopped on update_channel_2")
}
func TestGQLAuth(t * testing.T) {
ctx := logTestContext(t, []string{"test", "gql", "policy"})
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
p2 := NewSimplePolicy(RandID(), NewNodeActions(NodeActions{
"signal": []string{"status"},
}, nil))
l1 := NewListener(RandID(), "GQL Thread")
err = AttachPolicies(ctx, &l1, &p2)
fatalErr(t, err)
p3 := NewPerNodePolicy(RandID(), map[NodeID]NodeActions{
l1.ID(): NewNodeActions(nil, []string{"*"}),
})
gql := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil)
err = AttachPolicies(ctx, &gql, &p2, &p3)
context := NewWriteContext(ctx)
err = LinkLockables(context, &l1, &l1, []LockableNode{&gql})
fatalErr(t, err)
done := make(chan error, 1)
go func(done chan error, thread ThreadNode) {
timeout := time.After(2*time.Second)
select {
case <-timeout:
ctx.Log.Logf("test", "TIMEOUT")
case <-done:
ctx.Log.Logf("test", "DONE")
}
context := NewReadContext(ctx)
err := Signal(context, thread, thread, StopSignal)
fatalErr(t, err)
}(done, &gql)
go func(thread ThreadNode){
(*GraphTester)(t).WaitForStatus(ctx, l1.Chan, "server_started", 100*time.Millisecond, "Server didn't start")
port := gql.tcp_listener.Addr().(*net.TCPAddr).Port
ctx.Log.Logf("test", "GQL_PORT: %d", port)
customTransport := &http.Transport{
Proxy: http.DefaultTransport.(*http.Transport).Proxy,
DialContext: http.DefaultTransport.(*http.Transport).DialContext,
MaxIdleConns: http.DefaultTransport.(*http.Transport).MaxIdleConns,
IdleConnTimeout: http.DefaultTransport.(*http.Transport).IdleConnTimeout,
ExpectContinueTimeout: http.DefaultTransport.(*http.Transport).ExpectContinueTimeout,
TLSHandshakeTimeout: http.DefaultTransport.(*http.Transport).TLSHandshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{Transport: customTransport}
url := fmt.Sprintf("https://localhost:%d/auth", port)
id, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
auth_req, ec_key, err := NewAuthReqJSON(ecdh.P256(), id)
fatalErr(t, err)
str, err := json.Marshal(auth_req)
fatalErr(t, err)
b := bytes.NewBuffer(str)
req, err := http.NewRequest("PUT", url, b)
fatalErr(t, err)
resp, err := client.Do(req)
fatalErr(t, err)
body, err := io.ReadAll(resp.Body)
fatalErr(t, err)
resp.Body.Close()
var j AuthRespJSON
err = json.Unmarshal(body, &j)
fatalErr(t, err)
shared, err := ParseAuthRespJSON(j, elliptic.P256(), ecdh.P256(), ec_key)
fatalErr(t, err)
url = fmt.Sprintf("https://localhost:%d/gql", port)
ser, err := json.MarshalIndent(&GQLPayload{
Query: "query { Self { Users { ID, Name } } }",
}, "", " ")
fatalErr(t, err)
b = bytes.NewBuffer(ser)
req, err = http.NewRequest("GET", url, b)
fatalErr(t, err)
req.SetBasicAuth(KeyID(&id.PublicKey).String(), base64.StdEncoding.EncodeToString(shared))
resp, err = client.Do(req)
fatalErr(t, err)
body, err = io.ReadAll(resp.Body)
fatalErr(t, err)
resp.Body.Close()
ctx.Log.Logf("test", "TEST_RESP: %s", body)
req.SetBasicAuth(KeyID(&id.PublicKey).String(), "BAD_PASSWORD")
resp, err = client.Do(req)
fatalErr(t, err)
body, err = io.ReadAll(resp.Body)
fatalErr(t, err)
resp.Body.Close()
ctx.Log.Logf("test", "TEST_RESP: %s", body)
done <- nil
}(&gql)
err = ThreadLoop(ctx, &gql, "start")
fatalErr(t, err)
}

@ -6,21 +6,33 @@ import (
)
type ListenerExt struct {
Buffer int
Chan chan GraphSignal
}
func NewListenerExt(buffer int) ListenerExt {
return ListenerExt{
func NewListenerExt(buffer int) *ListenerExt {
return &ListenerExt{
Buffer: buffer,
Chan: make(chan GraphSignal, buffer),
}
}
func LoadListenerExt(ctx *Context, data []byte) (Extension, error) {
var j int
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
return NewListenerExt(j), nil
}
const ListenerExtType = ExtType("LISTENER")
func (listener ListenerExt) Type() ExtType {
return ListenerExtType
}
func (ext ListenerExt) Process(context *StateContext, signal GraphSignal) error {
func (ext ListenerExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
select {
case ext.Chan <- signal:
default:
@ -29,8 +41,8 @@ func (ext ListenerExt) Process(context *StateContext, signal GraphSignal) error
return nil
}
func (node ListenerExt) Serialize() ([]byte, error) {
return []byte{}, nil
func (ext ListenerExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext.Buffer, "", " ")
}
type LockableExt struct {
@ -61,6 +73,27 @@ func (ext *LockableExt) Serialize() ([]byte, error) {
}, "", " ")
}
func NewLockableExt(owner *Node, requirements NodeMap, dependencies NodeMap, locks_held NodeMap) *LockableExt {
if requirements == nil {
requirements = NodeMap{}
}
if dependencies == nil {
dependencies = NodeMap{}
}
if locks_held == nil {
locks_held = NodeMap{}
}
return &LockableExt{
Owner: owner,
Requirements: requirements,
Dependencies: dependencies,
LocksHeld: locks_held,
}
}
func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
var j LockableExtJSON
err := json.Unmarshal(data, &j)
@ -88,14 +121,8 @@ func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
return nil, err
}
extension := LockableExt{
Owner: owner,
Requirements: requirements,
Dependencies: dependencies,
LocksHeld: locks_held,
}
return &extension, nil
return NewLockableExt(owner, requirements, dependencies, locks_held), nil
}
func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphSignal) error {

@ -78,12 +78,29 @@ type Node struct {
Extensions map[ExtType]Extension
}
func GetCtx[T Extension, C any](ctx *Context) (C, error) {
var zero T
var zero_ctx C
ext_type := zero.Type()
ext_info := ctx.ExtByType(ext_type)
if ext_info == nil {
return zero_ctx, fmt.Errorf("%s is not an extension in ctx", ext_type)
}
ext_ctx, ok := ext_info.Data.(C)
if ok == false {
return zero_ctx, fmt.Errorf("context for %s is %+v, not %+v", ext_type, reflect.TypeOf(ext_info.Data), reflect.TypeOf(zero))
}
return ext_ctx, nil
}
func GetExt[T Extension](node *Node) (T, error) {
var zero T
ext_type := zero.Type()
ext, exists := node.Extensions[ext_type]
if exists == false {
return zero, fmt.Errorf("%s does not have %s extension", node.ID, ext_type)
return zero, fmt.Errorf("%s does not have %s extension - %+v", node.ID, ext_type, node.Extensions)
}
ret, ok := ext.(T)
@ -373,6 +390,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
node = &new_node
ctx.Nodes[id] = node
found_extensions := []ExtType{}
// Parse each of the extensions from the db
for _, ext_db := range(node_db.Extensions) {
type_hash := ext_db.Header.TypeHash
@ -385,7 +403,44 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
return nil, err
}
node.Extensions[def.Type] = extension
ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s - 0x%x", id, type_hash)
found_extensions = append(found_extensions, def.Type)
ctx.Log.Logf("db", "DB_EXTENSION_LOADED: %s - 0x%x - %+v", id, type_hash, def.Type)
}
missing_extensions := []ExtType{}
for _, ext := range(node_type.Extensions) {
found := false
for _, found_ext := range(found_extensions) {
if found_ext == ext {
found = true
break
}
}
if found == false {
missing_extensions = append(missing_extensions, ext)
}
}
if len(missing_extensions) > 0 {
return nil, fmt.Errorf("DB_LOAD_MISSING_EXTENSIONS: %s - %+v - %+v", id, node_type, missing_extensions)
}
extra_extensions := []ExtType{}
for _, found_ext := range(found_extensions) {
found := false
for _, ext := range(node_type.Extensions) {
if ext == found_ext {
found = true
break
}
}
if found == false {
extra_extensions = append(extra_extensions, found_ext)
}
}
if len(extra_extensions) > 0 {
return nil, fmt.Errorf("DB_LOAD_EXTRA_EXTENSIONS: %s - %+v - %+v", id, node_type, extra_extensions)
}
ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id)

@ -7,7 +7,7 @@ import (
func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{"test", "db", "node", "policy"})
node_type := NodeType("test")
err := ctx.RegisterNodeType(node_type, []ExtType{})
err := ctx.RegisterNodeType(node_type, []ExtType{"ACL"})
fatalErr(t, err)
node := NewNode(RandID(), node_type)
node.Extensions[ACLExtType] = &ACLExt{
@ -18,7 +18,7 @@ func TestNodeDB(t *testing.T) {
context := NewWriteContext(ctx)
err = UpdateStates(context, &node, NewACLInfo(&node, []string{"test"}), func(context *StateContext) error {
ser, err := node.Serialize()
ctx.Log.Logf("test", "NODE_SER: %s", ser)
ctx.Log.Logf("test", "NODE_SER: %+v", ser)
return err
})
fatalErr(t, err)

@ -65,6 +65,16 @@ func LoadACLExt(ctx *Context, data []byte) (Extension, error) {
}, nil
}
func NewACLExt(delegations NodeMap) *ACLExt {
if delegations == nil {
delegations = NodeMap{}
}
return &ACLExt{
Delegations: delegations,
}
}
func (ext *ACLExt) Serialize() ([]byte, error) {
delegations := make([]string, len(ext.Delegations))
i := 0

@ -6,18 +6,113 @@ import (
"sync"
"errors"
"encoding/json"
"crypto/sha512"
"encoding/binary"
)
type ThreadAction func(*Context, *Node, *ThreadExt)(string, error)
type ThreadActions map[string]ThreadAction
type ThreadHandler func(*Context, *Node, *ThreadExt, GraphSignal)(string, error)
type ThreadHandlers map[string]ThreadHandler
type InfoType string
func (t InfoType) String() string {
return string(t)
}
type Info interface {
Serializable[InfoType]
}
// Data required by a parent thread to restore it's children
type ParentInfo struct {
Start bool `json:"start"`
StartAction string `json:"start_action"`
RestoreAction string `json:"restore_action"`
}
const ParentInfoType = InfoType("PARENT")
func (info *ParentInfo) Type() InfoType {
return ParentInfoType
}
func (info *ParentInfo) Serialize() ([]byte, error) {
return json.MarshalIndent(info, "", " ")
}
type QueuedAction struct {
Timeout time.Time `json:"time"`
Action string `json:"action"`
}
type ThreadType string
func (thread ThreadType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("THREAD: %s", string(thread))))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
}
type ThreadInfo struct {
Actions ThreadActions
Handlers ThreadHandlers
}
type InfoLoadFunc func([]byte)(Info, error)
type ThreadExtContext struct {
Loads map[InfoType]func([]byte)ThreadInfo
Types map[ThreadType]ThreadInfo
Loads map[InfoType]InfoLoadFunc
}
const BaseThreadType = ThreadType("BASE")
func NewThreadExtContext() *ThreadExtContext {
return &ThreadExtContext{
Types: map[ThreadType]ThreadInfo{
BaseThreadType: ThreadInfo{
Actions: BaseThreadActions,
Handlers: BaseThreadHandlers,
},
},
Loads: map[InfoType]InfoLoadFunc{
ParentInfoType: func(data []byte) (Info, error) {
var info ParentInfo
err := json.Unmarshal(data, &info)
if err != nil {
return nil, err
}
return &info, nil
},
},
}
}
func (ctx *ThreadExtContext) RegisterThreadType(thread_type ThreadType, actions ThreadActions, handlers ThreadHandlers) error {
if actions == nil || handlers == nil {
return fmt.Errorf("Cannot register ThreadType %s with nil actions or handlers", thread_type)
}
_, exists := ctx.Types[thread_type]
if exists == true {
return fmt.Errorf("ThreadType %s already registered in ThreadExtContext, cannot register again", thread_type)
}
ctx.Types[thread_type] = ThreadInfo{
Actions: actions,
Handlers: handlers,
}
return nil
}
func (ctx *ThreadExtContext) RegisterInfoType(info_type InfoType, load_fn InfoLoadFunc) error {
if load_fn == nil {
return fmt.Errorf("Cannot register %s with nil load_fn", info_type)
}
_, exists := ctx.Loads[info_type]
if exists == true {
return fmt.Errorf("InfoType %s is already registered in ThreadExtContext, cannot register again", info_type)
}
ctx.Loads[info_type] = load_fn
return nil
}
@ -25,6 +120,8 @@ type ThreadExt struct {
Actions ThreadActions
Handlers ThreadHandlers
ThreadType ThreadType
SignalChan chan GraphSignal
TimeoutChan <-chan time.Time
@ -43,6 +140,7 @@ type ThreadExt struct {
type ThreadExtJSON struct {
State string `json:"state"`
Type string `json:"type"`
Parent string `json:"parent"`
Children map[string][]byte `json:"children"`
ActionQueue []QueuedAction
@ -52,6 +150,39 @@ func (ext *ThreadExt) Serialize() ([]byte, error) {
return nil, fmt.Errorf("NOT_IMPLEMENTED")
}
func NewThreadExt(ctx*Context, thread_type ThreadType, parent *Node, children map[NodeID]ChildInfo, state string, action_queue []QueuedAction) (*ThreadExt, error) {
if children == nil {
children = map[NodeID]ChildInfo{}
}
if action_queue == nil {
action_queue = []QueuedAction{}
}
thread_ctx, err := GetCtx[*ThreadExt, *ThreadExtContext](ctx)
if err != nil {
return nil, err
}
type_info, exists := thread_ctx.Types[thread_type]
if exists == false {
return nil, fmt.Errorf("Tried to load thread type %s which is not in context", thread_type)
}
next_action, timeout_chan := SoonestAction(action_queue)
return &ThreadExt{
Actions: type_info.Actions,
Handlers: type_info.Handlers,
SignalChan: make(chan GraphSignal, THREAD_BUFFER_SIZE),
TimeoutChan: timeout_chan,
Active: false,
State: state,
Parent: parent,
Children: children,
ActionQueue: action_queue,
NextAction: next_action,
}, nil
}
const THREAD_BUFFER_SIZE int = 1024
func LoadThreadExt(ctx *Context, data []byte) (Extension, error) {
var j ThreadExtJSON
@ -75,26 +206,11 @@ func LoadThreadExt(ctx *Context, data []byte) (Extension, error) {
children[child_node.ID] = ChildInfo{
Child: child_node,
Infos: map[InfoType]ThreadInfo{},
}
Infos: map[InfoType]Info{},
}
next_action, timeout_chan := SoonestAction(j.ActionQueue)
extension := ThreadExt{
Actions: BaseThreadActions,
Handlers: BaseThreadHandlers,
SignalChan: make(chan GraphSignal, THREAD_BUFFER_SIZE),
TimeoutChan: timeout_chan,
Active: false,
State: j.State,
Parent: parent,
Children: children,
ActionQueue: j.ActionQueue,
NextAction: next_action,
}
return &extension, nil
return NewThreadExt(ctx, ThreadType(j.Type), parent, children, j.State, j.ActionQueue)
}
const ThreadExtType = ExtType("THREAD")
@ -281,44 +397,14 @@ func LinkThreads(context *StateContext, principal *Node, thread *Node, info Chil
})
}
type ThreadAction func(*Context, *Node, *ThreadExt)(string, error)
type ThreadActions map[string]ThreadAction
type ThreadHandler func(*Context, *Node, *ThreadExt, GraphSignal)(string, error)
type ThreadHandlers map[string]ThreadHandler
type InfoType string
func (t InfoType) String() string {
return string(t)
}
type ThreadInfo interface {
Serializable[InfoType]
}
// Data required by a parent thread to restore it's children
type ParentThreadInfo struct {
Start bool `json:"start"`
StartAction string `json:"start_action"`
RestoreAction string `json:"restore_action"`
}
const ParentThreadInfoType = InfoType("PARENT")
func (info *ParentThreadInfo) Type() InfoType {
return ParentThreadInfoType
}
func (info *ParentThreadInfo) Serialize() ([]byte, error) {
return json.MarshalIndent(info, "", " ")
}
type ChildInfo struct {
Child *Node
Infos map[InfoType]ThreadInfo
Infos map[InfoType]Info
}
func NewChildInfo(child *Node, infos map[InfoType]ThreadInfo) ChildInfo {
func NewChildInfo(child *Node, infos map[InfoType]Info) ChildInfo {
if infos == nil {
infos = map[InfoType]ThreadInfo{}
infos = map[InfoType]Info{}
}
return ChildInfo{
@ -327,48 +413,6 @@ func NewChildInfo(child *Node, infos map[InfoType]ThreadInfo) ChildInfo {
}
}
var deserializers = map[InfoType]func(interface{})(interface{}, error) {
"parent": func(raw interface{})(interface{}, error) {
m, ok := raw.(map[string]interface{})
if ok == false {
return nil, fmt.Errorf("Failed to cast parent info to map")
}
start, ok := m["start"].(bool)
if ok == false {
return nil, fmt.Errorf("Failed to get start from parent info")
}
start_action, ok := m["start_action"].(string)
if ok == false {
return nil, fmt.Errorf("Failed to get start_action from parent info")
}
restore_action, ok := m["restore_action"].(string)
if ok == false {
return nil, fmt.Errorf("Failed to get restore_action from parent info")
}
return &ParentThreadInfo{
Start: start,
StartAction: start_action,
RestoreAction: restore_action,
}, nil
},
}
func NewThreadExt(buffer int, name string, state string, actions ThreadActions, handlers ThreadHandlers) ThreadExt {
return ThreadExt{
Actions: actions,
Handlers: handlers,
SignalChan: make(chan GraphSignal, buffer),
TimeoutChan: nil,
Active: false,
State: state,
Parent: nil,
Children: map[NodeID]ChildInfo{},
ActionQueue: []QueuedAction{},
NextAction: nil,
}
}
func (ext *ThreadExt) SetActive(active bool) error {
ext.ActiveLock.Lock()
defer ext.ActiveLock.Unlock()
@ -485,7 +529,7 @@ func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal
ctx.Log.Logf("thread", "THREAD_NODE_LINKED: %s is not a child of %s", sig.ID)
return nil
}
parent_info, exists := info.Infos["parent"].(*ParentThreadInfo)
parent_info, exists := info.Infos["parent"].(*ParentInfo)
if exists == false {
panic("ran ThreadChildLinked from a thread that doesn't require 'parent' child info. library party foul")
}
@ -520,7 +564,7 @@ func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal
}
return UpdateStates(context, thread, NewACLInfo(info.Child, []string{"start"}), func(context *StateContext) error {
parent_info, exists := info.Infos["parent"].(*ParentThreadInfo)
parent_info, exists := info.Infos["parent"].(*ParentInfo)
if exists == false {
return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info")
}
@ -544,7 +588,7 @@ func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start boo
return err
}
parent_info := info.Infos["parent"].(*ParentThreadInfo)
parent_info := info.Infos["parent"].(*ParentInfo)
if parent_info.Start == true && child_ext.State != "finished" {
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID)
if start == true {

@ -88,6 +88,15 @@ func (ext *GroupExt) Serialize() ([]byte, error) {
}, "", " ")
}
func NewGroupExt(members NodeMap) *GroupExt {
if members == nil {
members = NodeMap{}
}
return &GroupExt{
Members: members,
}
}
func LoadGroupExt(ctx *Context, data []byte) (Extension, error) {
var j struct {
Members []string `json:"members"`
@ -103,10 +112,7 @@ func LoadGroupExt(ctx *Context, data []byte) (Extension, error) {
return nil, err
}
extension := GroupExt{
Members: members,
}
return &extension, nil
return NewGroupExt(members), nil
}
func (ext *GroupExt) Process(context *StateContext, node *Node, signal GraphSignal) error {