diff --git a/context.go b/context.go index b4d6d8a..21ad0c2 100644 --- a/context.go +++ b/context.go @@ -5,13 +5,20 @@ import ( "fmt" ) +//Function to load an extension from bytes type ExtensionLoadFunc func(*Context, []byte) (Extension, error) +// Information about a loaded extension type ExtensionInfo struct { Load ExtensionLoadFunc Type ExtType Data interface{} } +// Information about a loaded node type +type NodeInfo struct { + Type NodeType +} + // A Context is all the data needed to run a graphvent type Context struct { // DB is the database connection used to load and write nodes @@ -20,6 +27,8 @@ type Context struct { Log Logger // A mapping between type hashes and their corresponding extension definitions Extensions map[uint64]ExtensionInfo + // A mapping between type hashes and their corresponding node definitions + Types map[uint64]NodeInfo // All loaded Nodes Nodes map[NodeID]*Node } @@ -30,8 +39,21 @@ func (ctx *Context) ExtByType(ext_type ExtType) ExtensionInfo { return ext } +func (ctx *Context) RegisterNodeType(node_type NodeType) error { + type_hash := node_type.Hash() + _, exists := ctx.Types[type_hash] + if exists == true { + return fmt.Errorf("Cannot register node type %s, type already exists in context", node_type) + } + + ctx.Types[type_hash] = NodeInfo{ + Type: node_type, + } + return nil +} + // Add a node to a context, returns an error if the def is invalid or already exists in the context -func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc) error { +func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc, data interface{}) error { if load_fn == nil { return fmt.Errorf("def has no load function") } @@ -45,6 +67,7 @@ func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFun ctx.Extensions[type_hash] = ExtensionInfo{ Load: load_fn, Type: ext_type, + Data: data, } return nil } @@ -55,15 +78,41 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { DB: db, Log: log, Extensions: map[uint64]ExtensionInfo{}, + Types: map[uint64]NodeInfo{}, Nodes: map[NodeID]*Node{}, } - err := ctx.RegisterExtension(ACLExtType, LoadACLExtension) + err := ctx.RegisterExtension(ACLExtType, LoadACLExt, nil) + if err != nil { + return nil, err + } + + err = ctx.RegisterExtension(ACLPolicyExtType, LoadACLPolicyExt, NewACLPolicyExtContext()) + if err != nil { + return nil, err + } + + err = ctx.RegisterExtension(LockableExtType, LoadLockableExt, nil) + if err != nil { + return nil, err + } + + err = ctx.RegisterExtension(ThreadExtType, LoadThreadExt, NewThreadExtContext()) + if err != nil { + return nil, err + } + + err = ctx.RegisterExtension(ECDHExtType, LoadECDHExt, nil) + if err != nil { + return nil, err + } + + err = ctx.RegisterExtension(GroupExtType, LoadGroupExt, nil) if err != nil { return nil, err } - err = ctx.RegisterExtension(ACLPolicyExtType, LoadACLPolicyExtension) + err = ctx.RegisterExtension(GQLExtType, LoadGQLExt, NewGQLExtContext()) if err != nil { return nil, err } diff --git a/gql.go b/gql.go index 50c338b..3987748 100644 --- a/gql.go +++ b/gql.go @@ -70,7 +70,7 @@ type AuthRespJSON struct { Signature []byte `json:"signature"` } -func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) { +func NewAuthRespJSON(gql_ext *GQLExt, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) { // Check if req.Time is within +- 1 second of now now := time.Now() earliest := now.Add(-1 * time.Second) @@ -82,12 +82,12 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time) } - x, y := elliptic.Unmarshal(thread.Key.Curve, req.Pubkey) + x, y := elliptic.Unmarshal(gql_ext.Key.Curve, req.Pubkey) if x == nil { return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey) } - remote, err := thread.ECDH.NewPublicKey(req.ECDHPubkey) + remote, err := gql_ext.ECDH.NewPublicKey(req.ECDHPubkey) if err != nil { return AuthRespJSON{}, nil, nil, err } @@ -98,7 +98,7 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P sig_hash := sha512.Sum512(sig_data) remote_key := &ecdsa.PublicKey{ - Curve: thread.Key.Curve, + Curve: gql_ext.Key.Curve, X: x, Y: y, } @@ -113,7 +113,7 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req) } - ec_key, err := thread.ECDH.GenerateKey(rand.Reader) + ec_key, err := gql_ext.ECDH.GenerateKey(rand.Reader) if err != nil { return AuthRespJSON{}, nil, nil, err } @@ -125,7 +125,7 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P resp_sig_data := append(ec_key_pub, time_ser...) resp_sig_hash := sha512.Sum512(resp_sig_data) - resp_sig, err := ecdsa.SignASN1(rand.Reader, thread.Key, resp_sig_hash[:]) + resp_sig, err := ecdsa.SignASN1(rand.Reader, gql_ext.Key, resp_sig_hash[:]) if err != nil { return AuthRespJSON{}, nil, nil, err } @@ -156,7 +156,7 @@ func ParseAuthRespJSON(resp AuthRespJSON, ecdsa_curve elliptic.Curve, ecdh_curve return shared_secret, nil } -func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *http.Request) { +func AuthHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { ctx.Log.Logf("gql", "GQL_AUTH_REQUEST: %s", r.RemoteAddr) enableCORS(&w) @@ -174,7 +174,7 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt return } - resp, remote_id, shared, err := NewAuthRespJSON(server, req) + resp, _, _, err := NewAuthRespJSON(gql_ext, req) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %s", err) return @@ -195,34 +195,31 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt return } - key_id := KeyID(remote_id) - - _, exists := server.UserMap[key_id] - if exists { + /*if exists { ctx.Log.Logf("gql", "REFRESHING AUTH FOR %s", key_id) } else { ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, shared) new_user := NewUser(fmt.Sprintf("GQL_USER %s", key_id.String()), time.Now(), remote_id, shared) context := NewWriteContext(ctx) - err := UpdateStates(context, server, NewLockMap(LockMap{ - server.ID(): LockInfo{ + err := UpdateStates(context, server, ACLMap{ + server.ID: ACLInfo{ Node: server, Resources: []string{"users"}, }, - new_user.ID(): LockInfo{ + new_user.ID: ACLInfo{ Node: &new_user, Resources: nil, }, - }), func(context *StateContext) error { - server.UserMap[key_id] = &new_user + }, func(context *StateContext) error { + server.Users[key_id] = &new_user return nil }) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_UPDATE_ERR: %s", err) return } - } + }*/ } } @@ -363,11 +360,13 @@ func checkForAuthHeader(header http.Header) (string, bool) { type ResolveContext struct { Context *Context - Server *GQLThread - User *User + GQLContext *GQLExtContext + Server *Node + Ext *GQLExt + User *Node } -func NewResolveContext(ctx *Context, server *GQLThread, r *http.Request) (*ResolveContext, error) { +func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { username, password, ok := r.BasicAuth() if ok == false { return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") @@ -378,25 +377,29 @@ func NewResolveContext(ctx *Context, server *GQLThread, r *http.Request) (*Resol return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) } - user, exists := server.UserMap[auth_id] + user, exists := gql_ext.Users[auth_id] if exists == false { return nil, fmt.Errorf("GQL_REQUEST_ERR: no existing authorization for client %s", auth_id) } - if base64.StdEncoding.EncodeToString(user.Shared) != password { + user_ext, err := GetExt[*ECDHExt](user) + if err != nil { + return nil, err + } + + if base64.StdEncoding.EncodeToString(user_ext.Shared) != password { return nil, fmt.Errorf("GQL_AUTH_FAIL") } return &ResolveContext{ Context: ctx, + GQLContext: ctx.Extensions[GQLExtType.Hash()].Data.(*GQLExtContext), Server: server, User: user, }, nil } -func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) { - gql_ctx := context.Background() - +func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r * http.Request) { ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr) enableCORS(&w) @@ -406,7 +409,7 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") - resolve_context, err := NewResolveContext(ctx, server, r) + resolve_context, err := NewResolveContext(ctx, server, gql_ext, r) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err))) @@ -414,7 +417,7 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht } req_ctx := context.Background() - req_ctx = context.WithValue(gql_ctx, "resolve", resolve_context) + req_ctx = context.WithValue(req_ctx, "resolve", resolve_context) str, err := io.ReadAll(r.Body) if err != nil { @@ -425,8 +428,10 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht query := GQLPayload{} json.Unmarshal(str, &query) + gql_context := ctx.Extensions[GQLExtType.Hash()].Data.(*GQLExtContext) + params := graphql.Params{ - Schema: ctx.GQL.Schema, + Schema: gql_context.Schema, Context: req_ctx, RequestString: query.Query, } @@ -494,11 +499,7 @@ func GQLWSDo(ctx * Context, p graphql.Params) chan *graphql.Result { return sendOneResultAndClose(res) } -func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) { - gql_ctx := context.Background() - gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx) - gql_ctx = context.WithValue(gql_ctx, "gql_server", server) - +func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r * http.Request) { ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr) enableCORS(&w) @@ -508,7 +509,7 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, * } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") - resolve_context, err := NewResolveContext(ctx, server, r) + resolve_context, err := NewResolveContext(ctx, server, gql_ext, r) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) return @@ -557,8 +558,9 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, * } } else if msg.Type == "subscribe" { ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload) + gql_context := ctx.Extensions[GQLExtType.Hash()].Data.(*GQLExtContext) params := graphql.Params{ - Schema: ctx.GQL.Schema, + Schema: gql_context.Schema, Context: req_ctx, RequestString: msg.Payload.Query, } @@ -628,35 +630,94 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, * } } -type GQLThread struct { - Thread +// Map of go types to graphql types +type ObjTypeMap map[reflect.Type]*graphql.Object + +// GQL Specific Context information +type GQLExtContext struct { + // Generated GQL schema + Schema graphql.Schema + + // List of GQL types + TypeList []graphql.Type + + // Interface type maps to map go types of specific interfaces to gql types + ValidNodes ObjTypeMap + ValidLockables ObjTypeMap + ValidThreads ObjTypeMap + + BaseNodeType *graphql.Object + BaseLockableType *graphql.Object + BaseThreadType *graphql.Object + + Query *graphql.Object + Mutation *graphql.Object + Subscription *graphql.Object +} + +func NewGQLExtContext() *GQLExtContext { + query := graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{}, + }) + + mutation := graphql.NewObject(graphql.ObjectConfig{ + Name: "Mutation", + Fields: graphql.Fields{}, + }) + + subscription := graphql.NewObject(graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{}, + }) + + context := GQLExtContext{ + Schema: graphql.Schema{}, + TypeList: []graphql.Type{}, + ValidNodes: ObjTypeMap{}, + ValidThreads: ObjTypeMap{}, + ValidLockables: ObjTypeMap{}, + Query: query, + Mutation: mutation, + Subscription: subscription, + BaseNodeType: GQLTypeBaseNode.Type, + BaseLockableType: GQLTypeBaseLockable.Type, + BaseThreadType: GQLTypeBaseThread.Type, + } + + return &context +} + +type GQLExt struct { tcp_listener net.Listener http_server *http.Server - http_done *sync.WaitGroup + http_done sync.WaitGroup tls_key []byte tls_cert []byte Listen string - UserMap map[NodeID]*User + Users NodeMap Key *ecdsa.PrivateKey ECDH ecdh.Curve SubscribeLock sync.Mutex SubscribeListeners []chan GraphSignal } -func (thread *GQLThread) NewSubscriptionChannel(buffer int) chan GraphSignal { - thread.SubscribeLock.Lock() - defer thread.SubscribeLock.Unlock() +func (ext *GQLExt) NewSubscriptionChannel(buffer int) chan GraphSignal { + ext.SubscribeLock.Lock() + defer ext.SubscribeLock.Unlock() new_listener := make(chan GraphSignal, buffer) - thread.SubscribeListeners = append(thread.SubscribeListeners, new_listener) + ext.SubscribeListeners = append(ext.SubscribeListeners, new_listener) return new_listener } -func (thread *GQLThread) Process(context *StateContext, signal GraphSignal) error { +func (ext *GQLExt) Process(context *StateContext, node *Node, signal GraphSignal) error { + ext.SubscribeLock.Lock() + defer ext.SubscribeLock.Unlock() + active_listeners := []chan GraphSignal{} - thread.SubscribeLock.Lock() - for _, listener := range(thread.SubscribeListeners) { + for _, listener := range(ext.SubscribeListeners) { select { case listener <- signal: active_listeners = append(active_listeners, listener) @@ -667,34 +728,38 @@ func (thread *GQLThread) Process(context *StateContext, signal GraphSignal) erro }(listener) } } - thread.SubscribeListeners = active_listeners - thread.SubscribeLock.Unlock() - return thread.Thread.Process(context, signal) -} - -func (thread * GQLThread) Type() NodeType { - return NodeType("gql_thread") -} - -func (thread * GQLThread) Serialize() ([]byte, error) { - thread_json := NewGQLThreadJSON(thread) - return json.MarshalIndent(&thread_json, "", " ") + ext.SubscribeListeners = active_listeners + return nil } -func (thread * GQLThread) Users() map[NodeID]*User { - return thread.UserMap +const GQLExtType = ExtType("gql_thread") +func (ext *GQLExt) Type() ExtType { + return GQLExtType } -type GQLThreadJSON struct { - ThreadJSON +type GQLExtJSON struct { Listen string `json:"listen"` - Users []string `json:"users"` Key []byte `json:"key"` ECDH uint8 `json:"ecdh_curve"` TLSKey []byte `json:"ssl_key"` TLSCert []byte `json:"ssl_cert"` } +func (ext *GQLExt) Serialize() ([]byte, error) { + ser_key, err := x509.MarshalECPrivateKey(ext.Key) + if err != nil { + return nil, err + } + + return json.MarshalIndent(&GQLExtJSON{ + Listen: ext.Listen, + Key: ser_key, + ECDH: ecdh_curve_ids[ext.ECDH], + TLSKey: ext.tls_key, + TLSCert: ext.tls_cert, + }, "", " ") +} + var ecdsa_curves = map[uint8]elliptic.Curve{ 0: elliptic.P256(), } @@ -711,33 +776,13 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{ ecdh.P256(): 0, } -func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON { - thread_json := NewThreadJSON(&thread.Thread) - - ser_key, err := x509.MarshalECPrivateKey(thread.Key) +func LoadGQLExt(ctx *Context, data []byte) (Extension, error) { + var j GQLExtJSON + err := json.Unmarshal(data, &j) if err != nil { - panic(err) - } - - users := make([]string, len(thread.UserMap)) - i := 0 - for id, _ := range(thread.UserMap) { - users[i] = id.String() - i += 1 - } - - return GQLThreadJSON{ - ThreadJSON: thread_json, - Listen: thread.Listen, - Users: users, - Key: ser_key, - ECDH: ecdh_curve_ids[thread.ECDH], - TLSKey: thread.tls_key, - TLSCert: thread.tls_cert, + return nil, err } -} -var LoadGQLThread = LoadJSONNode(func(id NodeID, j GQLThreadJSON) (Node, error) { ecdh_curve, ok := ecdh_curves[j.ECDH] if ok == false { return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH) @@ -748,27 +793,19 @@ var LoadGQLThread = LoadJSONNode(func(id NodeID, j GQLThreadJSON) (Node, error) return nil, err } - thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey) - return &thread, nil -}, func(ctx *Context, thread *GQLThread, j GQLThreadJSON, nodes NodeMap) error { - thread.UserMap = map[NodeID]*User{} - for _, id_str := range(j.Users) { - ctx.Log.Logf("db", "THREAD_LOAD_USER: %s", id_str) - user_id, err := ParseID(id_str) - if err != nil { - return err - } - user, err := LoadNodeRecurse(ctx, user_id, nodes) - if err != nil { - return err - } - thread.UserMap[user_id] = user.(*User) + extension := GQLExt{ + Listen: j.Listen, + Key: key, + ECDH: ecdh_curve, + SubscribeListeners: []chan GraphSignal{}, + tls_key: j.TLSKey, + tls_cert: j.TLSCert, } - return RestoreThread(ctx, thread, j.ThreadJSON, nodes) -}) + return &extension, nil +} -func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLThread { +func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLExt { if tls_cert == nil || tls_key == nil { ssl_key, err := ecdsa.GenerateKey(key.Curve, rand.Reader) if err != nil { @@ -808,12 +845,9 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh tls_cert = ssl_cert_pem tls_key = ssl_key_pem } - return GQLThread{ - Thread: NewThread(id, name, state_name, []InfoType{"parent"}, gql_actions, gql_handlers), + return GQLExt{ Listen: listen, SubscribeListeners: []chan GraphSignal{}, - UserMap: map[NodeID]*User{}, - http_done: &sync.WaitGroup{}, Key: key, ECDH: ecdh_curve, tls_cert: tls_cert, @@ -823,23 +857,26 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh var gql_actions ThreadActions = ThreadActions{ "wait": ThreadWait, - "restore": func(ctx *Context, node ThreadNode) (string, error) { - return "start_server", ThreadRestore(ctx, node, false) + "restore": func(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) { + return "start_server", ThreadRestore(ctx, thread, thread_ext, false) }, - "start": func(ctx * Context, node ThreadNode) (string, error) { - _, err := ThreadStart(ctx, node) + "start": func(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) { + _, err := ThreadStart(ctx, thread, thread_ext) if err != nil { return "", err } - return "start_server", ThreadRestore(ctx, node, true) + return "start_server", ThreadRestore(ctx, thread, thread_ext, true) }, - "start_server": func(ctx * Context, node ThreadNode) (string, error) { - gql_thread := node.(*GQLThread) + "start_server": func(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) { + gql_ext, err := GetExt[*GQLExt](thread) + if err != nil { + return "", err + } mux := http.NewServeMux() - mux.HandleFunc("/auth", AuthHandler(ctx, gql_thread)) - mux.HandleFunc("/gql", GQLHandler(ctx, gql_thread)) - mux.HandleFunc("/gqlws", GQLWSHandler(ctx, gql_thread)) + mux.HandleFunc("/auth", AuthHandler(ctx, thread, gql_ext)) + mux.HandleFunc("/gql", GQLHandler(ctx, thread, gql_ext)) + mux.HandleFunc("/gqlws", GQLWSHandler(ctx, thread, gql_ext)) // Server a graphiql interface(TODO make configurable whether to start this) mux.HandleFunc("/graphiql", GraphiQLHandler()) @@ -849,7 +886,7 @@ var gql_actions ThreadActions = ThreadActions{ mux.Handle("/site/", http.StripPrefix("/site", fs)) http_server := &http.Server{ - Addr: gql_thread.Listen, + Addr: gql_ext.Listen, Handler: mux, } @@ -858,7 +895,7 @@ var gql_actions ThreadActions = ThreadActions{ return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) } - cert, err := tls.X509KeyPair(gql_thread.tls_cert, gql_thread.tls_key) + cert, err := tls.X509KeyPair(gql_ext.tls_cert, gql_ext.tls_key) if err != nil { return "", err } @@ -870,23 +907,21 @@ var gql_actions ThreadActions = ThreadActions{ listener := tls.NewListener(l, &config) - gql_thread.http_done.Add(1) - go func(gql_thread *GQLThread) { - defer gql_thread.http_done.Done() + gql_ext.http_done.Add(1) + go func(qql_ext *GQLExt) { + defer gql_ext.http_done.Done() err := http_server.Serve(listener) if err != http.ErrServerClosed { panic(fmt.Sprintf("Failed to start gql server: %s", err)) } - }(gql_thread) + }(gql_ext) context := NewWriteContext(ctx) - err = UpdateStates(context, node, NewLockMap( - NewLockInfo(node, []string{"http_server"}), - ), func(context *StateContext) error { - gql_thread.tcp_listener = listener - gql_thread.http_server = http_server + err = UpdateStates(context, thread, NewACLInfo(thread, []string{"http_server"}), func(context *StateContext) error { + gql_ext.tcp_listener = listener + gql_ext.http_server = http_server return nil }) @@ -895,18 +930,22 @@ var gql_actions ThreadActions = ThreadActions{ } context = NewReadContext(ctx) - err = Signal(context, gql_thread, gql_thread, NewStatusSignal("server_started", gql_thread.ID())) + err = Signal(context, thread, thread, NewStatusSignal("server_started", thread.ID)) if err != nil { return "", err } return "wait", nil }, - "finish": func(ctx *Context, node ThreadNode) (string, error) { - gql_thread := node.(*GQLThread) - gql_thread.http_server.Shutdown(context.TODO()) - gql_thread.http_done.Wait() - return ThreadFinish(ctx, node) + "finish": func(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) { + gql_ext, err := GetExt[*GQLExt](thread) + if err != nil { + return "", err + } + + gql_ext.http_server.Shutdown(context.TODO()) + gql_ext.http_done.Wait() + return ThreadFinish(ctx, thread, thread_ext) }, } diff --git a/gql_interfaces.go b/gql_interfaces.go index 613df72..473df7c 100644 --- a/gql_interfaces.go +++ b/gql_interfaces.go @@ -3,6 +3,7 @@ package graphvent import ( "github.com/graphql-go/graphql" "reflect" + "fmt" ) func NewField(init func()*graphql.Field) *graphql.Field { @@ -26,22 +27,34 @@ func NewSingleton[K graphql.Type](init func() K, post_init func(K, *graphql.List } } -func addNodeInterfaceFields(i *graphql.Interface) { +func AddNodeInterfaceFields(i *graphql.Interface) { i.AddFieldConfig("ID", &graphql.Field{ Type: graphql.String, }) + + i.AddFieldConfig("TypeHash", &graphql.Field{ + Type: graphql.String, + }) +} + +func PrepTypeResolve(p graphql.ResolveTypeParams) (*ResolveContext, error) { + resolve_context, ok := p.Context.Value("resolve").(*ResolveContext) + if ok == false { + return nil, fmt.Errorf("Bad resolve in params context") + } + return resolve_context, nil } var GQLInterfaceNode = NewSingleton(func() *graphql.Interface { i := graphql.NewInterface(graphql.InterfaceConfig{ Name: "Node", ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { - ctx, ok := p.Context.Value("graph_context").(*Context) - if ok == false { + ctx, err := PrepTypeResolve(p) + if err != nil { return nil } - valid_nodes := ctx.GQL.ValidNodes + valid_nodes := ctx.GQLContext.ValidNodes p_type := reflect.TypeOf(p.Value) for key, value := range(valid_nodes) { @@ -50,9 +63,9 @@ var GQLInterfaceNode = NewSingleton(func() *graphql.Interface { } } - _, ok = p.Value.(Node) + _, ok := p.Value.(Node) if ok == true { - return ctx.GQL.BaseNodeType + return ctx.GQLContext.BaseNodeType } return nil @@ -60,41 +73,21 @@ var GQLInterfaceNode = NewSingleton(func() *graphql.Interface { Fields: graphql.Fields{}, }) - addNodeInterfaceFields(i) + AddNodeInterfaceFields(i) return i }, nil) -func addLockableInterfaceFields(i *graphql.Interface, lockable *graphql.Interface, list *graphql.List) { - addNodeInterfaceFields(i) - - i.AddFieldConfig("Name", &graphql.Field{ - Type: graphql.String, - }) - - i.AddFieldConfig("Requirements", &graphql.Field{ - Type: list, - }) - - i.AddFieldConfig("Dependencies", &graphql.Field{ - Type: list, - }) - - i.AddFieldConfig("Owner", &graphql.Field{ - Type: lockable, - }) -} - var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface { gql_interface_lockable := graphql.NewInterface(graphql.InterfaceConfig{ Name: "Lockable", ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { - ctx, ok := p.Context.Value("graph_context").(*Context) - if ok == false { + ctx, err := PrepTypeResolve(p) + if err != nil { return nil } - valid_lockables := ctx.GQL.ValidLockables + valid_lockables := ctx.GQLContext.ValidLockables p_type := reflect.TypeOf(p.Value) for key, value := range(valid_lockables) { @@ -103,9 +96,9 @@ var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface { } } - _, ok = p.Value.(Lockable) - if ok == true { - return ctx.GQL.BaseLockableType + _, ok := p.Value.(*Node) + if ok == false { + return ctx.GQLContext.BaseLockableType } return nil }, @@ -114,31 +107,30 @@ var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface { return gql_interface_lockable }, func(lockable *graphql.Interface, lockable_list *graphql.List) { - addLockableInterfaceFields(lockable, lockable, lockable_list) -}) - -func addThreadInterfaceFields(i *graphql.Interface, thread *graphql.Interface, list *graphql.List) { - addLockableInterfaceFields(i, GQLInterfaceLockable.Type, GQLInterfaceLockable.List) + lockable.AddFieldConfig("Requirements", &graphql.Field{ + Type: lockable_list, + }) - i.AddFieldConfig("Children", &graphql.Field{ - Type: list, + lockable.AddFieldConfig("Dependencies", &graphql.Field{ + Type: lockable_list, }) - i.AddFieldConfig("Parent", &graphql.Field{ - Type: thread, + lockable.AddFieldConfig("Owner", &graphql.Field{ + Type: lockable, }) -} + AddNodeInterfaceFields(lockable) +}) var GQLInterfaceThread = NewSingleton(func() *graphql.Interface { gql_interface_thread := graphql.NewInterface(graphql.InterfaceConfig{ Name: "Thread", ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { - ctx, ok := p.Context.Value("graph_context").(*Context) - if ok == false { + ctx, err := PrepTypeResolve(p) + if err != nil { return nil } - valid_threads := ctx.GQL.ValidThreads + valid_threads := ctx.GQLContext.ValidThreads p_type := reflect.TypeOf(p.Value) for key, value := range(valid_threads) { @@ -147,9 +139,14 @@ var GQLInterfaceThread = NewSingleton(func() *graphql.Interface { } } - _, ok = p.Value.(Thread) - if ok == true { - return ctx.GQL.BaseThreadType + node, ok := p.Value.(*Node) + if ok == false { + return nil + } + + _, err = GetExt[*ThreadExt](node) + if err == nil { + return ctx.GQLContext.BaseThreadType } return nil @@ -159,5 +156,17 @@ var GQLInterfaceThread = NewSingleton(func() *graphql.Interface { return gql_interface_thread }, func(thread *graphql.Interface, thread_list *graphql.List) { - addThreadInterfaceFields(thread, thread, thread_list) + thread.AddFieldConfig("Children", &graphql.Field{ + Type: thread_list, + }) + + thread.AddFieldConfig("Parent", &graphql.Field{ + Type: thread, + }) + + thread.AddFieldConfig("State", &graphql.Field{ + Type: graphql.String, + }) + + AddNodeInterfaceFields(thread) }) diff --git a/gql_resolvers.go b/gql_resolvers.go index fb2200e..e5023cc 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -5,12 +5,18 @@ import ( "github.com/graphql-go/graphql" ) -func PrepResolve(p graphql.ResolveParams) (*ResolveContext, error) { +func PrepResolve(p graphql.ResolveParams) (*Node, *ResolveContext, error) { resolve_context, ok := p.Context.Value("resolve").(*ResolveContext) if ok == false { - return nil, fmt.Errorf("Bad resolve in params context") + return nil, nil, fmt.Errorf("Bad resolve in params context") } - return resolve_context, nil + + node, ok := p.Source.(*Node) + if ok == false { + return nil, nil, fmt.Errorf("Source is not a *Node in PrepResolve") + } + + return node, resolve_context, nil } // TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does @@ -65,30 +71,38 @@ func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) { // TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted func GQLNodeID(p graphql.ResolveParams) (interface{}, error) { - node, ok := p.Source.(Node) - if ok == false || node == nil { - return nil, fmt.Errorf("Failed to cast source to Node") + node, _, err := PrepResolve(p) + if err != nil { + return nil, err } - return node.ID(), nil + return node.ID, nil } -func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) +func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { + node, _, err := PrepResolve(p) if err != nil { return nil, err } - node, ok := p.Source.(*GQLThread) - if ok == false || node == nil { - return nil, fmt.Errorf("Failed to cast source to GQLThread") + return string(node.Type), nil +} + +func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) { + node, ctx, err := PrepResolve(p) + if err != nil { + return nil, err } + gql_ext, err := GetExt[*GQLExt](node) + if err != nil { + return nil, err + } listen := "" context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewLockInfo(node, []string{"listen"}), func(context *StateContext) error { - listen = node.Listen + err = UseStates(context, ctx.User, NewACLInfo(node, []string{"listen"}), func(context *StateContext) error { + listen = gql_ext.Listen return nil }) @@ -100,20 +114,20 @@ func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) { } func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) + node, ctx, err := PrepResolve(p) if err != nil { return nil, err } - node, ok := p.Source.(*Thread) - if ok == false || node == nil { - return nil, fmt.Errorf("Failed to cast source to Thread") + thread_ext, err := GetExt[*ThreadExt](node) + if err != nil { + return nil, err } - var parent ThreadNode = nil + var parent *Node = nil context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewLockInfo(node, []string{"parent"}), func(context *StateContext) error { - parent = node.ThreadHandle().Parent + err = UseStates(context, ctx.User, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error { + parent = thread_ext.Parent return nil }) @@ -125,20 +139,20 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) { } func GQLThreadState(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) + node, ctx, err := PrepResolve(p) if err != nil { return nil, err } - node, ok := p.Source.(ThreadNode) - if ok == false || node == nil { - return nil, fmt.Errorf("Failed to cast source to Thread") + thread_ext, err := GetExt[*ThreadExt](node) + if err != nil { + return nil, err } var state string context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewLockInfo(node, []string{"state"}), func(context *StateContext) error { - state = node.ThreadHandle().StateName + err = UseStates(context, ctx.User, NewACLInfo(node, []string{"state"}), func(context *StateContext) error { + state = thread_ext.State return nil }) @@ -150,50 +164,20 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) { } func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - node, ok := p.Source.(ThreadNode) - if ok == false || node == nil { - return nil, fmt.Errorf("Failed to cast source to Thread") - } - - var children []ThreadNode = nil - context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewLockInfo(node, []string{"children"}), func(context *StateContext) error { - children = make([]ThreadNode, len(node.ThreadHandle().Children)) - i := 0 - for _, info := range(node.ThreadHandle().Children) { - children[i] = info.Child - i += 1 - } - return nil - }) - + node, ctx, err := PrepResolve(p) if err != nil { return nil, err } - return children, nil -} - -func GQLLockableName(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) + thread_ext, err := GetExt[*ThreadExt](node) if err != nil { return nil, err } - node, ok := p.Source.(LockableNode) - if ok == false || node == nil { - return nil, fmt.Errorf("Failed to cast source to Lockable") - } - - name := "" + var children []*Node = nil context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewLockInfo(node, []string{"name"}), func(context *StateContext) error { - name = node.LockableHandle().Name + err = UseStates(context, ctx.User, NewACLInfo(node, []string{"children"}), func(context *StateContext) error { + children = thread_ext.ChildList() return nil }) @@ -201,26 +185,26 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) { return nil, err } - return name, nil + return children, nil } func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) + node, ctx, err := PrepResolve(p) if err != nil { return nil, err } - node, ok := p.Source.(LockableNode) - if ok == false || node == nil { - return nil, fmt.Errorf("Failed to cast source to Lockable") + lockable_ext, err := GetExt[*LockableExt](node) + if err != nil { + return nil, err } - var requirements []LockableNode = nil + var requirements []*Node = nil context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewLockInfo(node, []string{"requirements"}), func(context *StateContext) error { - requirements = make([]LockableNode, len(node.LockableHandle().Requirements)) + err = UseStates(context, ctx.User, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error { + requirements = make([]*Node, len(lockable_ext.Requirements)) i := 0 - for _, req := range(node.LockableHandle().Requirements) { + for _, req := range(lockable_ext.Requirements) { requirements[i] = req i += 1 } @@ -235,22 +219,22 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) { } func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) + node, ctx, err := PrepResolve(p) if err != nil { return nil, err } - node, ok := p.Source.(LockableNode) - if ok == false || node == nil { - return nil, fmt.Errorf("Failed to cast source to Lockable") + lockable_ext, err := GetExt[*LockableExt](node) + if err != nil { + return nil, err } - var dependencies []LockableNode = nil + var dependencies []*Node = nil context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewLockInfo(node, []string{"dependencies"}), func(context *StateContext) error { - dependencies = make([]LockableNode, len(node.LockableHandle().Dependencies)) + err = UseStates(context, ctx.User, NewACLInfo(node, []string{"dependencies"}), func(context *StateContext) error { + dependencies = make([]*Node, len(lockable_ext.Dependencies)) i := 0 - for _, dep := range(node.LockableHandle().Dependencies) { + for _, dep := range(lockable_ext.Dependencies) { dependencies[i] = dep i += 1 } @@ -265,20 +249,20 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) { } func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) + node, ctx, err := PrepResolve(p) if err != nil { return nil, err } - node, ok := p.Source.(LockableNode) - if ok == false || node == nil { - return nil, fmt.Errorf("Failed to cast source to Lockable") + lockable_ext, err := GetExt[*LockableExt](node) + if err != nil { + return nil, err } - var owner Node = nil + var owner *Node = nil context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewLockInfo(node, []string{"owner"}), func(context *StateContext) error { - owner = node.LockableHandle().Owner + err = UseStates(context, ctx.User, NewACLInfo(node, []string{"owner"}), func(context *StateContext) error { + owner = lockable_ext.Owner return nil }) @@ -289,24 +273,24 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { return owner, nil } -func GQLGroupNodeUsers(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) +func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) { + node, ctx, err := PrepResolve(p) if err != nil { return nil, err } - node, ok := p.Source.(GroupNode) - if ok == false || node == nil { - return nil, fmt.Errorf("Failed to cast source to GQLThread") + group_ext, err := GetExt[*GroupExt](node) + if err != nil { + return nil, err } - var users []*User + var members []*Node context := NewReadContext(ctx.Context) - err = UseStates(context, ctx.User, NewLockInfo(node, []string{"users"}), func(context *StateContext) error { - users = make([]*User, len(node.Users())) + err = UseStates(context, ctx.User, NewACLInfo(node, []string{"users"}), func(context *StateContext) error { + members = make([]*Node, len(group_ext.Members)) i := 0 - for _, user := range(node.Users()) { - users[i] = user + for _, member := range(group_ext.Members) { + members[i] = member i += 1 } return nil @@ -316,7 +300,7 @@ func GQLGroupNodeUsers(p graphql.ResolveParams) (interface{}, error) { return nil, err } - return users, nil + return members, nil } func GQLSignalFn(p graphql.ResolveParams, fn func(GraphSignal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { diff --git a/gql_types.go b/gql_types.go index 53fad30..b130118 100644 --- a/gql_types.go +++ b/gql_types.go @@ -9,16 +9,16 @@ func AddNodeFields(obj *graphql.Object) { Type: graphql.String, Resolve: GQLNodeID, }) + + obj.AddFieldConfig("TypeHash", &graphql.Field{ + Type: graphql.String, + Resolve: GQLNodeTypeHash, + }) } func AddLockableFields(obj *graphql.Object) { AddNodeFields(obj) - obj.AddFieldConfig("Name", &graphql.Field{ - Type: graphql.String, - Resolve: GQLLockableName, - }) - obj.AddFieldConfig("Requirements", &graphql.Field{ Type: GQLInterfaceLockable.List, Resolve: GQLLockableRequirements, @@ -36,7 +36,7 @@ func AddLockableFields(obj *graphql.Object) { } func AddThreadFields(obj *graphql.Object) { - AddLockableFields(obj) + AddNodeFields(obj) obj.AddFieldConfig("State", &graphql.Field{ Type: graphql.String, @@ -54,56 +54,7 @@ func AddThreadFields(obj *graphql.Object) { }) } -var GQLTypeUser = NewSingleton(func() *graphql.Object { - gql_type_user := graphql.NewObject(graphql.ObjectConfig{ - Name: "User", - Interfaces: []*graphql.Interface{ - GQLInterfaceNode.Type, - GQLInterfaceLockable.Type, - }, - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(*User) - return ok - }, - Fields: graphql.Fields{}, - }) - - AddLockableFields(gql_type_user) - - return gql_type_user -}, nil) - -var GQLTypeGQLThread = NewSingleton(func() *graphql.Object { - gql_type_gql_thread := graphql.NewObject(graphql.ObjectConfig{ - Name: "GQLThread", - Interfaces: []*graphql.Interface{ - GQLInterfaceNode.Type, - GQLInterfaceThread.Type, - GQLInterfaceLockable.Type, - }, - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(*GQLThread) - return ok - }, - Fields: graphql.Fields{}, - }) - - AddThreadFields(gql_type_gql_thread) - - gql_type_gql_thread.AddFieldConfig("Users", &graphql.Field{ - Type: GQLTypeUser.List, - Resolve: GQLGroupNodeUsers, - }) - - gql_type_gql_thread.AddFieldConfig("Listen", &graphql.Field{ - Type: graphql.String, - Resolve: GQLThreadListen, - }) - - return gql_type_gql_thread -}, nil) - -var GQLTypeSimpleThread = NewSingleton(func() *graphql.Object { +var GQLTypeBaseThread = NewSingleton(func() *graphql.Object { gql_type_simple_thread := graphql.NewObject(graphql.ObjectConfig{ Name: "SimpleThread", Interfaces: []*graphql.Interface{ @@ -112,8 +63,13 @@ var GQLTypeSimpleThread = NewSingleton(func() *graphql.Object { GQLInterfaceLockable.Type, }, IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(Thread) - return ok + node, ok := p.Value.(*Node) + if ok == false { + return false + } + + _, err := GetExt[*ThreadExt](node) + return err == nil }, Fields: graphql.Fields{}, }) @@ -123,7 +79,7 @@ var GQLTypeSimpleThread = NewSingleton(func() *graphql.Object { return gql_type_simple_thread }, nil) -var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object { +var GQLTypeBaseLockable = NewSingleton(func() *graphql.Object { gql_type_simple_lockable := graphql.NewObject(graphql.ObjectConfig{ Name: "SimpleLockable", Interfaces: []*graphql.Interface{ @@ -131,8 +87,13 @@ var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object { GQLInterfaceLockable.Type, }, IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(Lockable) - return ok + node, ok := p.Value.(*Node) + if ok == false { + return false + } + + _, err := GetExt[*LockableExt](node) + return err == nil }, Fields: graphql.Fields{}, }) @@ -142,14 +103,14 @@ var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object { return gql_type_simple_lockable }, nil) -var GQLTypeSimpleNode = NewSingleton(func() *graphql.Object { +var GQLTypeBaseNode = NewSingleton(func() *graphql.Object { object := graphql.NewObject(graphql.ObjectConfig{ Name: "SimpleNode", Interfaces: []*graphql.Interface{ GQLInterfaceNode.Type, }, IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(Node) + _, ok := p.Value.(*Node) return ok }, Fields: graphql.Fields{}, diff --git a/lockable.go b/lockable.go index 6a412c8..08a761d 100644 --- a/lockable.go +++ b/lockable.go @@ -45,46 +45,57 @@ func (ext *LockableExt) Type() ExtType { return LockableExtType } +type LockableExtJSON struct { + Owner string `json:"owner"` + Requirements []string `json:"requirements"` + Dependencies []string `json:"dependencies"` + LocksHeld map[string]string `json:"locks_held"` +} + func (ext *LockableExt) Serialize() ([]byte, error) { - requirements := make([]string, len(ext.Requirements)) - req_n := 0 - for id, _ := range(ext.Requirements) { - requirements[req_n] = id.String() - req_n++ + return json.MarshalIndent(&LockableExtJSON{ + Owner: SaveNode(ext.Owner), + Requirements: SaveNodeList(ext.Requirements), + Dependencies: SaveNodeList(ext.Dependencies), + LocksHeld: SaveNodeMap(ext.LocksHeld), + }, "", " ") +} + +func LoadLockableExt(ctx *Context, data []byte) (Extension, error) { + var j LockableExtJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err } - dependencies := make([]string, len(ext.Dependencies)) - dep_n := 0 - for id, _ := range(ext.Dependencies) { - dependencies[dep_n] = id.String() - dep_n++ + owner, err := RestoreNode(ctx, j.Owner) + if err != nil { + return nil, err } - owner := "" - if ext.Owner != nil { - owner = ext.Owner.ID.String() + requirements, err := RestoreNodeList(ctx, j.Requirements) + if err != nil { + return nil, err } - locks_held := map[string]string{} - for lockable_id, node := range(ext.LocksHeld) { - if node == nil { - locks_held[lockable_id.String()] = "" - } else { - locks_held[lockable_id.String()] = node.ID.String() - } + dependencies, err := RestoreNodeList(ctx, j.Dependencies) + if err != nil { + return nil, err } - return json.MarshalIndent(&struct{ - Owner string `json:"owner"` - Requirements []string `json:"requirements"` - Dependencies []string `json:"dependencies"` - LocksHeld map[string]string `json:"locks_held"` - }{ + locks_held, err := RestoreNodeMap(ctx, j.LocksHeld) + if err != nil { + return nil, err + } + + extension := LockableExt{ Owner: owner, Requirements: requirements, Dependencies: dependencies, LocksHeld: locks_held, - }, "", " ") + } + + return &extension, nil } func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphSignal) error { @@ -469,6 +480,14 @@ func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node) }) } +func SaveNode(node *Node) string { + str := "" + if node != nil { + str = node.ID.String() + } + return str +} + func RestoreNode(ctx *Context, id_str string) (*Node, error) { id, err := ParseID(id_str) if err != nil { @@ -478,6 +497,14 @@ func RestoreNode(ctx *Context, id_str string) (*Node, error) { return LoadNode(ctx, id) } +func SaveNodeMap(nodes NodeMap) map[string]string { + m := map[string]string{} + for id, node := range(nodes) { + m[id.String()] = SaveNode(node) + } + return m +} + func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) { nodes := NodeMap{} for id_str_1, id_str_2 := range(ids) { @@ -507,6 +534,17 @@ func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) { return nodes, nil } +func SaveNodeList(nodes NodeMap) []string { + ids := make([]string, len(nodes)) + i := 0 + for id, _ := range(nodes) { + ids[i] = id.String() + i += 1 + } + + return ids +} + func RestoreNodeList(ctx *Context, ids []string) (NodeMap, error) { nodes := NodeMap{} diff --git a/node.go b/node.go index 9ed503e..21a58bc 100644 --- a/node.go +++ b/node.go @@ -73,6 +73,7 @@ type Extension interface { // Nodes represent an addressible group of extensions type Node struct { ID NodeID + Type NodeType Lock sync.RWMutex ExtensionMap map[ExtType]Extension } @@ -93,65 +94,6 @@ func GetExt[T Extension](node *Node) (T, error) { return ret, nil } -// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies -type ACLExtension struct { - Delegations NodeMap -} - -func (ext ACLExtension) Process(context *StateContext, node *Node, signal GraphSignal) error { - return nil -} - -func LoadACLExtension(ctx *Context, data []byte) (Extension, error) { - var j struct { - Delegations []string `json:"delegation"` - } - - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - - delegations := NodeMap{} - for _, str := range(j.Delegations) { - id, err := ParseID(str) - if err != nil { - return nil, err - } - - node, err := LoadNode(ctx, id) - if err != nil { - return nil, err - } - - delegations[id] = node - } - - return ACLExtension{ - Delegations: delegations, - }, nil -} - -func (ext ACLExtension) Serialize() ([]byte, error) { - delegations := make([]string, len(ext.Delegations)) - i := 0 - for id, _ := range(ext.Delegations) { - delegations[i] = id.String() - i += 1 - } - - return json.MarshalIndent(&struct{ - Delegations []string `json:"delegations"` - }{ - Delegations: delegations, - }, "", " ") -} - -const ACLExtType = ExtType("ACL") -func (extension ACLExtension) Type() ExtType { - return ACLExtType -} - func (node *Node) Serialize() ([]byte, error) { extensions := make([]ExtensionDB, len(node.ExtensionMap)) node_db := NodeDB{ @@ -181,9 +123,10 @@ func (node *Node) Serialize() ([]byte, error) { return node_db.Serialize(), nil } -func NewNode(id NodeID) Node { +func NewNode(id NodeID, node_type NodeType) Node { return Node{ ID: id, + Type: node_type, ExtensionMap: map[ExtType]Extension{}, } } @@ -198,15 +141,15 @@ func Allowed(context *StateContext, principal *Node, action string, node *Node) if exists == false { return fmt.Errorf("%s does not have ACL extension, other nodes cannot perform actions on it", node.ID) } - acl_ext := ext.(ACLExtension) + acl_ext := ext.(ACLExt) for _, policy_node := range(acl_ext.Delegations) { ext, exists := policy_node.ExtensionMap[ACLPolicyExtType] if exists == false { - context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExtension") + context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExt") continue } - policy_ext := ext.(ACLPolicyExtension) + policy_ext := ext.(ACLPolicyExt) if policy_ext.Allows(context, principal, action, node) == true { context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s", principal.ID, node.ID, action) return nil @@ -238,11 +181,12 @@ func Signal(context *StateContext, node *Node, princ *Node, signal GraphSignal) // Magic first four bytes of serialized DB content, stored big endian const NODE_DB_MAGIC = 0x2491df14 // Total length of the node database header, has magic to verify and type_hash to map to load function -const NODE_DB_HEADER_LEN = 8 +const NODE_DB_HEADER_LEN = 16 // A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node type NodeDBHeader struct { Magic uint32 NumExtensions uint32 + TypeHash uint64 } type NodeDB struct { @@ -258,6 +202,7 @@ func NewNodeDB(data []byte) (NodeDB, error) { magic := binary.BigEndian.Uint32(data[0:4]) num_extensions := binary.BigEndian.Uint32(data[4:8]) + node_type_hash := binary.BigEndian.Uint64(data[8:16]) ptr += NODE_DB_HEADER_LEN @@ -290,6 +235,7 @@ func NewNodeDB(data []byte) (NodeDB, error) { return NodeDB{ Header: NodeDBHeader{ Magic: magic, + TypeHash: node_type_hash, NumExtensions: num_extensions, }, Extensions: extensions, @@ -304,6 +250,7 @@ func (header NodeDBHeader) Serialize() []byte { ret := make([]byte, NODE_DB_HEADER_LEN) binary.BigEndian.PutUint32(ret[0:4], header.Magic) binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions) + binary.BigEndian.PutUint64(ret[8:16], header.TypeHash) return ret } @@ -411,8 +358,13 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, err } + node_type, known := ctx.Types[node_db.Header.TypeHash] + if known == false { + return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash) + } + // Create the blank node with the ID, and add it to the context - new_node := NewNode(id) + new_node := NewNode(id, node_type.Type) node = &new_node ctx.Nodes[id] = node @@ -476,6 +428,12 @@ func ACLList(list []*Node, resources []string) ACLMap { return reqs } +type NodeType string +func (node NodeType) Hash() uint64 { + hash := sha512.Sum512([]byte(fmt.Sprintf("NODE: %s", string(node)))) + return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)]) +} + type PolicyType string func (policy PolicyType) Hash() uint64 { hash := sha512.Sum512([]byte(fmt.Sprintf("POLICY: %s", string(policy)))) diff --git a/policy.go b/policy.go index 8360e79..54aa110 100644 --- a/policy.go +++ b/policy.go @@ -32,10 +32,58 @@ func (policy AllNodesPolicy) Serialize() ([]byte, error) { } // Extension to allow a node to hold ACL policies -type ACLPolicyExtension struct { +type ACLPolicyExt struct { Policies map[PolicyType]Policy } +// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies +type ACLExt struct { + Delegations NodeMap +} + +func (ext ACLExt) Process(context *StateContext, node *Node, signal GraphSignal) error { + return nil +} + +func LoadACLExt(ctx *Context, data []byte) (Extension, error) { + var j struct { + Delegations []string `json:"delegation"` + } + + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + delegations, err := RestoreNodeList(ctx, j.Delegations) + if err != nil { + return nil, err + } + + return ACLExt{ + Delegations: delegations, + }, nil +} + +func (ext ACLExt) Serialize() ([]byte, error) { + delegations := make([]string, len(ext.Delegations)) + i := 0 + for id, _ := range(ext.Delegations) { + delegations[i] = id.String() + i += 1 + } + + return json.MarshalIndent(&struct{ + Delegations []string `json:"delegations"` + }{ + Delegations: delegations, + }, "", " ") +} + +const ACLExtType = ExtType("ACL") +func (extension ACLExt) Type() ExtType { + return ACLExtType +} type PolicyLoadFunc func(*Context, []byte) (Policy, error) type PolicyInfo struct { @@ -43,11 +91,15 @@ type PolicyInfo struct { Type PolicyType } -type ACLPolicyExtensionContext struct { +type ACLPolicyExtContext struct { Types map[PolicyType]PolicyInfo } -func (ext ACLPolicyExtension) Serialize() ([]byte, error) { +func NewACLPolicyExtContext() *ACLPolicyExtContext { + return nil +} + +func (ext ACLPolicyExt) Serialize() ([]byte, error) { policies := map[string][]byte{} for name, policy := range(ext.Policies) { ser, err := policy.Serialize() @@ -64,11 +116,11 @@ func (ext ACLPolicyExtension) Serialize() ([]byte, error) { }, "", " ") } -func (ext ACLPolicyExtension) Process(context *StateContext, node *Node, signal GraphSignal) error { +func (ext ACLPolicyExt) Process(context *StateContext, node *Node, signal GraphSignal) error { return nil } -func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) { +func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) { var j struct { Policies map[string][]byte `json:"policies"` } @@ -78,7 +130,7 @@ func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) { } policies := map[PolicyType]Policy{} - acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(ACLPolicyExtensionContext) + acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(ACLPolicyExtContext) for name, ser := range(j.Policies) { policy_def, exists := acl_ctx.Types[PolicyType(name)] if exists == false { @@ -92,18 +144,18 @@ func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) { policies[PolicyType(name)] = policy } - return ACLPolicyExtension{ + return ACLPolicyExt{ Policies: policies, }, nil } const ACLPolicyExtType = ExtType("ACL_POLICIES") -func (ext ACLPolicyExtension) Type() ExtType { +func (ext ACLPolicyExt) Type() ExtType { return ACLPolicyExtType } // Check if the extension allows the principal to perform action on node -func (ext ACLPolicyExtension) Allows(context *StateContext, principal *Node, action string, node *Node) bool { +func (ext ACLPolicyExt) Allows(context *StateContext, principal *Node, action string, node *Node) bool { for _, policy := range(ext.Policies) { if policy.Allows(context, principal, action, node) == true { return true diff --git a/thread.go b/thread.go index 9a7d672..67b00df 100644 --- a/thread.go +++ b/thread.go @@ -8,6 +8,19 @@ import ( "encoding/json" ) +type QueuedAction struct { + Timeout time.Time `json:"time"` + Action string `json:"action"` +} + +type ThreadExtContext struct { + Loads map[InfoType]func([]byte)ThreadInfo +} + +func NewThreadExtContext() *ThreadExtContext { + return nil +} + type ThreadExt struct { Actions ThreadActions Handlers ThreadHandlers @@ -19,7 +32,7 @@ type ThreadExt struct { ActiveLock sync.Mutex Active bool - StateName string + State string Parent *Node Children map[NodeID]ChildInfo @@ -28,15 +41,94 @@ type ThreadExt struct { NextAction *QueuedAction } +type ThreadExtJSON struct { + State string `json:"state"` + Parent string `json:"parent"` + Children map[string][]byte `json:"children"` + ActionQueue []QueuedAction +} + func (ext *ThreadExt) Serialize() ([]byte, error) { return nil, fmt.Errorf("NOT_IMPLEMENTED") } +const THREAD_BUFFER_SIZE int = 1024 +func LoadThreadExt(ctx *Context, data []byte) (Extension, error) { + var j ThreadExtJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + parent, err := RestoreNode(ctx, j.Parent) + if err != nil { + return nil, err + } + + children := map[NodeID]ChildInfo{} + for id_str, _ := range(j.Children) { + child_node, err := RestoreNode(ctx, id_str) + if err != nil { + return nil, err + } + //TODO: Restore child info based off context + + children[child_node.ID] = ChildInfo{ + Child: child_node, + Infos: map[InfoType]ThreadInfo{}, + } + } + + next_action, timeout_chan := SoonestAction(j.ActionQueue) + + extension := ThreadExt{ + Actions: BaseThreadActions, + Handlers: BaseThreadHandlers, + SignalChan: make(chan GraphSignal, THREAD_BUFFER_SIZE), + TimeoutChan: timeout_chan, + Active: false, + State: j.State, + Parent: parent, + Children: children, + ActionQueue: j.ActionQueue, + NextAction: next_action, + } + + return &extension, nil +} + const ThreadExtType = ExtType("THREAD") func (ext *ThreadExt) Type() ExtType { return ThreadExtType } +func (ext *ThreadExt) QueueAction(end time.Time, action string) { + ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action}) + ext.NextAction, ext.TimeoutChan = SoonestAction(ext.ActionQueue) +} + +func (ext *ThreadExt) ClearActionQueue() { + ext.ActionQueue = []QueuedAction{} + ext.NextAction = nil + ext.TimeoutChan = nil +} + +func SoonestAction(actions []QueuedAction) (*QueuedAction, <-chan time.Time) { + var soonest_action *QueuedAction + var soonest_time time.Time + for _, action := range(actions) { + if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil { + soonest_action = &action + soonest_time = action.Timeout + } + } + if soonest_action != nil { + return soonest_action, time.After(time.Until(soonest_action.Timeout)) + } else { + return nil, nil + } +} + func (ext *ThreadExt) ChildList() []*Node { ret := make([]*Node, len(ext.Children)) i := 0 @@ -235,38 +327,6 @@ func NewChildInfo(child *Node, infos map[InfoType]ThreadInfo) ChildInfo { } } -type QueuedAction struct { - Timeout time.Time `json:"time"` - Action string `json:"action"` -} - -func (ext *ThreadExt) QueueAction(end time.Time, action string) { - ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action}) - ext.NextAction, ext.TimeoutChan = ext.SoonestAction() -} - -func (ext *ThreadExt) ClearActionQueue() { - ext.ActionQueue = []QueuedAction{} - ext.NextAction = nil - ext.TimeoutChan = nil -} - -func (ext *ThreadExt) SoonestAction() (*QueuedAction, <-chan time.Time) { - var soonest_action *QueuedAction - var soonest_time time.Time - for _, action := range(ext.ActionQueue) { - if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil { - soonest_action = &action - soonest_time = action.Timeout - } - } - if soonest_action != nil { - return soonest_action, time.After(time.Until(soonest_action.Timeout)) - } else { - return nil, nil - } -} - var deserializers = map[InfoType]func(interface{})(interface{}, error) { "parent": func(raw interface{})(interface{}, error) { m, ok := raw.(map[string]interface{}) @@ -294,14 +354,14 @@ var deserializers = map[InfoType]func(interface{})(interface{}, error) { }, } -func NewThreadExt(buffer int, name string, state_name string, actions ThreadActions, handlers ThreadHandlers) ThreadExt { +func NewThreadExt(buffer int, name string, state string, actions ThreadActions, handlers ThreadHandlers) ThreadExt { return ThreadExt{ Actions: actions, Handlers: handlers, SignalChan: make(chan GraphSignal, buffer), TimeoutChan: nil, Active: false, - StateName: state_name, + State: state, Parent: nil, Children: map[NodeID]ChildInfo{}, ActionQueue: []QueuedAction{}, @@ -322,7 +382,7 @@ func (ext *ThreadExt) SetActive(active bool) error { } func (ext *ThreadExt) SetState(state string) error { - ext.StateName = state + ext.State = state return nil } @@ -485,7 +545,7 @@ func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start boo } parent_info := info.Infos["parent"].(*ParentThreadInfo) - if parent_info.Start == true && child_ext.StateName != "finished" { + if parent_info.Start == true && child_ext.State != "finished" { ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID) if start == true { ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction) @@ -537,7 +597,7 @@ func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, err context := NewWriteContext(ctx) err := UpdateStates(context, thread, NewACLMap(NewACLInfo(thread, []string{"timeout"})), func(context *StateContext) error { timeout_action = thread_ext.NextAction.Action - thread_ext.NextAction, thread_ext.TimeoutChan = thread_ext.SoonestAction() + thread_ext.NextAction, thread_ext.TimeoutChan = SoonestAction(thread_ext.ActionQueue) return nil }) if err != nil { diff --git a/user.go b/user.go index 80b85d1..67ce10a 100644 --- a/user.go +++ b/user.go @@ -8,47 +8,47 @@ import ( "crypto/x509" ) -type GroupNode interface { - Node - Users() map[NodeID]*User -} - -type User struct { - Lockable - +type ECDHExt struct { Granted time.Time Pubkey *ecdsa.PublicKey Shared []byte - Tags []string } -type UserJSON struct { - LockableJSON +type ECDHExtJSON struct { Granted time.Time `json:"granted"` Pubkey []byte `json:"pubkey"` Shared []byte `json:"shared"` } -func (user *User) Type() NodeType { - return NodeType("user") +func (ext *ECDHExt) Process(context *StateContext, node *Node, signal GraphSignal) error { + return nil } -func (user *User) Serialize() ([]byte, error) { - lockable_json := NewLockableJSON(&user.Lockable) - pubkey, err := x509.MarshalPKIXPublicKey(user.Pubkey) +const ECDHExtType = ExtType("ECDH") +func (ext *ECDHExt) Type() ExtType { + return ECDHExtType +} + +func (ext *ECDHExt) Serialize() ([]byte, error) { + pubkey, err := x509.MarshalPKIXPublicKey(ext.Pubkey) if err != nil { return nil, err } - return json.MarshalIndent(&UserJSON{ - LockableJSON: lockable_json, - Granted: user.Granted, - Shared: user.Shared, + return json.MarshalIndent(&ECDHExtJSON{ + Granted: ext.Granted, Pubkey: pubkey, + Shared: ext.Shared, }, "", " ") } -var LoadUser = LoadJSONNode(func(id NodeID, j UserJSON) (Node, error) { +func LoadECDHExt(ctx *Context, data []byte) (Extension, error) { + var j ECDHExtJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + pub, err := x509.ParsePKIXPublicKey(j.Pubkey) if err != nil { return nil, err @@ -59,83 +59,56 @@ var LoadUser = LoadJSONNode(func(id NodeID, j UserJSON) (Node, error) { case *ecdsa.PublicKey: pubkey = pub.(*ecdsa.PublicKey) default: - return nil, fmt.Errorf("Invalid key type") + return nil, fmt.Errorf("Invalid key type: %+v", pub) } - user := NewUser(j.Name, j.Granted, pubkey, j.Shared) - return &user, nil -}, func(ctx *Context, user *User, j UserJSON, nodes NodeMap) error { - return RestoreLockable(ctx, user, j.LockableJSON, nodes) -}) - -func NewUser(name string, granted time.Time, pubkey *ecdsa.PublicKey, shared []byte) User { - id := KeyID(pubkey) - return User{ - Lockable: NewLockable(id, name), - Granted: granted, + extension := ECDHExt{ + Granted: j.Granted, Pubkey: pubkey, - Shared: shared, + Shared: j.Shared, } -} - -type Group struct { - Lockable - UserMap map[NodeID]*User + return &extension, nil } -func NewGroup(id NodeID, name string) Group { - return Group{ - Lockable: NewLockable(id, name), - UserMap: map[NodeID]*User{}, - } +type GroupExt struct { + Members NodeMap } -type GroupJSON struct { - LockableJSON - Users []string `json:"users"` +const GroupExtType = ExtType("GROUP") +func (ext *GroupExt) Type() ExtType { + return GroupExtType } -func (group *Group) Type() NodeType { - return NodeType("group") +func (ext *GroupExt) Serialize() ([]byte, error) { + return json.MarshalIndent(&struct{ + Members []string `json:"members"` + }{ + Members: SaveNodeList(ext.Members), + }, "", " ") } -func (group *Group) Serialize() ([]byte, error) { - users := make([]string, len(group.UserMap)) - i := 0 - for id, _ := range(group.UserMap) { - users[i] = id.String() - i += 1 +func LoadGroupExt(ctx *Context, data []byte) (Extension, error) { + var j struct { + Members []string `json:"members"` } - return json.MarshalIndent(&GroupJSON{ - LockableJSON: NewLockableJSON(&group.Lockable), - Users: users, - }, "", " ") -} + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } -var LoadGroup = LoadJSONNode(func(id NodeID, j GroupJSON) (Node, error) { - group := NewGroup(id, j.Name) - return &group, nil -}, func(ctx *Context, group *Group, j GroupJSON, nodes NodeMap) error { - for _, id_str := range(j.Users) { - id, err := ParseID(id_str) - if err != nil { - return err - } - - user_node, err := LoadNodeRecurse(ctx, id, nodes) - if err != nil { - return err - } - - user, ok := user_node.(*User) - if ok == false { - return fmt.Errorf("%s is not a *User", id_str) - } - - group.UserMap[id] = user + members, err := RestoreNodeList(ctx, j.Members) + if err != nil { + return nil, err + } + + extension := GroupExt{ + Members: members, } + return &extension, nil +} - return RestoreLockable(ctx, group, j.LockableJSON, nodes) -}) +func (ext *GroupExt) Process(context *StateContext, node *Node, signal GraphSignal) error { + return nil +}