Changed NodeID from string to UUID, filled in more auth

graph-rework-2
noah metz 2023-07-19 20:03:13 -06:00
parent 6d0925f20f
commit 374fd6e487
9 changed files with 129 additions and 70 deletions

@ -23,6 +23,7 @@ import (
"crypto/sha512"
"crypto/rand"
"crypto/x509"
"github.com/google/uuid"
)
type AuthReqJSON struct {
@ -65,26 +66,26 @@ type AuthRespJSON struct {
Signature []byte `json:"signature"`
}
func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, []byte, error) {
func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) {
// Check if req.Time is within +- 1 second of now
now := time.Now()
earliest := now.Add(-1 * time.Second)
latest := now.Add(1 * time.Second)
// If req.Time is before the earliest acceptable time, or after the latest acceptible time
if req.Time.Compare(earliest) == -1 {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_LATE: %s", req.Time)
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_LATE: %s", req.Time)
} else if req.Time.Compare(latest) == 1 {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time)
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time)
}
x, y := elliptic.Unmarshal(thread.Key.Curve, req.Pubkey)
if x == nil {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey)
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey)
}
remote, err := thread.ECDH.NewPublicKey(req.ECDHPubkey)
if err != nil {
return AuthRespJSON{}, nil, err
return AuthRespJSON{}, nil, nil, err
}
// Verify the signature
@ -92,23 +93,25 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, []byte,
sig_data := append(req.ECDHPubkey, time_bytes...)
sig_hash := sha512.Sum512(sig_data)
remote_key := &ecdsa.PublicKey{
Curve: thread.Key.Curve,
X: x,
Y: y,
}
verified := ecdsa.VerifyASN1(
&ecdsa.PublicKey{
Curve: thread.Key.Curve,
X: x,
Y: y,
},
remote_key,
sig_hash[:],
req.Signature,
)
if verified == false {
return AuthRespJSON{}, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req)
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req)
}
ec_key, err := thread.ECDH.GenerateKey(rand.Reader)
if err != nil {
return AuthRespJSON{}, nil, err
return AuthRespJSON{}, nil, nil, err
}
ec_key_pub := ec_key.PublicKey().Bytes()
@ -120,35 +123,37 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, []byte,
resp_sig, err := ecdsa.SignASN1(rand.Reader, thread.Key, resp_sig_hash[:])
if err != nil {
return AuthRespJSON{}, nil, err
return AuthRespJSON{}, nil, nil, err
}
shared_secret, err := ec_key.ECDH(remote)
if err != nil {
return AuthRespJSON{}, nil, err
return AuthRespJSON{}, nil, nil, err
}
return AuthRespJSON{
Granted: granted,
ECDHPubkey: ec_key_pub,
Signature: resp_sig,
}, shared_secret, nil
}, remote_key, shared_secret, nil
}
type AuthData struct {
Granted time.Time
Pubkey ecdh.PublicKey
ECDHClient ecdh.PublicKey
Pubkey *ecdsa.PublicKey
Shared []byte
}
type AuthDataJSON struct {
Granted time.Time `json:"granted"`
Pubkey []byte `json:"pbkey"`
ECDHClient []byte `json:"ecdh_client"`
Pubkey []byte `json:"pubkey"`
Shared []byte `json:"shared"`
}
func HashKey(pub []byte) uint64 {
return 0
func KeyID(pub *ecdsa.PublicKey) NodeID {
ser := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
str := uuid.NewHash(sha512.New(), ZeroUUID, ser, 3)
return NodeID(str)
}
func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *http.Request) {
@ -169,7 +174,7 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
return
}
resp, _, err := NewAuthRespJSON(server, req)
resp, remote_id, _, err := NewAuthRespJSON(server, req)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %e", err)
return
@ -192,13 +197,13 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_SUCCESS: %s", str)
key_hash := HashKey(req.Pubkey)
key_hash := KeyID(remote_id)
_, exists := server.AuthMap[key_hash]
if exists {
// New user
ctx.Log.Logf("gql", "REFRESHING AUTH FOR %+s", req.Pubkey)
} else {
// Existing user
ctx.Log.Logf("gql", "AUTHORIZING NEW USER %+s", req.Pubkey)
}
}
@ -578,7 +583,7 @@ type GQLThread struct {
http_server *http.Server
http_done *sync.WaitGroup
Listen string
AuthMap map[uint64]AuthData
AuthMap map[NodeID]AuthData
Key *ecdsa.PrivateKey
ECDH ecdh.Curve
}
@ -604,7 +609,7 @@ func (thread * GQLThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo
type GQLThreadJSON struct {
SimpleThreadJSON
Listen string `json:"listen"`
AuthMap map[uint64]AuthData `json:"auth_map"`
AuthMap map[string]AuthDataJSON `json:"auth_map"`
Key []byte `json:"key"`
ECDH uint8 `json:"ecdh_curve"`
}
@ -633,10 +638,19 @@ func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON {
panic(err)
}
auth_map := map[string]AuthDataJSON{}
for id, data := range(thread.AuthMap) {
auth_map[id.String()] = AuthDataJSON{
Granted: data.Granted,
Pubkey: elliptic.Marshal(data.Pubkey.Curve, data.Pubkey.X, data.Pubkey.Y),
Shared: thread.AuthMap[id].Shared,
}
}
return GQLThreadJSON{
SimpleThreadJSON: thread_json,
Listen: thread.Listen,
AuthMap: thread.AuthMap,
AuthMap: auth_map,
Key: ser_key,
ECDH: ecdh_curve_ids[thread.ECDH],
}
@ -660,7 +674,26 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e
}
thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key)
thread.AuthMap = j.AuthMap
thread.AuthMap = map[NodeID]AuthData{}
for id_str, auth_json := range(j.AuthMap) {
id, err := ParseID(id_str)
if err != nil {
return nil, err
}
x, y := elliptic.Unmarshal(key.Curve, auth_json.Pubkey)
if x == nil {
return nil, fmt.Errorf("Failed to load public key for curve %+v from %+v", key.Curve, auth_json.Pubkey)
}
thread.AuthMap[id] = AuthData{
Granted: auth_json.Granted,
Pubkey: &ecdsa.PublicKey{
Curve: key.Curve,
X: x,
Y: y,
},
Shared: auth_json.Shared,
}
}
nodes[id] = &thread
err = RestoreSimpleThread(ctx, &thread, j.SimpleThreadJSON, nodes)
@ -675,7 +708,7 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh
return GQLThread{
SimpleThread: NewSimpleThread(id, name, state_name, reflect.TypeOf((*ParentThreadInfo)(nil)), gql_actions, gql_handlers),
Listen: listen,
AuthMap: map[uint64]AuthData{},
AuthMap: map[NodeID]AuthData{},
http_done: &sync.WaitGroup{},
Key: key,
ECDH: ecdh_curve,

@ -823,14 +823,19 @@ func GQLMutationSendUpdate() *graphql.Field {
return nil, fmt.Errorf("Bad direction: %d", signal_map["Direction"])
}
id , ok := p.Args["id"].(string)
id_str, ok := p.Args["id"].(string)
if ok == false {
return nil, fmt.Errorf("Failed to cast arg id to string")
}
id, err := ParseID(id_str)
if err != nil {
return nil, err
}
var node Node = nil
err := UseStates(ctx, []Node{server}, func(nodes NodeMap) (error){
node = FindChild(ctx, server, NodeID(id), nodes)
err = UseStates(ctx, []Node{server}, func(nodes NodeMap) (error){
node = FindChild(ctx, server, id, nodes)
if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
}

@ -53,13 +53,13 @@ func TestGQLThread(t * testing.T) {
}
func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{})
ctx := logTestContext(t, []string{"test"})
l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
l1 := &l1_r
t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
t1 := &t1_r
update_channel := UpdateChannel(t1, 10, "test")
update_channel := UpdateChannel(t1, 10, NodeID{})
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err)
@ -86,7 +86,7 @@ func TestGQLDBLoad(t * testing.T) {
fatalErr(t, err)
err = ThreadLoop(ctx, gql, "start")
if errors.Is(err, NewThreadAbortedError("")) {
if errors.Is(err, NewThreadAbortedError(NodeID{})) {
ctx.Log.Logf("test", "Main thread aborted by signal: %s", err)
} else {
fatalErr(t, err)
@ -97,8 +97,8 @@ func TestGQLDBLoad(t * testing.T) {
err = UseStates(ctx, []Node{gql, t1}, func(nodes NodeMap) error {
ser1, err := gql.Serialize()
ser2, err := t1.Serialize()
ctx.Log.Logf("thread", "\n%s\n\n", ser1)
ctx.Log.Logf("thread", "\n%s\n\n", ser2)
ctx.Log.Logf("test", "\n%s\n\n", ser1)
ctx.Log.Logf("test", "\n%s\n\n", ser2)
return err
})
@ -112,7 +112,7 @@ func TestGQLDBLoad(t * testing.T) {
ctx.Log.Logf("test", "\n%s\n\n", ser)
child := gql_loaded.(Thread).Children()[0].(*SimpleThread)
t1_loaded = child
update_channel_2 = UpdateChannel(t1_loaded, 10, "test")
update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{})
err = UseMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error {
ser, err := child.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser)
@ -123,7 +123,7 @@ func TestGQLDBLoad(t * testing.T) {
})
err = ThreadLoop(ctx, gql_loaded.(Thread), "restore")
if errors.Is(err, NewThreadAbortedError("")) {
if errors.Is(err, NewThreadAbortedError(NodeID{})) {
ctx.Log.Logf("test", "Main thread aborted by signal: %s", err)
} else {
fatalErr(t, err)
@ -143,7 +143,7 @@ func TestGQLAuth(t * testing.T) {
var update_channel chan GraphSignal
err = UseStates(ctx, []Node{gql_t}, func(nodes NodeMap) error {
update_channel = UpdateChannel(gql_t, 10, "test")
update_channel = UpdateChannel(gql_t, 10, NodeID{})
return nil
})
fatalErr(t, err)

@ -24,7 +24,7 @@ func (t * GraphTester) WaitForValue(ctx * Context, listener chan GraphSignal, si
if signal.Type() == signal_type {
ctx.Log.Logf("test", "SIGNAL_TYPE_FOUND: %s - %s %+v\n", signal.Type(), signal.Source(), listener)
if source == nil {
if signal.Source() == "" {
if signal.Source() == ZeroID {
return signal
}
} else {

@ -64,13 +64,12 @@ type SimpleLockableJSON struct {
Owner *NodeID `json:"owner"`
Dependencies []NodeID `json:"dependencies"`
Requirements []NodeID `json:"requirements"`
LocksHeld map[NodeID]*NodeID `json:"locks_held"`
LocksHeld map[string]*NodeID `json:"locks_held"`
}
func (lockable * SimpleLockable) Serialize() ([]byte, error) {
lockable_json := NewSimpleLockableJSON(lockable)
return json.MarshalIndent(&lockable_json, "", " ")
}
func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON {
@ -90,13 +89,13 @@ func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON {
owner_id = &new_str
}
locks_held := map[NodeID]*NodeID{}
locks_held := map[string]*NodeID{}
for lockable_id, node := range(lockable.locks_held) {
if node == nil {
locks_held[lockable_id] = nil
locks_held[lockable_id.String()] = nil
} else {
str := node.ID()
locks_held[lockable_id] = &str
locks_held[lockable_id.String()] = &str
}
}
return SimpleLockableJSON{
@ -590,7 +589,8 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO
lockable.AddRequirement(req_l)
}
for l_id, h_id := range(j.LocksHeld) {
for l_id_str, h_id := range(j.LocksHeld) {
l_id, err := ParseID(l_id_str)
l, err := LoadNodeRecurse(ctx, l_id, nodes)
if err != nil {
return err

@ -74,7 +74,7 @@ func TestLockableSelfLock(t * testing.T) {
fatalErr(t, err)
err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) {
owner_id := NodeID("")
owner_id := NodeID{}
if l1.owner != nil {
owner_id = l1.owner.ID()
}
@ -120,7 +120,7 @@ func TestLockableSelfLockTiered(t * testing.T) {
fatalErr(t, err)
err = UseStates(ctx, []Node{l1, l2, l3}, func(nodes NodeMap) (error) {
owner_1 := NodeID("")
owner_1 := NodeID{}
if l1.owner != nil {
owner_1 = l1.owner.ID()
}
@ -128,7 +128,7 @@ func TestLockableSelfLockTiered(t * testing.T) {
return fmt.Errorf("l1 is owned by %s instead of l3", owner_1)
}
owner_2 := NodeID("")
owner_2 := NodeID{}
if l2.owner != nil {
owner_2 = l2.owner.ID()
}
@ -181,7 +181,7 @@ func TestLockableLockOther(t * testing.T) {
fatalErr(t, err)
err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) {
owner_id := NodeID("")
owner_id := NodeID{}
if l1.owner != nil {
owner_id = l1.owner.ID()
}
@ -236,7 +236,7 @@ func TestLockableLockSimpleConflict(t * testing.T) {
fatalErr(t, err)
err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) {
owner_id := NodeID("")
owner_id := NodeID{}
if l1.owner != nil {
owner_id = l1.owner.ID()
}
@ -304,7 +304,7 @@ func TestLockableSimpleUpdate(t * testing.T) {
l1 := &l1_r
update_channel := UpdateChannel(l1, 1, "test")
update_channel := UpdateChannel(l1, 1, NodeID{})
go func() {
UseStates(ctx, []Node{l1}, func(nodes NodeMap) error {
@ -333,7 +333,7 @@ func TestLockableDownUpdate(t * testing.T) {
})
fatalErr(t, err)
update_channel := UpdateChannel(l1, 1, "test")
update_channel := UpdateChannel(l1, 1, NodeID{})
go func() {
UseStates(ctx, []Node{l2}, func(nodes NodeMap) error {
@ -362,7 +362,7 @@ func TestLockableUpUpdate(t * testing.T) {
})
fatalErr(t, err)
update_channel := UpdateChannel(l3, 1, "test")
update_channel := UpdateChannel(l3, 1, NodeID{})
go func() {
UseStates(ctx, []Node{l2}, func(nodes NodeMap) error {
@ -390,7 +390,7 @@ func TestOwnerNotUpdatedTwice(t * testing.T) {
})
fatalErr(t, err)
update_channel := UpdateChannel(l2, 1, "test")
update_channel := UpdateChannel(l2, 1, NodeID{})
go func() {
err := UseStates(ctx, []Node{l1}, func(nodes NodeMap) error {

@ -10,9 +10,26 @@ import (
)
// IDs are how nodes are uniquely identified, and can be serialized for the database
type NodeID string
type NodeID uuid.UUID
var ZeroUUID = uuid.UUID{}
var ZeroID = NodeID(ZeroUUID)
func (id NodeID) Serialize() []byte {
return []byte(id)
ser, _ := (uuid.UUID)(id).MarshalBinary()
return ser
}
func (id NodeID) String() string {
return (uuid.UUID)(id).String()
}
func ParseID(str string) (NodeID, error) {
id_uuid, err := uuid.Parse(str)
if err != nil {
return NodeID{}, err
}
return NodeID(id_uuid), nil
}
// Types are how nodes are associated with structs at runtime(and from the DB)
@ -27,8 +44,7 @@ func (node_type NodeType) Hash() uint64 {
// Generate a random NodeID
func RandID() NodeID {
uuid_str := uuid.New().String()
return NodeID(uuid_str)
return NodeID(uuid.New())
}
// A Node represents data that can be read by multiple goroutines and written to by one, with a unique ID attached, and a method to process updates(including propagating them to connected nodes)
@ -169,7 +185,7 @@ func getNodeBytes(node Node) ([]byte, error) {
}
ser, err := node.Serialize()
if err != nil {
return nil, fmt.Errorf("DB_SERIALIZE_ERROR: %e", err)
return nil, fmt.Errorf("DB_SERIALIZE_ERROR: %s", err)
}
header := NewDBHeader(node.Type())

@ -48,7 +48,7 @@ func (signal BaseSignal) Type() string {
}
func NewBaseSignal(source Node, _type string, direction SignalDirection) BaseSignal {
var source_id NodeID = "nil"
var source_id NodeID = NodeID{}
if source != nil {
source_id = source.ID()
}

@ -7,6 +7,7 @@ import (
"errors"
"reflect"
"encoding/json"
"github.com/google/uuid"
)
// SimpleThread.Signal updates the parent and children, and sends the signal to an internal channel
@ -305,7 +306,7 @@ func (thread * SimpleThread) SignalChannel() <-chan GraphSignal {
type SimpleThreadJSON struct {
Parent *NodeID `json:"parent"`
Children map[NodeID]interface{} `json:"children"`
Children map[string]interface{} `json:"children"`
Timeout time.Time `json:"timeout"`
TimeoutAction string `json:"timeout_action"`
StateName string `json:"state_name"`
@ -313,9 +314,9 @@ type SimpleThreadJSON struct {
}
func NewSimpleThreadJSON(thread *SimpleThread) SimpleThreadJSON {
children := map[NodeID]interface{}{}
children := map[string]interface{}{}
for _, child := range(thread.children) {
children[child.ID()] = thread.child_info[child.ID()]
children[child.ID().String()] = thread.child_info[child.ID()]
}
var parent_id *NodeID = nil
@ -379,7 +380,11 @@ func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes
thread.SetParent(p_t)
}
for id, info_raw := range(j.Children) {
for id_str, info_raw := range(j.Children) {
id, err := ParseID(id_str)
if err != nil {
return err
}
child_node, err := LoadNodeRecurse(ctx, id, nodes)
if err != nil {
return err
@ -572,7 +577,7 @@ var ThreadRestore = func(ctx * Context, thread Thread) {
var ThreadStart = func(ctx * Context, thread Thread) error {
return UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error {
owner_id := NodeID("")
owner_id := NodeID{}
if thread.Owner() != nil {
owner_id = thread.Owner().ID()
}
@ -636,12 +641,12 @@ var ThreadWait = func(ctx * Context, thread Thread) (string, error) {
type ThreadAbortedError NodeID
func (e ThreadAbortedError) Is(target error) bool {
error_type := reflect.TypeOf(ThreadAbortedError(""))
error_type := reflect.TypeOf(ThreadAbortedError(NodeID{}))
target_type := reflect.TypeOf(target)
return error_type == target_type
}
func (e ThreadAbortedError) Error() string {
return fmt.Sprintf("Aborted by %s", string(e))
return fmt.Sprintf("Aborted by %s", (uuid.UUID)(e).String())
}
func NewThreadAbortedError(aborter NodeID) ThreadAbortedError {
return ThreadAbortedError(aborter)