diff --git a/context.go b/context.go index 9b3eeee..141d0df 100644 --- a/context.go +++ b/context.go @@ -7,6 +7,7 @@ import ( "errors" "runtime" "crypto/sha512" + "crypto/elliptic" "encoding/binary" ) @@ -77,6 +78,8 @@ type Context struct { Extensions map[uint64]ExtensionInfo // Map between database type hashes and the registered info Types map[uint64]*NodeInfo + // Curve used for signature operations + ECDSA elliptic.Curve // Routing map to all the nodes local to this context NodesLock sync.RWMutex Nodes map[NodeID]*Node @@ -194,6 +197,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { Extensions: map[uint64]ExtensionInfo{}, Types: map[uint64]*NodeInfo{}, Nodes: map[NodeID]*Node{}, + ECDSA: elliptic.P256(), } var err error diff --git a/gql_test.go b/gql_test.go index 6eec99e..a18b1ba 100644 --- a/gql_test.go +++ b/gql_test.go @@ -15,7 +15,7 @@ func TestGQLDB(t * testing.T) { TestUserNodeType := NodeType("TEST_USER") err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) fatalErr(t, err) - u1 := NewNode(ctx, RandID(), TestUserNodeType, 10, nil) + u1 := NewNode(ctx, nil, TestUserNodeType, 10, nil) ctx.Log.Logf("test", "U1_ID: %s", u1.ID) @@ -24,7 +24,7 @@ func TestGQLDB(t * testing.T) { gql_ext := NewGQLExt(":0", ecdh.P256(), key, nil, nil) listener_ext := NewListenerExt(10) - gql := NewNode(ctx, RandID(), GQLNodeType, 10, nil, + gql := NewNode(ctx, nil, GQLNodeType, 10, nil, gql_ext, listener_ext, NewACLExt(), diff --git a/graph_test.go b/graph_test.go index 29d5d35..973a694 100644 --- a/graph_test.go +++ b/graph_test.go @@ -113,7 +113,7 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { policy := NewAllNodesPolicy(Actions{MakeAction("status")}) listener_extension := NewListenerExt(buffer) listener := NewNode(ctx, - RandID(), + nil, SimpleListenerNodeType, 10, nil, diff --git a/lockable_test.go b/lockable_test.go index b683409..f773d34 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -25,13 +25,13 @@ func TestLink(t *testing.T) { ctx := lockableTestContext(t, []string{}) l1_listener := NewListenerExt(10) - l1 := NewNode(ctx, RandID(), TestLockableType, 10, nil, + l1 := NewNode(ctx, nil, TestLockableType, 10, nil, l1_listener, NewACLExt(link_policy), NewLockableExt(), ) l2_listener := NewListenerExt(10) - l2 := NewNode(ctx, RandID(), TestLockableType, 10, nil, + l2 := NewNode(ctx, nil, TestLockableType, 10, nil, l2_listener, NewACLExt(link_policy), NewLockableExt(), @@ -55,7 +55,7 @@ func TestLink10K(t *testing.T) { ctx := lockableTestContext(t, []string{"test"}) NewLockable := func()(*Node) { - l := NewNode(ctx, RandID(), TestLockableType, 10, nil, + l := NewNode(ctx, nil, TestLockableType, 10, nil, NewACLExt(lock_policy, link_policy), NewLockableExt(), ) @@ -64,7 +64,7 @@ func TestLink10K(t *testing.T) { NewListener := func()(*Node, *ListenerExt) { listener := NewListenerExt(100000) - l := NewNode(ctx, RandID(), TestLockableType, 256, nil, + l := NewNode(ctx, nil, TestLockableType, 256, nil, listener, NewACLExt(lock_policy, link_policy), NewLockableExt(), @@ -94,7 +94,7 @@ func TestLock(t *testing.T) { NewLockable := func()(*Node, *ListenerExt) { listener := NewListenerExt(100) - l := NewNode(ctx, RandID(), TestLockableType, 10, nil, + l := NewNode(ctx, nil, TestLockableType, 10, nil, listener, NewACLExt(lock_policy, link_policy), NewLockableExt(), diff --git a/node.go b/node.go index 319b255..10e2246 100644 --- a/node.go +++ b/node.go @@ -10,13 +10,18 @@ import ( "encoding/binary" "encoding/json" "sync/atomic" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/sha512" + "crypto/rand" + "crypto/x509" ) const ( // Magic first four bytes of serialized DB content, stored big endian NODE_DB_MAGIC = 0x2491df14 // Total length of the node database header, has magic to verify and type_hash to map to load function - NODE_DB_HEADER_LEN = 24 + NODE_DB_HEADER_LEN = 28 EXTENSION_DB_HEADER_LEN = 16 ) @@ -32,6 +37,7 @@ func (id NodeID) MarshalJSON() ([]byte, error) { str := id.String() return json.Marshal(&str) } + func (id *NodeID) UnmarshalJSON(bytes []byte) error { var id_str string err := json.Unmarshal(bytes, &id_str) @@ -95,6 +101,7 @@ type QueuedSignal struct { // Default message channel size for nodes // Nodes represent a group of extensions that can be collectively addressed type Node struct { + Key *ecdsa.PrivateKey ID NodeID Type NodeType Extensions map[ExtType]Extension @@ -268,16 +275,24 @@ func GetExt[T Extension](node *Node) (T, error) { func (node *Node) Serialize() ([]byte, error) { extensions := make([]ExtensionDB, len(node.Extensions)) + + key_bytes, err := x509.MarshalECPrivateKey(node.Key) + if err != nil { + return nil, err + } + node_db := NodeDB{ Header: NodeDBHeader{ Magic: NODE_DB_MAGIC, TypeHash: Hash(node.Type), + KeyLength: uint32(len(key_bytes)), BufferSize: node.BufferSize, NumExtensions: uint32(len(extensions)), NumQueuedSignals: uint32(len(node.SignalQueue)), }, Extensions: extensions, QueuedSignals: node.SignalQueue, + KeyBytes: key_bytes, } i := 0 @@ -299,8 +314,23 @@ func (node *Node) Serialize() ([]byte, error) { return node_db.Serialize(), nil } +func KeyID(pub *ecdsa.PublicKey) NodeID { + ser := elliptic.Marshal(pub.Curve, pub.X, pub.Y) + str := uuid.NewHash(sha512.New(), ZeroUUID, ser, 3) + return NodeID(str) +} + // Create a new node in memory and start it's event loop -func NewNode(ctx *Context, id NodeID, node_type NodeType, buffer_size uint32, queued_signals []QueuedSignal, extensions ...Extension) *Node { +// TODO: Change panics to errors +func NewNode(ctx *Context, key *ecdsa.PrivateKey, node_type NodeType, buffer_size uint32, queued_signals []QueuedSignal, extensions ...Extension) *Node { + var err error + if key == nil { + key, err = ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) + if err != nil { + panic(err) + } + } + id := KeyID(&key.PublicKey) _, exists := ctx.Node(id) if exists == true { panic("Attempted to create an existing node") @@ -334,6 +364,7 @@ func NewNode(ctx *Context, id NodeID, node_type NodeType, buffer_size uint32, qu next_signal, timeout_chan := SoonestSignal(queued_signals) node := &Node{ + Key: key, ID: id, Type: node_type, Extensions: ext_map, @@ -344,7 +375,7 @@ func NewNode(ctx *Context, id NodeID, node_type NodeType, buffer_size uint32, qu NextSignal: next_signal, } ctx.AddNode(id, node) - err := WriteNode(ctx, node) + err = WriteNode(ctx, node) if err != nil { panic(err) } @@ -384,6 +415,7 @@ type NodeDBHeader struct { NumExtensions uint32 NumQueuedSignals uint32 BufferSize uint32 + KeyLength uint32 TypeHash uint64 } @@ -391,6 +423,7 @@ type NodeDB struct { Header NodeDBHeader QueuedSignals []QueuedSignal Extensions []ExtensionDB + KeyBytes []byte } //TODO: add size safety checks @@ -403,7 +436,8 @@ func NewNodeDB(data []byte) (NodeDB, error) { num_extensions := binary.BigEndian.Uint32(data[4:8]) num_queued_signals := binary.BigEndian.Uint32(data[8:12]) buffer_size := binary.BigEndian.Uint32(data[12:16]) - node_type_hash := binary.BigEndian.Uint64(data[16:24]) + key_length := binary.BigEndian.Uint32(data[16:20]) + node_type_hash := binary.BigEndian.Uint64(data[20:28]) ptr += NODE_DB_HEADER_LEN @@ -411,6 +445,14 @@ func NewNodeDB(data []byte) (NodeDB, error) { return zero, fmt.Errorf("header has incorrect magic 0x%x", magic) } + key_bytes := make([]byte, key_length) + n := copy(key_bytes, data[ptr:(ptr+int(key_length))]) + if n != int(key_length) { + return zero, fmt.Errorf("not enough key bytes: %d", n) + } + + ptr += int(key_length) + extensions := make([]ExtensionDB, num_extensions) for i, _ := range(extensions) { cur := data[ptr:] @@ -443,9 +485,11 @@ func NewNodeDB(data []byte) (NodeDB, error) { Magic: magic, TypeHash: node_type_hash, BufferSize: buffer_size, + KeyLength: key_length, NumExtensions: num_extensions, NumQueuedSignals: num_queued_signals, }, + KeyBytes: key_bytes, Extensions: extensions, QueuedSignals: queued_signals, }, nil @@ -461,12 +505,14 @@ func (header NodeDBHeader) Serialize() []byte { binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions) binary.BigEndian.PutUint32(ret[8:12], header.NumQueuedSignals) binary.BigEndian.PutUint32(ret[12:16], header.BufferSize) - binary.BigEndian.PutUint64(ret[16:24], header.TypeHash) + binary.BigEndian.PutUint32(ret[16:20], header.KeyLength) + binary.BigEndian.PutUint64(ret[20:28], header.TypeHash) return ret } func (node NodeDB) Serialize() []byte { ser := node.Header.Serialize() + ser = append(ser, node.KeyBytes...) for _, extension := range(node.Extensions) { ser = append(ser, extension.Serialize()...) } @@ -541,6 +587,20 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, err } + key, err := x509.ParseECPrivateKey(node_db.KeyBytes) + if err != nil { + return nil, err + } + + if key.PublicKey.Curve != ctx.ECDSA { + return nil, fmt.Errorf("%s - wrong ec curve for private key: %+v, expected %+v", id, key.PublicKey.Curve, ctx.ECDSA) + } + + key_id := KeyID(&key.PublicKey) + if key_id != id { + return nil, fmt.Errorf("KeyID(%s) != %s", key_id, id) + } + node_type, known := ctx.Types[node_db.Header.TypeHash] if known == false { return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash) @@ -548,7 +608,8 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { next_signal, timeout_chan := SoonestSignal(node_db.QueuedSignals) node := &Node{ - ID: id, + Key: key, + ID: key_id, Type: node_type.Type, Extensions: map[ExtType]Extension{}, MsgChan: make(chan Msg, node_db.Header.BufferSize), diff --git a/node_test.go b/node_test.go index 3a5019f..a1f0ada 100644 --- a/node_test.go +++ b/node_test.go @@ -3,6 +3,8 @@ package graphvent import ( "testing" "time" + "crypto/rand" + "crypto/ecdsa" ) func TestNodeDB(t *testing.T) { @@ -11,7 +13,7 @@ func TestNodeDB(t *testing.T) { err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) fatalErr(t, err) - node := NewNode(ctx, RandID(), node_type, 10, nil, NewGroupExt(nil)) + node := NewNode(ctx, nil, node_type, 10, nil, NewGroupExt(nil)) ctx.Nodes = NodeMap{} _, err = ctx.GetNode(node.ID) @@ -24,8 +26,13 @@ func TestNodeRead(t *testing.T) { err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, GroupExtType}) fatalErr(t, err) - n1_id := RandID() - n2_id := RandID() + n1_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) + fatalErr(t, err) + n2_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) + fatalErr(t, err) + + n1_id := KeyID(&n1_key.PublicKey) + n2_id := KeyID(&n2_key.PublicKey) ctx.Log.Logf("test", "N1: %s", n1_id) ctx.Log.Logf("test", "N2: %s", n2_id) @@ -34,12 +41,12 @@ func TestNodeRead(t *testing.T) { n1_id: Actions{MakeAction(ReadResultSignalType, "+")}, }) n2_listener := NewListenerExt(10) - n2 := NewNode(ctx, n2_id, node_type, 10, nil, NewACLExt(n2_policy), NewGroupExt(nil), n2_listener) + n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(n2_policy), NewGroupExt(nil), n2_listener) n1_policy := NewPerNodePolicy(map[NodeID]Actions{ n2_id: Actions{MakeAction(ReadSignalType, "+")}, }) - n1 := NewNode(ctx, n1_id, node_type, 10, nil, NewACLExt(n1_policy), NewGroupExt(nil)) + n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(n1_policy), NewGroupExt(nil)) ctx.Send(n2.ID, n1.ID, NewReadSignal(map[ExtType][]string{ GroupExtType: []string{"members"}, diff --git a/signal.go b/signal.go index 3353651..a1ba3fb 100644 --- a/signal.go +++ b/signal.go @@ -29,6 +29,7 @@ func (signal_type SignalType) String() string { type Signal interface { Serializable[SignalType] Direction() SignalDirection + MarshalJSON() ([]byte, error) Permission() Action } @@ -49,8 +50,12 @@ func (signal BaseSignal) Direction() SignalDirection { return signal.SignalDirection } +func (signal BaseSignal) MarshalJSON() ([]byte, error) { + return json.Marshal(signal) +} + func (signal BaseSignal) Serialize() ([]byte, error) { - return json.MarshalIndent(signal, "", " ") + return signal.MarshalJSON() } func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal { @@ -183,3 +188,4 @@ func NewReadResultSignal(exts map[ExtType]map[string]interface{}) ReadResultSign } } +