diff --git a/gql.go b/gql.go index 5be7c83..497acd1 100644 --- a/gql.go +++ b/gql.go @@ -23,6 +23,7 @@ import ( "crypto/sha512" "crypto/rand" "crypto/x509" + "github.com/google/uuid" ) type AuthReqJSON struct { @@ -65,26 +66,26 @@ type AuthRespJSON struct { Signature []byte `json:"signature"` } -func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, []byte, error) { +func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) { // Check if req.Time is within +- 1 second of now now := time.Now() earliest := now.Add(-1 * time.Second) latest := now.Add(1 * time.Second) // If req.Time is before the earliest acceptable time, or after the latest acceptible time if req.Time.Compare(earliest) == -1 { - return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_LATE: %s", req.Time) + return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_LATE: %s", req.Time) } else if req.Time.Compare(latest) == 1 { - return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time) + return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time) } x, y := elliptic.Unmarshal(thread.Key.Curve, req.Pubkey) if x == nil { - return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey) + return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey) } remote, err := thread.ECDH.NewPublicKey(req.ECDHPubkey) if err != nil { - return AuthRespJSON{}, nil, err + return AuthRespJSON{}, nil, nil, err } // Verify the signature @@ -92,23 +93,25 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, []byte, sig_data := append(req.ECDHPubkey, time_bytes...) sig_hash := sha512.Sum512(sig_data) + remote_key := &ecdsa.PublicKey{ + Curve: thread.Key.Curve, + X: x, + Y: y, + } + verified := ecdsa.VerifyASN1( - &ecdsa.PublicKey{ - Curve: thread.Key.Curve, - X: x, - Y: y, - }, + remote_key, sig_hash[:], req.Signature, ) if verified == false { - return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req) + return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req) } ec_key, err := thread.ECDH.GenerateKey(rand.Reader) if err != nil { - return AuthRespJSON{}, nil, err + return AuthRespJSON{}, nil, nil, err } ec_key_pub := ec_key.PublicKey().Bytes() @@ -120,35 +123,37 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, []byte, resp_sig, err := ecdsa.SignASN1(rand.Reader, thread.Key, resp_sig_hash[:]) if err != nil { - return AuthRespJSON{}, nil, err + return AuthRespJSON{}, nil, nil, err } shared_secret, err := ec_key.ECDH(remote) if err != nil { - return AuthRespJSON{}, nil, err + return AuthRespJSON{}, nil, nil, err } return AuthRespJSON{ Granted: granted, ECDHPubkey: ec_key_pub, Signature: resp_sig, - }, shared_secret, nil + }, remote_key, shared_secret, nil } type AuthData struct { Granted time.Time - Pubkey ecdh.PublicKey - ECDHClient ecdh.PublicKey + Pubkey *ecdsa.PublicKey + Shared []byte } type AuthDataJSON struct { Granted time.Time `json:"granted"` - Pubkey []byte `json:"pbkey"` - ECDHClient []byte `json:"ecdh_client"` + Pubkey []byte `json:"pubkey"` + Shared []byte `json:"shared"` } -func HashKey(pub []byte) uint64 { - return 0 +func KeyID(pub *ecdsa.PublicKey) NodeID { + ser := elliptic.Marshal(pub.Curve, pub.X, pub.Y) + str := uuid.NewHash(sha512.New(), ZeroUUID, ser, 3) + return NodeID(str) } func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *http.Request) { @@ -169,7 +174,7 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt return } - resp, _, err := NewAuthRespJSON(server, req) + resp, remote_id, _, err := NewAuthRespJSON(server, req) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %e", err) return @@ -192,13 +197,13 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_SUCCESS: %s", str) - key_hash := HashKey(req.Pubkey) + key_hash := KeyID(remote_id) _, exists := server.AuthMap[key_hash] if exists { - // New user + ctx.Log.Logf("gql", "REFRESHING AUTH FOR %+s", req.Pubkey) } else { - // Existing user + ctx.Log.Logf("gql", "AUTHORIZING NEW USER %+s", req.Pubkey) } } @@ -578,7 +583,7 @@ type GQLThread struct { http_server *http.Server http_done *sync.WaitGroup Listen string - AuthMap map[uint64]AuthData + AuthMap map[NodeID]AuthData Key *ecdsa.PrivateKey ECDH ecdh.Curve } @@ -604,7 +609,7 @@ func (thread * GQLThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo type GQLThreadJSON struct { SimpleThreadJSON Listen string `json:"listen"` - AuthMap map[uint64]AuthData `json:"auth_map"` + AuthMap map[string]AuthDataJSON `json:"auth_map"` Key []byte `json:"key"` ECDH uint8 `json:"ecdh_curve"` } @@ -633,10 +638,19 @@ func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON { panic(err) } + auth_map := map[string]AuthDataJSON{} + for id, data := range(thread.AuthMap) { + auth_map[id.String()] = AuthDataJSON{ + Granted: data.Granted, + Pubkey: elliptic.Marshal(data.Pubkey.Curve, data.Pubkey.X, data.Pubkey.Y), + Shared: thread.AuthMap[id].Shared, + } + } + return GQLThreadJSON{ SimpleThreadJSON: thread_json, Listen: thread.Listen, - AuthMap: thread.AuthMap, + AuthMap: auth_map, Key: ser_key, ECDH: ecdh_curve_ids[thread.ECDH], } @@ -660,7 +674,26 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e } thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key) - thread.AuthMap = j.AuthMap + thread.AuthMap = map[NodeID]AuthData{} + for id_str, auth_json := range(j.AuthMap) { + id, err := ParseID(id_str) + if err != nil { + return nil, err + } + x, y := elliptic.Unmarshal(key.Curve, auth_json.Pubkey) + if x == nil { + return nil, fmt.Errorf("Failed to load public key for curve %+v from %+v", key.Curve, auth_json.Pubkey) + } + thread.AuthMap[id] = AuthData{ + Granted: auth_json.Granted, + Pubkey: &ecdsa.PublicKey{ + Curve: key.Curve, + X: x, + Y: y, + }, + Shared: auth_json.Shared, + } + } nodes[id] = &thread err = RestoreSimpleThread(ctx, &thread, j.SimpleThreadJSON, nodes) @@ -675,7 +708,7 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh return GQLThread{ SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers), Listen: listen, - AuthMap: map[uint64]AuthData{}, + AuthMap: map[NodeID]AuthData{}, http_done: &sync.WaitGroup{}, Key: key, ECDH: ecdh_curve, diff --git a/gql_graph.go b/gql_graph.go index 7ac78dd..538e71e 100644 --- a/gql_graph.go +++ b/gql_graph.go @@ -823,14 +823,19 @@ func GQLMutationSendUpdate() *graphql.Field { return nil, fmt.Errorf("Bad direction: %d", signal_map["Direction"]) } - id , ok := p.Args["id"].(string) + id_str, ok := p.Args["id"].(string) if ok == false { return nil, fmt.Errorf("Failed to cast arg id to string") } + id, err := ParseID(id_str) + if err != nil { + return nil, err + } + var node Node = nil - err := UseStates(ctx, []Node{server}, func(nodes NodeMap) (error){ - node = FindChild(ctx, server, NodeID(id), nodes) + err = UseStates(ctx, []Node{server}, func(nodes NodeMap) (error){ + node = FindChild(ctx, server, id, nodes) if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", id) } diff --git a/gql_test.go b/gql_test.go index a1e997c..84c3fd8 100644 --- a/gql_test.go +++ b/gql_test.go @@ -53,13 +53,13 @@ func TestGQLThread(t * testing.T) { } func TestGQLDBLoad(t * testing.T) { - ctx := logTestContext(t, []string{}) + ctx := logTestContext(t, []string{"test"}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") l1 := &l1_r t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) t1 := &t1_r - update_channel := UpdateChannel(t1, 10, "test") + update_channel := UpdateChannel(t1, 10, NodeID{}) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) @@ -86,7 +86,7 @@ func TestGQLDBLoad(t * testing.T) { fatalErr(t, err) err = ThreadLoop(ctx, gql, "start") - if errors.Is(err, NewThreadAbortedError("")) { + if errors.Is(err, NewThreadAbortedError(NodeID{})) { ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) } else { fatalErr(t, err) @@ -97,8 +97,8 @@ func TestGQLDBLoad(t * testing.T) { err = UseStates(ctx, []Node{gql, t1}, func(nodes NodeMap) error { ser1, err := gql.Serialize() ser2, err := t1.Serialize() - ctx.Log.Logf("thread", "\n%s\n\n", ser1) - ctx.Log.Logf("thread", "\n%s\n\n", ser2) + ctx.Log.Logf("test", "\n%s\n\n", ser1) + ctx.Log.Logf("test", "\n%s\n\n", ser2) return err }) @@ -112,7 +112,7 @@ func TestGQLDBLoad(t * testing.T) { ctx.Log.Logf("test", "\n%s\n\n", ser) child := gql_loaded.(Thread).Children()[0].(*SimpleThread) t1_loaded = child - update_channel_2 = UpdateChannel(t1_loaded, 10, "test") + update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{}) err = UseMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error { ser, err := child.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) @@ -123,7 +123,7 @@ func TestGQLDBLoad(t * testing.T) { }) err = ThreadLoop(ctx, gql_loaded.(Thread), "restore") - if errors.Is(err, NewThreadAbortedError("")) { + if errors.Is(err, NewThreadAbortedError(NodeID{})) { ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) } else { fatalErr(t, err) @@ -143,7 +143,7 @@ func TestGQLAuth(t * testing.T) { var update_channel chan GraphSignal err = UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error { - update_channel = UpdateChannel(gql_t, 10, "test") + update_channel = UpdateChannel(gql_t, 10, NodeID{}) return nil }) fatalErr(t, err) diff --git a/graph_test.go b/graph_test.go index 54b76dd..6c68a31 100644 --- a/graph_test.go +++ b/graph_test.go @@ -24,7 +24,7 @@ func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, si if signal.Type() == signal_type { ctx.Log.Logf("test", "SIGNAL_TYPE_FOUND: %s - %s %+v\n", signal.Type(), signal.Source(), listener) if source == nil { - if signal.Source() == "" { + if signal.Source() == ZeroID { return signal } } else { diff --git a/lockable.go b/lockable.go index d1f156d..642aa6e 100644 --- a/lockable.go +++ b/lockable.go @@ -64,13 +64,12 @@ type SimpleLockableJSON struct { Owner *NodeID `json:"owner"` Dependencies []NodeID `json:"dependencies"` Requirements []NodeID `json:"requirements"` - LocksHeld map[NodeID]*NodeID `json:"locks_held"` + LocksHeld map[string]*NodeID `json:"locks_held"` } func (lockable * SimpleLockable) Serialize() ([]byte, error) { lockable_json := NewSimpleLockableJSON(lockable) return json.MarshalIndent(&lockable_json, "", " ") - } func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON { @@ -90,13 +89,13 @@ func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON { owner_id = &new_str } - locks_held := map[NodeID]*NodeID{} + locks_held := map[string]*NodeID{} for lockable_id, node := range(lockable.locks_held) { if node == nil { - locks_held[lockable_id] = nil + locks_held[lockable_id.String()] = nil } else { str := node.ID() - locks_held[lockable_id] = &str + locks_held[lockable_id.String()] = &str } } return SimpleLockableJSON{ @@ -590,7 +589,8 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO lockable.AddRequirement(req_l) } - for l_id, h_id := range(j.LocksHeld) { + for l_id_str, h_id := range(j.LocksHeld) { + l_id, err := ParseID(l_id_str) l, err := LoadNodeRecurse(ctx, l_id, nodes) if err != nil { return err diff --git a/lockable_test.go b/lockable_test.go index 1fb6726..2be1608 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -74,7 +74,7 @@ func TestLockableSelfLock(t * testing.T) { fatalErr(t, err) err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner_id := NodeID("") + owner_id := NodeID{} if l1.owner != nil { owner_id = l1.owner.ID() } @@ -120,7 +120,7 @@ func TestLockableSelfLockTiered(t * testing.T) { fatalErr(t, err) err = UseStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) (error) { - owner_1 := NodeID("") + owner_1 := NodeID{} if l1.owner != nil { owner_1 = l1.owner.ID() } @@ -128,7 +128,7 @@ func TestLockableSelfLockTiered(t * testing.T) { return fmt.Errorf("l1 is owned by %s instead of l3", owner_1) } - owner_2 := NodeID("") + owner_2 := NodeID{} if l2.owner != nil { owner_2 = l2.owner.ID() } @@ -181,7 +181,7 @@ func TestLockableLockOther(t * testing.T) { fatalErr(t, err) err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner_id := NodeID("") + owner_id := NodeID{} if l1.owner != nil { owner_id = l1.owner.ID() } @@ -236,7 +236,7 @@ func TestLockableLockSimpleConflict(t * testing.T) { fatalErr(t, err) err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { - owner_id := NodeID("") + owner_id := NodeID{} if l1.owner != nil { owner_id = l1.owner.ID() } @@ -304,7 +304,7 @@ func TestLockableSimpleUpdate(t * testing.T) { l1 := &l1_r - update_channel := UpdateChannel(l1, 1, "test") + update_channel := UpdateChannel(l1, 1, NodeID{}) go func() { UseStates(ctx, []Node{l1}, func(nodes NodeMap) error { @@ -333,7 +333,7 @@ func TestLockableDownUpdate(t * testing.T) { }) fatalErr(t, err) - update_channel := UpdateChannel(l1, 1, "test") + update_channel := UpdateChannel(l1, 1, NodeID{}) go func() { UseStates(ctx, []Node{l2}, func(nodes NodeMap) error { @@ -362,7 +362,7 @@ func TestLockableUpUpdate(t * testing.T) { }) fatalErr(t, err) - update_channel := UpdateChannel(l3, 1, "test") + update_channel := UpdateChannel(l3, 1, NodeID{}) go func() { UseStates(ctx, []Node{l2}, func(nodes NodeMap) error { @@ -390,7 +390,7 @@ func TestOwnerNotUpdatedTwice(t * testing.T) { }) fatalErr(t, err) - update_channel := UpdateChannel(l2, 1, "test") + update_channel := UpdateChannel(l2, 1, NodeID{}) go func() { err := UseStates(ctx, []Node{l1}, func(nodes NodeMap) error { diff --git a/node.go b/node.go index 023acb8..af7f9ee 100644 --- a/node.go +++ b/node.go @@ -10,9 +10,26 @@ import ( ) // IDs are how nodes are uniquely identified, and can be serialized for the database -type NodeID string +type NodeID uuid.UUID + +var ZeroUUID = uuid.UUID{} +var ZeroID = NodeID(ZeroUUID) + func (id NodeID) Serialize() []byte { - return []byte(id) + ser, _ := (uuid.UUID)(id).MarshalBinary() + return ser +} + +func (id NodeID) String() string { + return (uuid.UUID)(id).String() +} + +func ParseID(str string) (NodeID, error) { + id_uuid, err := uuid.Parse(str) + if err != nil { + return NodeID{}, err + } + return NodeID(id_uuid), nil } // Types are how nodes are associated with structs at runtime(and from the DB) @@ -27,8 +44,7 @@ func (node_type NodeType) Hash() uint64 { // Generate a random NodeID func RandID() NodeID { - uuid_str := uuid.New().String() - return NodeID(uuid_str) + return NodeID(uuid.New()) } // A Node represents data that can be read by multiple goroutines and written to by one, with a unique ID attached, and a method to process updates(including propagating them to connected nodes) @@ -169,7 +185,7 @@ func getNodeBytes(node Node) ([]byte, error) { } ser, err := node.Serialize() if err != nil { - return nil, fmt.Errorf("DB_SERIALIZE_ERROR: %e", err) + return nil, fmt.Errorf("DB_SERIALIZE_ERROR: %s", err) } header := NewDBHeader(node.Type()) diff --git a/signal.go b/signal.go index 73a0b45..811121c 100644 --- a/signal.go +++ b/signal.go @@ -48,7 +48,7 @@ func (signal BaseSignal) Type() string { } func NewBaseSignal(source Node, _type string, direction SignalDirection) BaseSignal { - var source_id NodeID = "nil" + var source_id NodeID = NodeID{} if source != nil { source_id = source.ID() } diff --git a/thread.go b/thread.go index d369686..b2a207b 100644 --- a/thread.go +++ b/thread.go @@ -7,6 +7,7 @@ import ( "errors" "reflect" "encoding/json" + "github.com/google/uuid" ) // SimpleThread.Signal updates the parent and children, and sends the signal to an internal channel @@ -305,7 +306,7 @@ func (thread * SimpleThread) SignalChannel() <-chan GraphSignal { type SimpleThreadJSON struct { Parent *NodeID `json:"parent"` - Children map[NodeID]interface{} `json:"children"` + Children map[string]interface{} `json:"children"` Timeout time.Time `json:"timeout"` TimeoutAction string `json:"timeout_action"` StateName string `json:"state_name"` @@ -313,9 +314,9 @@ type SimpleThreadJSON struct { } func NewSimpleThreadJSON(thread *SimpleThread) SimpleThreadJSON { - children := map[NodeID]interface{}{} + children := map[string]interface{}{} for _, child := range(thread.children) { - children[child.ID()] = thread.child_info[child.ID()] + children[child.ID().String()] = thread.child_info[child.ID()] } var parent_id *NodeID = nil @@ -379,7 +380,11 @@ func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes thread.SetParent(p_t) } - for id, info_raw := range(j.Children) { + for id_str, info_raw := range(j.Children) { + id, err := ParseID(id_str) + if err != nil { + return err + } child_node, err := LoadNodeRecurse(ctx, id, nodes) if err != nil { return err @@ -572,7 +577,7 @@ var ThreadRestore = func(ctx * Context, thread Thread) { var ThreadStart = func(ctx * Context, thread Thread) error { return UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { - owner_id := NodeID("") + owner_id := NodeID{} if thread.Owner() != nil { owner_id = thread.Owner().ID() } @@ -636,12 +641,12 @@ var ThreadWait = func(ctx * Context, thread Thread) (string, error) { type ThreadAbortedError NodeID func (e ThreadAbortedError) Is(target error) bool { - error_type := reflect.TypeOf(ThreadAbortedError("")) + error_type := reflect.TypeOf(ThreadAbortedError(NodeID{})) target_type := reflect.TypeOf(target) return error_type == target_type } func (e ThreadAbortedError) Error() string { - return fmt.Sprintf("Aborted by %s", string(e)) + return fmt.Sprintf("Aborted by %s", (uuid.UUID)(e).String()) } func NewThreadAbortedError(aborter NodeID) ThreadAbortedError { return ThreadAbortedError(aborter)