Changed NodeID from string to UUID, filled in more auth

graph-rework-2
noah metz 2023-07-19 20:03:13 -06:00
parent 6d0925f20f
commit 374fd6e487
9 changed files with 129 additions and 70 deletions

@ -23,6 +23,7 @@ import (
"crypto/sha512" "crypto/sha512"
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"github.com/google/uuid"
) )
type AuthReqJSON struct { type AuthReqJSON struct {
@ -65,26 +66,26 @@ type AuthRespJSON struct {
Signature []byte `json:"signature"` Signature []byte `json:"signature"`
} }
func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, []byte, error) { func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) {
// Check if req.Time is within +- 1 second of now // Check if req.Time is within +- 1 second of now
now := time.Now() now := time.Now()
earliest := now.Add(-1 * time.Second) earliest := now.Add(-1 * time.Second)
latest := now.Add(1 * time.Second) latest := now.Add(1 * time.Second)
// If req.Time is before the earliest acceptable time, or after the latest acceptible time // If req.Time is before the earliest acceptable time, or after the latest acceptible time
if req.Time.Compare(earliest) == -1 { if req.Time.Compare(earliest) == -1 {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_LATE: %s", req.Time) return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_LATE: %s", req.Time)
} else if req.Time.Compare(latest) == 1 { } else if req.Time.Compare(latest) == 1 {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time) return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time)
} }
x, y := elliptic.Unmarshal(thread.Key.Curve, req.Pubkey) x, y := elliptic.Unmarshal(thread.Key.Curve, req.Pubkey)
if x == nil { if x == nil {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey) return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey)
} }
remote, err := thread.ECDH.NewPublicKey(req.ECDHPubkey) remote, err := thread.ECDH.NewPublicKey(req.ECDHPubkey)
if err != nil { if err != nil {
return AuthRespJSON{}, nil, err return AuthRespJSON{}, nil, nil, err
} }
// Verify the signature // Verify the signature
@ -92,23 +93,25 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, []byte,
sig_data := append(req.ECDHPubkey, time_bytes...) sig_data := append(req.ECDHPubkey, time_bytes...)
sig_hash := sha512.Sum512(sig_data) sig_hash := sha512.Sum512(sig_data)
verified := ecdsa.VerifyASN1( remote_key := &ecdsa.PublicKey{
&ecdsa.PublicKey{
Curve: thread.Key.Curve, Curve: thread.Key.Curve,
X: x, X: x,
Y: y, Y: y,
}, }
verified := ecdsa.VerifyASN1(
remote_key,
sig_hash[:], sig_hash[:],
req.Signature, req.Signature,
) )
if verified == false { if verified == false {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req) return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req)
} }
ec_key, err := thread.ECDH.GenerateKey(rand.Reader) ec_key, err := thread.ECDH.GenerateKey(rand.Reader)
if err != nil { if err != nil {
return AuthRespJSON{}, nil, err return AuthRespJSON{}, nil, nil, err
} }
ec_key_pub := ec_key.PublicKey().Bytes() ec_key_pub := ec_key.PublicKey().Bytes()
@ -120,35 +123,37 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, []byte,
resp_sig, err := ecdsa.SignASN1(rand.Reader, thread.Key, resp_sig_hash[:]) resp_sig, err := ecdsa.SignASN1(rand.Reader, thread.Key, resp_sig_hash[:])
if err != nil { if err != nil {
return AuthRespJSON{}, nil, err return AuthRespJSON{}, nil, nil, err
} }
shared_secret, err := ec_key.ECDH(remote) shared_secret, err := ec_key.ECDH(remote)
if err != nil { if err != nil {
return AuthRespJSON{}, nil, err return AuthRespJSON{}, nil, nil, err
} }
return AuthRespJSON{ return AuthRespJSON{
Granted: granted, Granted: granted,
ECDHPubkey: ec_key_pub, ECDHPubkey: ec_key_pub,
Signature: resp_sig, Signature: resp_sig,
}, shared_secret, nil }, remote_key, shared_secret, nil
} }
type AuthData struct { type AuthData struct {
Granted time.Time Granted time.Time
Pubkey ecdh.PublicKey Pubkey *ecdsa.PublicKey
ECDHClient ecdh.PublicKey Shared []byte
} }
type AuthDataJSON struct { type AuthDataJSON struct {
Granted time.Time `json:"granted"` Granted time.Time `json:"granted"`
Pubkey []byte `json:"pbkey"` Pubkey []byte `json:"pubkey"`
ECDHClient []byte `json:"ecdh_client"` Shared []byte `json:"shared"`
} }
func HashKey(pub []byte) uint64 { func KeyID(pub *ecdsa.PublicKey) NodeID {
return 0 ser := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
str := uuid.NewHash(sha512.New(), ZeroUUID, ser, 3)
return NodeID(str)
} }
func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *http.Request) { func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *http.Request) {
@ -169,7 +174,7 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
return return
} }
resp, _, err := NewAuthRespJSON(server, req) resp, remote_id, _, err := NewAuthRespJSON(server, req)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %e", err) ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %e", err)
return return
@ -192,13 +197,13 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_SUCCESS: %s", str) ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_SUCCESS: %s", str)
key_hash := HashKey(req.Pubkey) key_hash := KeyID(remote_id)
_, exists := server.AuthMap[key_hash] _, exists := server.AuthMap[key_hash]
if exists { if exists {
// New user ctx.Log.Logf("gql", "REFRESHING AUTH FOR %+s", req.Pubkey)
} else { } else {
// Existing user ctx.Log.Logf("gql", "AUTHORIZING NEW USER %+s", req.Pubkey)
} }
} }
@ -578,7 +583,7 @@ type GQLThread struct {
http_server *http.Server http_server *http.Server
http_done *sync.WaitGroup http_done *sync.WaitGroup
Listen string Listen string
AuthMap map[uint64]AuthData AuthMap map[NodeID]AuthData
Key *ecdsa.PrivateKey Key *ecdsa.PrivateKey
ECDH ecdh.Curve ECDH ecdh.Curve
} }
@ -604,7 +609,7 @@ func (thread * GQLThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo
type GQLThreadJSON struct { type GQLThreadJSON struct {
SimpleThreadJSON SimpleThreadJSON
Listen string `json:"listen"` Listen string `json:"listen"`
AuthMap map[uint64]AuthData `json:"auth_map"` AuthMap map[string]AuthDataJSON `json:"auth_map"`
Key []byte `json:"key"` Key []byte `json:"key"`
ECDH uint8 `json:"ecdh_curve"` ECDH uint8 `json:"ecdh_curve"`
} }
@ -633,10 +638,19 @@ func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON {
panic(err) panic(err)
} }
auth_map := map[string]AuthDataJSON{}
for id, data := range(thread.AuthMap) {
auth_map[id.String()] = AuthDataJSON{
Granted: data.Granted,
Pubkey: elliptic.Marshal(data.Pubkey.Curve, data.Pubkey.X, data.Pubkey.Y),
Shared: thread.AuthMap[id].Shared,
}
}
return GQLThreadJSON{ return GQLThreadJSON{
SimpleThreadJSON: thread_json, SimpleThreadJSON: thread_json,
Listen: thread.Listen, Listen: thread.Listen,
AuthMap: thread.AuthMap, AuthMap: auth_map,
Key: ser_key, Key: ser_key,
ECDH: ecdh_curve_ids[thread.ECDH], ECDH: ecdh_curve_ids[thread.ECDH],
} }
@ -660,7 +674,26 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e
} }
thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key) thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key)
thread.AuthMap = j.AuthMap thread.AuthMap = map[NodeID]AuthData{}
for id_str, auth_json := range(j.AuthMap) {
id, err := ParseID(id_str)
if err != nil {
return nil, err
}
x, y := elliptic.Unmarshal(key.Curve, auth_json.Pubkey)
if x == nil {
return nil, fmt.Errorf("Failed to load public key for curve %+v from %+v", key.Curve, auth_json.Pubkey)
}
thread.AuthMap[id] = AuthData{
Granted: auth_json.Granted,
Pubkey: &ecdsa.PublicKey{
Curve: key.Curve,
X: x,
Y: y,
},
Shared: auth_json.Shared,
}
}
nodes[id] = &thread nodes[id] = &thread
err = RestoreSimpleThread(ctx, &thread, j.SimpleThreadJSON, nodes) err = RestoreSimpleThread(ctx, &thread, j.SimpleThreadJSON, nodes)
@ -675,7 +708,7 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh
return GQLThread{ return GQLThread{
SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers), SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers),
Listen: listen, Listen: listen,
AuthMap: map[uint64]AuthData{}, AuthMap: map[NodeID]AuthData{},
http_done: &sync.WaitGroup{}, http_done: &sync.WaitGroup{},
Key: key, Key: key,
ECDH: ecdh_curve, ECDH: ecdh_curve,

@ -823,14 +823,19 @@ func GQLMutationSendUpdate() *graphql.Field {
return nil, fmt.Errorf("Bad direction: %d", signal_map["Direction"]) return nil, fmt.Errorf("Bad direction: %d", signal_map["Direction"])
} }
id , ok := p.Args["id"].(string) id_str, ok := p.Args["id"].(string)
if ok == false { if ok == false {
return nil, fmt.Errorf("Failed to cast arg id to string") return nil, fmt.Errorf("Failed to cast arg id to string")
} }
id, err := ParseID(id_str)
if err != nil {
return nil, err
}
var node Node = nil var node Node = nil
err := UseStates(ctx, []Node{server}, func(nodes NodeMap) (error){ err = UseStates(ctx, []Node{server}, func(nodes NodeMap) (error){
node = FindChild(ctx, server, NodeID(id), nodes) node = FindChild(ctx, server, id, nodes)
if node == nil { if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id) return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
} }

@ -53,13 +53,13 @@ func TestGQLThread(t * testing.T) {
} }
func TestGQLDBLoad(t * testing.T) { func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{}) ctx := logTestContext(t, []string{"test"})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r l1 := &l1_r
t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
t1 := &t1_r t1 := &t1_r
update_channel := UpdateChannel(t1, 10, "test") update_channel := UpdateChannel(t1, 10, NodeID{})
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
@ -86,7 +86,7 @@ func TestGQLDBLoad(t * testing.T) {
fatalErr(t, err) fatalErr(t, err)
err = ThreadLoop(ctx, gql, "start") err = ThreadLoop(ctx, gql, "start")
if errors.Is(err, NewThreadAbortedError("")) { if errors.Is(err, NewThreadAbortedError(NodeID{})) {
ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) ctx.Log.Logf("test", "Main thread aborted by signal: %s", err)
} else { } else {
fatalErr(t, err) fatalErr(t, err)
@ -97,8 +97,8 @@ func TestGQLDBLoad(t * testing.T) {
err = UseStates(ctx, []Node{gql, t1}, func(nodes NodeMap) error { err = UseStates(ctx, []Node{gql, t1}, func(nodes NodeMap) error {
ser1, err := gql.Serialize() ser1, err := gql.Serialize()
ser2, err := t1.Serialize() ser2, err := t1.Serialize()
ctx.Log.Logf("thread", "\n%s\n\n", ser1) ctx.Log.Logf("test", "\n%s\n\n", ser1)
ctx.Log.Logf("thread", "\n%s\n\n", ser2) ctx.Log.Logf("test", "\n%s\n\n", ser2)
return err return err
}) })
@ -112,7 +112,7 @@ func TestGQLDBLoad(t * testing.T) {
ctx.Log.Logf("test", "\n%s\n\n", ser) ctx.Log.Logf("test", "\n%s\n\n", ser)
child := gql_loaded.(Thread).Children()[0].(*SimpleThread) child := gql_loaded.(Thread).Children()[0].(*SimpleThread)
t1_loaded = child t1_loaded = child
update_channel_2 = UpdateChannel(t1_loaded, 10, "test") update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{})
err = UseMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error { err = UseMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error {
ser, err := child.Serialize() ser, err := child.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser) ctx.Log.Logf("test", "\n%s\n\n", ser)
@ -123,7 +123,7 @@ func TestGQLDBLoad(t * testing.T) {
}) })
err = ThreadLoop(ctx, gql_loaded.(Thread), "restore") err = ThreadLoop(ctx, gql_loaded.(Thread), "restore")
if errors.Is(err, NewThreadAbortedError("")) { if errors.Is(err, NewThreadAbortedError(NodeID{})) {
ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) ctx.Log.Logf("test", "Main thread aborted by signal: %s", err)
} else { } else {
fatalErr(t, err) fatalErr(t, err)
@ -143,7 +143,7 @@ func TestGQLAuth(t * testing.T) {
var update_channel chan GraphSignal var update_channel chan GraphSignal
err = UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error { err = UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error {
update_channel = UpdateChannel(gql_t, 10, "test") update_channel = UpdateChannel(gql_t, 10, NodeID{})
return nil return nil
}) })
fatalErr(t, err) fatalErr(t, err)

@ -24,7 +24,7 @@ func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, si
if signal.Type() == signal_type { if signal.Type() == signal_type {
ctx.Log.Logf("test", "SIGNAL_TYPE_FOUND: %s - %s %+v\n", signal.Type(), signal.Source(), listener) ctx.Log.Logf("test", "SIGNAL_TYPE_FOUND: %s - %s %+v\n", signal.Type(), signal.Source(), listener)
if source == nil { if source == nil {
if signal.Source() == "" { if signal.Source() == ZeroID {
return signal return signal
} }
} else { } else {

@ -64,13 +64,12 @@ type SimpleLockableJSON struct {
Owner *NodeID `json:"owner"` Owner *NodeID `json:"owner"`
Dependencies []NodeID `json:"dependencies"` Dependencies []NodeID `json:"dependencies"`
Requirements []NodeID `json:"requirements"` Requirements []NodeID `json:"requirements"`
LocksHeld map[NodeID]*NodeID `json:"locks_held"` LocksHeld map[string]*NodeID `json:"locks_held"`
} }
func (lockable * SimpleLockable) Serialize() ([]byte, error) { func (lockable * SimpleLockable) Serialize() ([]byte, error) {
lockable_json := NewSimpleLockableJSON(lockable) lockable_json := NewSimpleLockableJSON(lockable)
return json.MarshalIndent(&lockable_json, "", " ") return json.MarshalIndent(&lockable_json, "", " ")
} }
func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON { func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON {
@ -90,13 +89,13 @@ func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON {
owner_id = &new_str owner_id = &new_str
} }
locks_held := map[NodeID]*NodeID{} locks_held := map[string]*NodeID{}
for lockable_id, node := range(lockable.locks_held) { for lockable_id, node := range(lockable.locks_held) {
if node == nil { if node == nil {
locks_held[lockable_id] = nil locks_held[lockable_id.String()] = nil
} else { } else {
str := node.ID() str := node.ID()
locks_held[lockable_id] = &str locks_held[lockable_id.String()] = &str
} }
} }
return SimpleLockableJSON{ return SimpleLockableJSON{
@ -590,7 +589,8 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
lockable.AddRequirement(req_l) lockable.AddRequirement(req_l)
} }
for l_id, h_id := range(j.LocksHeld) { for l_id_str, h_id := range(j.LocksHeld) {
l_id, err := ParseID(l_id_str)
l, err := LoadNodeRecurse(ctx, l_id, nodes) l, err := LoadNodeRecurse(ctx, l_id, nodes)
if err != nil { if err != nil {
return err return err

@ -74,7 +74,7 @@ func TestLockableSelfLock(t * testing.T) {
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) {
owner_id := NodeID("") owner_id := NodeID{}
if l1.owner != nil { if l1.owner != nil {
owner_id = l1.owner.ID() owner_id = l1.owner.ID()
} }
@ -120,7 +120,7 @@ func TestLockableSelfLockTiered(t * testing.T) {
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) (error) { err = UseStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) (error) {
owner_1 := NodeID("") owner_1 := NodeID{}
if l1.owner != nil { if l1.owner != nil {
owner_1 = l1.owner.ID() owner_1 = l1.owner.ID()
} }
@ -128,7 +128,7 @@ func TestLockableSelfLockTiered(t * testing.T) {
return fmt.Errorf("l1 is owned by %s instead of l3", owner_1) return fmt.Errorf("l1 is owned by %s instead of l3", owner_1)
} }
owner_2 := NodeID("") owner_2 := NodeID{}
if l2.owner != nil { if l2.owner != nil {
owner_2 = l2.owner.ID() owner_2 = l2.owner.ID()
} }
@ -181,7 +181,7 @@ func TestLockableLockOther(t * testing.T) {
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) {
owner_id := NodeID("") owner_id := NodeID{}
if l1.owner != nil { if l1.owner != nil {
owner_id = l1.owner.ID() owner_id = l1.owner.ID()
} }
@ -236,7 +236,7 @@ func TestLockableLockSimpleConflict(t * testing.T) {
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) {
owner_id := NodeID("") owner_id := NodeID{}
if l1.owner != nil { if l1.owner != nil {
owner_id = l1.owner.ID() owner_id = l1.owner.ID()
} }
@ -304,7 +304,7 @@ func TestLockableSimpleUpdate(t * testing.T) {
l1 := &l1_r l1 := &l1_r
update_channel := UpdateChannel(l1, 1, "test") update_channel := UpdateChannel(l1, 1, NodeID{})
go func() { go func() {
UseStates(ctx, []Node{l1}, func(nodes NodeMap) error { UseStates(ctx, []Node{l1}, func(nodes NodeMap) error {
@ -333,7 +333,7 @@ func TestLockableDownUpdate(t * testing.T) {
}) })
fatalErr(t, err) fatalErr(t, err)
update_channel := UpdateChannel(l1, 1, "test") update_channel := UpdateChannel(l1, 1, NodeID{})
go func() { go func() {
UseStates(ctx, []Node{l2}, func(nodes NodeMap) error { UseStates(ctx, []Node{l2}, func(nodes NodeMap) error {
@ -362,7 +362,7 @@ func TestLockableUpUpdate(t * testing.T) {
}) })
fatalErr(t, err) fatalErr(t, err)
update_channel := UpdateChannel(l3, 1, "test") update_channel := UpdateChannel(l3, 1, NodeID{})
go func() { go func() {
UseStates(ctx, []Node{l2}, func(nodes NodeMap) error { UseStates(ctx, []Node{l2}, func(nodes NodeMap) error {
@ -390,7 +390,7 @@ func TestOwnerNotUpdatedTwice(t * testing.T) {
}) })
fatalErr(t, err) fatalErr(t, err)
update_channel := UpdateChannel(l2, 1, "test") update_channel := UpdateChannel(l2, 1, NodeID{})
go func() { go func() {
err := UseStates(ctx, []Node{l1}, func(nodes NodeMap) error { err := UseStates(ctx, []Node{l1}, func(nodes NodeMap) error {

@ -10,9 +10,26 @@ import (
) )
// IDs are how nodes are uniquely identified, and can be serialized for the database // IDs are how nodes are uniquely identified, and can be serialized for the database
type NodeID string type NodeID uuid.UUID
var ZeroUUID = uuid.UUID{}
var ZeroID = NodeID(ZeroUUID)
func (id NodeID) Serialize() []byte { func (id NodeID) Serialize() []byte {
return []byte(id) ser, _ := (uuid.UUID)(id).MarshalBinary()
return ser
}
func (id NodeID) String() string {
return (uuid.UUID)(id).String()
}
func ParseID(str string) (NodeID, error) {
id_uuid, err := uuid.Parse(str)
if err != nil {
return NodeID{}, err
}
return NodeID(id_uuid), nil
} }
// Types are how nodes are associated with structs at runtime(and from the DB) // Types are how nodes are associated with structs at runtime(and from the DB)
@ -27,8 +44,7 @@ func (node_type NodeType) Hash() uint64 {
// Generate a random NodeID // Generate a random NodeID
func RandID() NodeID { func RandID() NodeID {
uuid_str := uuid.New().String() return NodeID(uuid.New())
return NodeID(uuid_str)
} }
// A Node represents data that can be read by multiple goroutines and written to by one, with a unique ID attached, and a method to process updates(including propagating them to connected nodes) // A Node represents data that can be read by multiple goroutines and written to by one, with a unique ID attached, and a method to process updates(including propagating them to connected nodes)
@ -169,7 +185,7 @@ func getNodeBytes(node Node) ([]byte, error) {
} }
ser, err := node.Serialize() ser, err := node.Serialize()
if err != nil { if err != nil {
return nil, fmt.Errorf("DB_SERIALIZE_ERROR: %e", err) return nil, fmt.Errorf("DB_SERIALIZE_ERROR: %s", err)
} }
header := NewDBHeader(node.Type()) header := NewDBHeader(node.Type())

@ -48,7 +48,7 @@ func (signal BaseSignal) Type() string {
} }
func NewBaseSignal(source Node, _type string, direction SignalDirection) BaseSignal { func NewBaseSignal(source Node, _type string, direction SignalDirection) BaseSignal {
var source_id NodeID = "nil" var source_id NodeID = NodeID{}
if source != nil { if source != nil {
source_id = source.ID() source_id = source.ID()
} }

@ -7,6 +7,7 @@ import (
"errors" "errors"
"reflect" "reflect"
"encoding/json" "encoding/json"
"github.com/google/uuid"
) )
// SimpleThread.Signal updates the parent and children, and sends the signal to an internal channel // SimpleThread.Signal updates the parent and children, and sends the signal to an internal channel
@ -305,7 +306,7 @@ func (thread * SimpleThread) SignalChannel() <-chan GraphSignal {
type SimpleThreadJSON struct { type SimpleThreadJSON struct {
Parent *NodeID `json:"parent"` Parent *NodeID `json:"parent"`
Children map[NodeID]interface{} `json:"children"` Children map[string]interface{} `json:"children"`
Timeout time.Time `json:"timeout"` Timeout time.Time `json:"timeout"`
TimeoutAction string `json:"timeout_action"` TimeoutAction string `json:"timeout_action"`
StateName string `json:"state_name"` StateName string `json:"state_name"`
@ -313,9 +314,9 @@ type SimpleThreadJSON struct {
} }
func NewSimpleThreadJSON(thread *SimpleThread) SimpleThreadJSON { func NewSimpleThreadJSON(thread *SimpleThread) SimpleThreadJSON {
children := map[NodeID]interface{}{} children := map[string]interface{}{}
for _, child := range(thread.children) { for _, child := range(thread.children) {
children[child.ID()] = thread.child_info[child.ID()] children[child.ID().String()] = thread.child_info[child.ID()]
} }
var parent_id *NodeID = nil var parent_id *NodeID = nil
@ -379,7 +380,11 @@ func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes
thread.SetParent(p_t) thread.SetParent(p_t)
} }
for id, info_raw := range(j.Children) { for id_str, info_raw := range(j.Children) {
id, err := ParseID(id_str)
if err != nil {
return err
}
child_node, err := LoadNodeRecurse(ctx, id, nodes) child_node, err := LoadNodeRecurse(ctx, id, nodes)
if err != nil { if err != nil {
return err return err
@ -572,7 +577,7 @@ var ThreadRestore = func(ctx * Context, thread Thread) {
var ThreadStart = func(ctx * Context, thread Thread) error { var ThreadStart = func(ctx * Context, thread Thread) error {
return UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { return UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error {
owner_id := NodeID("") owner_id := NodeID{}
if thread.Owner() != nil { if thread.Owner() != nil {
owner_id = thread.Owner().ID() owner_id = thread.Owner().ID()
} }
@ -636,12 +641,12 @@ var ThreadWait = func(ctx * Context, thread Thread) (string, error) {
type ThreadAbortedError NodeID type ThreadAbortedError NodeID
func (e ThreadAbortedError) Is(target error) bool { func (e ThreadAbortedError) Is(target error) bool {
error_type := reflect.TypeOf(ThreadAbortedError("")) error_type := reflect.TypeOf(ThreadAbortedError(NodeID{}))
target_type := reflect.TypeOf(target) target_type := reflect.TypeOf(target)
return error_type == target_type return error_type == target_type
} }
func (e ThreadAbortedError) Error() string { func (e ThreadAbortedError) Error() string {
return fmt.Sprintf("Aborted by %s", string(e)) return fmt.Sprintf("Aborted by %s", (uuid.UUID)(e).String())
} }
func NewThreadAbortedError(aborter NodeID) ThreadAbortedError { func NewThreadAbortedError(aborter NodeID) ThreadAbortedError {
return ThreadAbortedError(aborter) return ThreadAbortedError(aborter)