Added gql to the rework

graph-rework-2
noah metz 2023-07-26 00:18:11 -06:00
parent ff813d6c2b
commit f1c0f1e7de
10 changed files with 695 additions and 572 deletions

@ -5,13 +5,20 @@ import (
"fmt" "fmt"
) )
//Function to load an extension from bytes
type ExtensionLoadFunc func(*Context, []byte) (Extension, error) type ExtensionLoadFunc func(*Context, []byte) (Extension, error)
// Information about a loaded extension
type ExtensionInfo struct { type ExtensionInfo struct {
Load ExtensionLoadFunc Load ExtensionLoadFunc
Type ExtType Type ExtType
Data interface{} Data interface{}
} }
// Information about a loaded node type
type NodeInfo struct {
Type NodeType
}
// A Context is all the data needed to run a graphvent // A Context is all the data needed to run a graphvent
type Context struct { type Context struct {
// DB is the database connection used to load and write nodes // DB is the database connection used to load and write nodes
@ -20,6 +27,8 @@ type Context struct {
Log Logger Log Logger
// A mapping between type hashes and their corresponding extension definitions // A mapping between type hashes and their corresponding extension definitions
Extensions map[uint64]ExtensionInfo Extensions map[uint64]ExtensionInfo
// A mapping between type hashes and their corresponding node definitions
Types map[uint64]NodeInfo
// All loaded Nodes // All loaded Nodes
Nodes map[NodeID]*Node Nodes map[NodeID]*Node
} }
@ -30,8 +39,21 @@ func (ctx *Context) ExtByType(ext_type ExtType) ExtensionInfo {
return ext return ext
} }
func (ctx *Context) RegisterNodeType(node_type NodeType) error {
type_hash := node_type.Hash()
_, exists := ctx.Types[type_hash]
if exists == true {
return fmt.Errorf("Cannot register node type %s, type already exists in context", node_type)
}
ctx.Types[type_hash] = NodeInfo{
Type: node_type,
}
return nil
}
// Add a node to a context, returns an error if the def is invalid or already exists in the context // Add a node to a context, returns an error if the def is invalid or already exists in the context
func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc) error { func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc, data interface{}) error {
if load_fn == nil { if load_fn == nil {
return fmt.Errorf("def has no load function") return fmt.Errorf("def has no load function")
} }
@ -45,6 +67,7 @@ func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFun
ctx.Extensions[type_hash] = ExtensionInfo{ ctx.Extensions[type_hash] = ExtensionInfo{
Load: load_fn, Load: load_fn,
Type: ext_type, Type: ext_type,
Data: data,
} }
return nil return nil
} }
@ -55,15 +78,41 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
DB: db, DB: db,
Log: log, Log: log,
Extensions: map[uint64]ExtensionInfo{}, Extensions: map[uint64]ExtensionInfo{},
Types: map[uint64]NodeInfo{},
Nodes: map[NodeID]*Node{}, Nodes: map[NodeID]*Node{},
} }
err := ctx.RegisterExtension(ACLExtType, LoadACLExtension) err := ctx.RegisterExtension(ACLExtType, LoadACLExt, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(ACLPolicyExtType, LoadACLPolicyExt, NewACLPolicyExtContext())
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(LockableExtType, LoadLockableExt, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(ThreadExtType, LoadThreadExt, NewThreadExtContext())
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(ECDHExtType, LoadECDHExt, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(GroupExtType, LoadGroupExt, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterExtension(ACLPolicyExtType, LoadACLPolicyExtension) err = ctx.RegisterExtension(GQLExtType, LoadGQLExt, NewGQLExtContext())
if err != nil { if err != nil {
return nil, err return nil, err
} }

313
gql.go

@ -70,7 +70,7 @@ type AuthRespJSON struct {
Signature []byte `json:"signature"` Signature []byte `json:"signature"`
} }
func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) { func NewAuthRespJSON(gql_ext *GQLExt, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) {
// Check if req.Time is within +- 1 second of now // Check if req.Time is within +- 1 second of now
now := time.Now() now := time.Now()
earliest := now.Add(-1 * time.Second) earliest := now.Add(-1 * time.Second)
@ -82,12 +82,12 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time) return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time)
} }
x, y := elliptic.Unmarshal(thread.Key.Curve, req.Pubkey) x, y := elliptic.Unmarshal(gql_ext.Key.Curve, req.Pubkey)
if x == nil { if x == nil {
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey) return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey)
} }
remote, err := thread.ECDH.NewPublicKey(req.ECDHPubkey) remote, err := gql_ext.ECDH.NewPublicKey(req.ECDHPubkey)
if err != nil { if err != nil {
return AuthRespJSON{}, nil, nil, err return AuthRespJSON{}, nil, nil, err
} }
@ -98,7 +98,7 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P
sig_hash := sha512.Sum512(sig_data) sig_hash := sha512.Sum512(sig_data)
remote_key := &ecdsa.PublicKey{ remote_key := &ecdsa.PublicKey{
Curve: thread.Key.Curve, Curve: gql_ext.Key.Curve,
X: x, X: x,
Y: y, Y: y,
} }
@ -113,7 +113,7 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req) return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req)
} }
ec_key, err := thread.ECDH.GenerateKey(rand.Reader) ec_key, err := gql_ext.ECDH.GenerateKey(rand.Reader)
if err != nil { if err != nil {
return AuthRespJSON{}, nil, nil, err return AuthRespJSON{}, nil, nil, err
} }
@ -125,7 +125,7 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P
resp_sig_data := append(ec_key_pub, time_ser...) resp_sig_data := append(ec_key_pub, time_ser...)
resp_sig_hash := sha512.Sum512(resp_sig_data) resp_sig_hash := sha512.Sum512(resp_sig_data)
resp_sig, err := ecdsa.SignASN1(rand.Reader, thread.Key, resp_sig_hash[:]) resp_sig, err := ecdsa.SignASN1(rand.Reader, gql_ext.Key, resp_sig_hash[:])
if err != nil { if err != nil {
return AuthRespJSON{}, nil, nil, err return AuthRespJSON{}, nil, nil, err
} }
@ -156,7 +156,7 @@ func ParseAuthRespJSON(resp AuthRespJSON, ecdsa_curve elliptic.Curve, ecdh_curve
return shared_secret, nil return shared_secret, nil
} }
func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *http.Request) { func AuthHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx.Log.Logf("gql", "GQL_AUTH_REQUEST: %s", r.RemoteAddr) ctx.Log.Logf("gql", "GQL_AUTH_REQUEST: %s", r.RemoteAddr)
enableCORS(&w) enableCORS(&w)
@ -174,7 +174,7 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
return return
} }
resp, remote_id, shared, err := NewAuthRespJSON(server, req) resp, _, _, err := NewAuthRespJSON(gql_ext, req)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %s", err)
return return
@ -195,34 +195,31 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
return return
} }
key_id := KeyID(remote_id) /*if exists {
_, exists := server.UserMap[key_id]
if exists {
ctx.Log.Logf("gql", "REFRESHING AUTH FOR %s", key_id) ctx.Log.Logf("gql", "REFRESHING AUTH FOR %s", key_id)
} else { } else {
ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, shared) ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, shared)
new_user := NewUser(fmt.Sprintf("GQL_USER %s", key_id.String()), time.Now(), remote_id, shared) new_user := NewUser(fmt.Sprintf("GQL_USER %s", key_id.String()), time.Now(), remote_id, shared)
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err := UpdateStates(context, server, NewLockMap(LockMap{ err := UpdateStates(context, server, ACLMap{
server.ID(): LockInfo{ server.ID: ACLInfo{
Node: server, Node: server,
Resources: []string{"users"}, Resources: []string{"users"},
}, },
new_user.ID(): LockInfo{ new_user.ID: ACLInfo{
Node: &new_user, Node: &new_user,
Resources: nil, Resources: nil,
}, },
}), func(context *StateContext) error { }, func(context *StateContext) error {
server.UserMap[key_id] = &new_user server.Users[key_id] = &new_user
return nil return nil
}) })
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_UPDATE_ERR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_UPDATE_ERR: %s", err)
return return
} }
} }*/
} }
} }
@ -363,11 +360,13 @@ func checkForAuthHeader(header http.Header) (string, bool) {
type ResolveContext struct { type ResolveContext struct {
Context *Context Context *Context
Server *GQLThread GQLContext *GQLExtContext
User *User Server *Node
Ext *GQLExt
User *Node
} }
func NewResolveContext(ctx *Context, server *GQLThread, r *http.Request) (*ResolveContext, error) { func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) {
username, password, ok := r.BasicAuth() username, password, ok := r.BasicAuth()
if ok == false { if ok == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header")
@ -378,25 +377,29 @@ func NewResolveContext(ctx *Context, server *GQLThread, r *http.Request) (*Resol
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username)
} }
user, exists := server.UserMap[auth_id] user, exists := gql_ext.Users[auth_id]
if exists == false { if exists == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no existing authorization for client %s", auth_id) return nil, fmt.Errorf("GQL_REQUEST_ERR: no existing authorization for client %s", auth_id)
} }
if base64.StdEncoding.EncodeToString(user.Shared) != password { user_ext, err := GetExt[*ECDHExt](user)
if err != nil {
return nil, err
}
if base64.StdEncoding.EncodeToString(user_ext.Shared) != password {
return nil, fmt.Errorf("GQL_AUTH_FAIL") return nil, fmt.Errorf("GQL_AUTH_FAIL")
} }
return &ResolveContext{ return &ResolveContext{
Context: ctx, Context: ctx,
GQLContext: ctx.Extensions[GQLExtType.Hash()].Data.(*GQLExtContext),
Server: server, Server: server,
User: user, User: user,
}, nil }, nil
} }
func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) { func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
gql_ctx := context.Background()
return func(w http.ResponseWriter, r * http.Request) { return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr) ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr)
enableCORS(&w) enableCORS(&w)
@ -406,7 +409,7 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht
} }
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, r) resolve_context, err := NewResolveContext(ctx, server, gql_ext, r)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err))) json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err)))
@ -414,7 +417,7 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht
} }
req_ctx := context.Background() req_ctx := context.Background()
req_ctx = context.WithValue(gql_ctx, "resolve", resolve_context) req_ctx = context.WithValue(req_ctx, "resolve", resolve_context)
str, err := io.ReadAll(r.Body) str, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
@ -425,8 +428,10 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht
query := GQLPayload{} query := GQLPayload{}
json.Unmarshal(str, &query) json.Unmarshal(str, &query)
gql_context := ctx.Extensions[GQLExtType.Hash()].Data.(*GQLExtContext)
params := graphql.Params{ params := graphql.Params{
Schema: ctx.GQL.Schema, Schema: gql_context.Schema,
Context: req_ctx, Context: req_ctx,
RequestString: query.Query, RequestString: query.Query,
} }
@ -494,11 +499,7 @@ func GQLWSDo(ctx * Context, p graphql.Params) chan *graphql.Result {
return sendOneResultAndClose(res) return sendOneResultAndClose(res)
} }
func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) { func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
gql_ctx := context.Background()
gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx)
gql_ctx = context.WithValue(gql_ctx, "gql_server", server)
return func(w http.ResponseWriter, r * http.Request) { return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr) ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr)
enableCORS(&w) enableCORS(&w)
@ -508,7 +509,7 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *
} }
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, r) resolve_context, err := NewResolveContext(ctx, server, gql_ext, r)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
return return
@ -557,8 +558,9 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *
} }
} else if msg.Type == "subscribe" { } else if msg.Type == "subscribe" {
ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload) ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload)
gql_context := ctx.Extensions[GQLExtType.Hash()].Data.(*GQLExtContext)
params := graphql.Params{ params := graphql.Params{
Schema: ctx.GQL.Schema, Schema: gql_context.Schema,
Context: req_ctx, Context: req_ctx,
RequestString: msg.Payload.Query, RequestString: msg.Payload.Query,
} }
@ -628,35 +630,94 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *
} }
} }
type GQLThread struct { // Map of go types to graphql types
Thread type ObjTypeMap map[reflect.Type]*graphql.Object
// GQL Specific Context information
type GQLExtContext struct {
// Generated GQL schema
Schema graphql.Schema
// List of GQL types
TypeList []graphql.Type
// Interface type maps to map go types of specific interfaces to gql types
ValidNodes ObjTypeMap
ValidLockables ObjTypeMap
ValidThreads ObjTypeMap
BaseNodeType *graphql.Object
BaseLockableType *graphql.Object
BaseThreadType *graphql.Object
Query *graphql.Object
Mutation *graphql.Object
Subscription *graphql.Object
}
func NewGQLExtContext() *GQLExtContext {
query := graphql.NewObject(graphql.ObjectConfig{
Name: "Query",
Fields: graphql.Fields{},
})
mutation := graphql.NewObject(graphql.ObjectConfig{
Name: "Mutation",
Fields: graphql.Fields{},
})
subscription := graphql.NewObject(graphql.ObjectConfig{
Name: "Subscription",
Fields: graphql.Fields{},
})
context := GQLExtContext{
Schema: graphql.Schema{},
TypeList: []graphql.Type{},
ValidNodes: ObjTypeMap{},
ValidThreads: ObjTypeMap{},
ValidLockables: ObjTypeMap{},
Query: query,
Mutation: mutation,
Subscription: subscription,
BaseNodeType: GQLTypeBaseNode.Type,
BaseLockableType: GQLTypeBaseLockable.Type,
BaseThreadType: GQLTypeBaseThread.Type,
}
return &context
}
type GQLExt struct {
tcp_listener net.Listener tcp_listener net.Listener
http_server *http.Server http_server *http.Server
http_done *sync.WaitGroup http_done sync.WaitGroup
tls_key []byte tls_key []byte
tls_cert []byte tls_cert []byte
Listen string Listen string
UserMap map[NodeID]*User Users NodeMap
Key *ecdsa.PrivateKey Key *ecdsa.PrivateKey
ECDH ecdh.Curve ECDH ecdh.Curve
SubscribeLock sync.Mutex SubscribeLock sync.Mutex
SubscribeListeners []chan GraphSignal SubscribeListeners []chan GraphSignal
} }
func (thread *GQLThread) NewSubscriptionChannel(buffer int) chan GraphSignal { func (ext *GQLExt) NewSubscriptionChannel(buffer int) chan GraphSignal {
thread.SubscribeLock.Lock() ext.SubscribeLock.Lock()
defer thread.SubscribeLock.Unlock() defer ext.SubscribeLock.Unlock()
new_listener := make(chan GraphSignal, buffer) new_listener := make(chan GraphSignal, buffer)
thread.SubscribeListeners = append(thread.SubscribeListeners, new_listener) ext.SubscribeListeners = append(ext.SubscribeListeners, new_listener)
return new_listener return new_listener
} }
func (thread *GQLThread) Process(context *StateContext, signal GraphSignal) error { func (ext *GQLExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
ext.SubscribeLock.Lock()
defer ext.SubscribeLock.Unlock()
active_listeners := []chan GraphSignal{} active_listeners := []chan GraphSignal{}
thread.SubscribeLock.Lock() for _, listener := range(ext.SubscribeListeners) {
for _, listener := range(thread.SubscribeListeners) {
select { select {
case listener <- signal: case listener <- signal:
active_listeners = append(active_listeners, listener) active_listeners = append(active_listeners, listener)
@ -667,34 +728,38 @@ func (thread *GQLThread) Process(context *StateContext, signal GraphSignal) erro
}(listener) }(listener)
} }
} }
thread.SubscribeListeners = active_listeners ext.SubscribeListeners = active_listeners
thread.SubscribeLock.Unlock() return nil
return thread.Thread.Process(context, signal)
}
func (thread * GQLThread) Type() NodeType {
return NodeType("gql_thread")
}
func (thread * GQLThread) Serialize() ([]byte, error) {
thread_json := NewGQLThreadJSON(thread)
return json.MarshalIndent(&thread_json, "", " ")
} }
func (thread * GQLThread) Users() map[NodeID]*User { const GQLExtType = ExtType("gql_thread")
return thread.UserMap func (ext *GQLExt) Type() ExtType {
return GQLExtType
} }
type GQLThreadJSON struct { type GQLExtJSON struct {
ThreadJSON
Listen string `json:"listen"` Listen string `json:"listen"`
Users []string `json:"users"`
Key []byte `json:"key"` Key []byte `json:"key"`
ECDH uint8 `json:"ecdh_curve"` ECDH uint8 `json:"ecdh_curve"`
TLSKey []byte `json:"ssl_key"` TLSKey []byte `json:"ssl_key"`
TLSCert []byte `json:"ssl_cert"` TLSCert []byte `json:"ssl_cert"`
} }
func (ext *GQLExt) Serialize() ([]byte, error) {
ser_key, err := x509.MarshalECPrivateKey(ext.Key)
if err != nil {
return nil, err
}
return json.MarshalIndent(&GQLExtJSON{
Listen: ext.Listen,
Key: ser_key,
ECDH: ecdh_curve_ids[ext.ECDH],
TLSKey: ext.tls_key,
TLSCert: ext.tls_cert,
}, "", " ")
}
var ecdsa_curves = map[uint8]elliptic.Curve{ var ecdsa_curves = map[uint8]elliptic.Curve{
0: elliptic.P256(), 0: elliptic.P256(),
} }
@ -711,33 +776,13 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{
ecdh.P256(): 0, ecdh.P256(): 0,
} }
func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON { func LoadGQLExt(ctx *Context, data []byte) (Extension, error) {
thread_json := NewThreadJSON(&thread.Thread) var j GQLExtJSON
err := json.Unmarshal(data, &j)
ser_key, err := x509.MarshalECPrivateKey(thread.Key)
if err != nil { if err != nil {
panic(err) return nil, err
}
users := make([]string, len(thread.UserMap))
i := 0
for id, _ := range(thread.UserMap) {
users[i] = id.String()
i += 1
}
return GQLThreadJSON{
ThreadJSON: thread_json,
Listen: thread.Listen,
Users: users,
Key: ser_key,
ECDH: ecdh_curve_ids[thread.ECDH],
TLSKey: thread.tls_key,
TLSCert: thread.tls_cert,
}
} }
var LoadGQLThread = LoadJSONNode(func(id NodeID, j GQLThreadJSON) (Node, error) {
ecdh_curve, ok := ecdh_curves[j.ECDH] ecdh_curve, ok := ecdh_curves[j.ECDH]
if ok == false { if ok == false {
return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH) return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH)
@ -748,27 +793,19 @@ var LoadGQLThread = LoadJSONNode(func(id NodeID, j GQLThreadJSON) (Node, error)
return nil, err return nil, err
} }
thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey) extension := GQLExt{
return &thread, nil Listen: j.Listen,
}, func(ctx *Context, thread *GQLThread, j GQLThreadJSON, nodes NodeMap) error { Key: key,
thread.UserMap = map[NodeID]*User{} ECDH: ecdh_curve,
for _, id_str := range(j.Users) { SubscribeListeners: []chan GraphSignal{},
ctx.Log.Logf("db", "THREAD_LOAD_USER: %s", id_str) tls_key: j.TLSKey,
user_id, err := ParseID(id_str) tls_cert: j.TLSCert,
if err != nil {
return err
}
user, err := LoadNodeRecurse(ctx, user_id, nodes)
if err != nil {
return err
}
thread.UserMap[user_id] = user.(*User)
} }
return RestoreThread(ctx, thread, j.ThreadJSON, nodes) return &extension, nil
}) }
func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLThread { func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLExt {
if tls_cert == nil || tls_key == nil { if tls_cert == nil || tls_key == nil {
ssl_key, err := ecdsa.GenerateKey(key.Curve, rand.Reader) ssl_key, err := ecdsa.GenerateKey(key.Curve, rand.Reader)
if err != nil { if err != nil {
@ -808,12 +845,9 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh
tls_cert = ssl_cert_pem tls_cert = ssl_cert_pem
tls_key = ssl_key_pem tls_key = ssl_key_pem
} }
return GQLThread{ return GQLExt{
Thread: NewThread(id, name, state_name, []InfoType{"parent"}, gql_actions, gql_handlers),
Listen: listen, Listen: listen,
SubscribeListeners: []chan GraphSignal{}, SubscribeListeners: []chan GraphSignal{},
UserMap: map[NodeID]*User{},
http_done: &sync.WaitGroup{},
Key: key, Key: key,
ECDH: ecdh_curve, ECDH: ecdh_curve,
tls_cert: tls_cert, tls_cert: tls_cert,
@ -823,23 +857,26 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh
var gql_actions ThreadActions = ThreadActions{ var gql_actions ThreadActions = ThreadActions{
"wait": ThreadWait, "wait": ThreadWait,
"restore": func(ctx *Context, node ThreadNode) (string, error) { "restore": func(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
return "start_server", ThreadRestore(ctx, node, false) return "start_server", ThreadRestore(ctx, thread, thread_ext, false)
}, },
"start": func(ctx * Context, node ThreadNode) (string, error) { "start": func(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
_, err := ThreadStart(ctx, node) _, err := ThreadStart(ctx, thread, thread_ext)
if err != nil { if err != nil {
return "", err return "", err
} }
return "start_server", ThreadRestore(ctx, node, true) return "start_server", ThreadRestore(ctx, thread, thread_ext, true)
}, },
"start_server": func(ctx * Context, node ThreadNode) (string, error) { "start_server": func(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
gql_thread := node.(*GQLThread) gql_ext, err := GetExt[*GQLExt](thread)
if err != nil {
return "", err
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/auth", AuthHandler(ctx, gql_thread)) mux.HandleFunc("/auth", AuthHandler(ctx, thread, gql_ext))
mux.HandleFunc("/gql", GQLHandler(ctx, gql_thread)) mux.HandleFunc("/gql", GQLHandler(ctx, thread, gql_ext))
mux.HandleFunc("/gqlws", GQLWSHandler(ctx, gql_thread)) mux.HandleFunc("/gqlws", GQLWSHandler(ctx, thread, gql_ext))
// Server a graphiql interface(TODO make configurable whether to start this) // Server a graphiql interface(TODO make configurable whether to start this)
mux.HandleFunc("/graphiql", GraphiQLHandler()) mux.HandleFunc("/graphiql", GraphiQLHandler())
@ -849,7 +886,7 @@ var gql_actions ThreadActions = ThreadActions{
mux.Handle("/site/", http.StripPrefix("/site", fs)) mux.Handle("/site/", http.StripPrefix("/site", fs))
http_server := &http.Server{ http_server := &http.Server{
Addr: gql_thread.Listen, Addr: gql_ext.Listen,
Handler: mux, Handler: mux,
} }
@ -858,7 +895,7 @@ var gql_actions ThreadActions = ThreadActions{
return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr)
} }
cert, err := tls.X509KeyPair(gql_thread.tls_cert, gql_thread.tls_key) cert, err := tls.X509KeyPair(gql_ext.tls_cert, gql_ext.tls_key)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -870,23 +907,21 @@ var gql_actions ThreadActions = ThreadActions{
listener := tls.NewListener(l, &config) listener := tls.NewListener(l, &config)
gql_thread.http_done.Add(1) gql_ext.http_done.Add(1)
go func(gql_thread *GQLThread) { go func(qql_ext *GQLExt) {
defer gql_thread.http_done.Done() defer gql_ext.http_done.Done()
err := http_server.Serve(listener) err := http_server.Serve(listener)
if err != http.ErrServerClosed { if err != http.ErrServerClosed {
panic(fmt.Sprintf("Failed to start gql server: %s", err)) panic(fmt.Sprintf("Failed to start gql server: %s", err))
} }
}(gql_thread) }(gql_ext)
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err = UpdateStates(context, node, NewLockMap( err = UpdateStates(context, thread, NewACLInfo(thread, []string{"http_server"}), func(context *StateContext) error {
NewLockInfo(node, []string{"http_server"}), gql_ext.tcp_listener = listener
), func(context *StateContext) error { gql_ext.http_server = http_server
gql_thread.tcp_listener = listener
gql_thread.http_server = http_server
return nil return nil
}) })
@ -895,18 +930,22 @@ var gql_actions ThreadActions = ThreadActions{
} }
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = Signal(context, gql_thread, gql_thread, NewStatusSignal("server_started", gql_thread.ID())) err = Signal(context, thread, thread, NewStatusSignal("server_started", thread.ID))
if err != nil { if err != nil {
return "", err return "", err
} }
return "wait", nil return "wait", nil
}, },
"finish": func(ctx *Context, node ThreadNode) (string, error) { "finish": func(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
gql_thread := node.(*GQLThread) gql_ext, err := GetExt[*GQLExt](thread)
gql_thread.http_server.Shutdown(context.TODO()) if err != nil {
gql_thread.http_done.Wait() return "", err
return ThreadFinish(ctx, node) }
gql_ext.http_server.Shutdown(context.TODO())
gql_ext.http_done.Wait()
return ThreadFinish(ctx, thread, thread_ext)
}, },
} }

@ -3,6 +3,7 @@ package graphvent
import ( import (
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
"reflect" "reflect"
"fmt"
) )
func NewField(init func()*graphql.Field) *graphql.Field { func NewField(init func()*graphql.Field) *graphql.Field {
@ -26,22 +27,34 @@ func NewSingleton[K graphql.Type](init func() K, post_init func(K, *graphql.List
} }
} }
func addNodeInterfaceFields(i *graphql.Interface) { func AddNodeInterfaceFields(i *graphql.Interface) {
i.AddFieldConfig("ID", &graphql.Field{ i.AddFieldConfig("ID", &graphql.Field{
Type: graphql.String, Type: graphql.String,
}) })
i.AddFieldConfig("TypeHash", &graphql.Field{
Type: graphql.String,
})
}
func PrepTypeResolve(p graphql.ResolveTypeParams) (*ResolveContext, error) {
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return nil, fmt.Errorf("Bad resolve in params context")
}
return resolve_context, nil
} }
var GQLInterfaceNode = NewSingleton(func() *graphql.Interface { var GQLInterfaceNode = NewSingleton(func() *graphql.Interface {
i := graphql.NewInterface(graphql.InterfaceConfig{ i := graphql.NewInterface(graphql.InterfaceConfig{
Name: "Node", Name: "Node",
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, ok := p.Context.Value("graph_context").(*Context) ctx, err := PrepTypeResolve(p)
if ok == false { if err != nil {
return nil return nil
} }
valid_nodes := ctx.GQL.ValidNodes valid_nodes := ctx.GQLContext.ValidNodes
p_type := reflect.TypeOf(p.Value) p_type := reflect.TypeOf(p.Value)
for key, value := range(valid_nodes) { for key, value := range(valid_nodes) {
@ -50,9 +63,9 @@ var GQLInterfaceNode = NewSingleton(func() *graphql.Interface {
} }
} }
_, ok = p.Value.(Node) _, ok := p.Value.(Node)
if ok == true { if ok == true {
return ctx.GQL.BaseNodeType return ctx.GQLContext.BaseNodeType
} }
return nil return nil
@ -60,41 +73,21 @@ var GQLInterfaceNode = NewSingleton(func() *graphql.Interface {
Fields: graphql.Fields{}, Fields: graphql.Fields{},
}) })
addNodeInterfaceFields(i) AddNodeInterfaceFields(i)
return i return i
}, nil) }, nil)
func addLockableInterfaceFields(i *graphql.Interface, lockable *graphql.Interface, list *graphql.List) {
addNodeInterfaceFields(i)
i.AddFieldConfig("Name", &graphql.Field{
Type: graphql.String,
})
i.AddFieldConfig("Requirements", &graphql.Field{
Type: list,
})
i.AddFieldConfig("Dependencies", &graphql.Field{
Type: list,
})
i.AddFieldConfig("Owner", &graphql.Field{
Type: lockable,
})
}
var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface { var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface {
gql_interface_lockable := graphql.NewInterface(graphql.InterfaceConfig{ gql_interface_lockable := graphql.NewInterface(graphql.InterfaceConfig{
Name: "Lockable", Name: "Lockable",
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, ok := p.Context.Value("graph_context").(*Context) ctx, err := PrepTypeResolve(p)
if ok == false { if err != nil {
return nil return nil
} }
valid_lockables := ctx.GQL.ValidLockables valid_lockables := ctx.GQLContext.ValidLockables
p_type := reflect.TypeOf(p.Value) p_type := reflect.TypeOf(p.Value)
for key, value := range(valid_lockables) { for key, value := range(valid_lockables) {
@ -103,9 +96,9 @@ var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface {
} }
} }
_, ok = p.Value.(Lockable) _, ok := p.Value.(*Node)
if ok == true { if ok == false {
return ctx.GQL.BaseLockableType return ctx.GQLContext.BaseLockableType
} }
return nil return nil
}, },
@ -114,31 +107,30 @@ var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface {
return gql_interface_lockable return gql_interface_lockable
}, func(lockable *graphql.Interface, lockable_list *graphql.List) { }, func(lockable *graphql.Interface, lockable_list *graphql.List) {
addLockableInterfaceFields(lockable, lockable, lockable_list) lockable.AddFieldConfig("Requirements", &graphql.Field{
Type: lockable_list,
}) })
func addThreadInterfaceFields(i *graphql.Interface, thread *graphql.Interface, list *graphql.List) { lockable.AddFieldConfig("Dependencies", &graphql.Field{
addLockableInterfaceFields(i, GQLInterfaceLockable.Type, GQLInterfaceLockable.List) Type: lockable_list,
i.AddFieldConfig("Children", &graphql.Field{
Type: list,
}) })
i.AddFieldConfig("Parent", &graphql.Field{ lockable.AddFieldConfig("Owner", &graphql.Field{
Type: thread, Type: lockable,
})
AddNodeInterfaceFields(lockable)
}) })
}
var GQLInterfaceThread = NewSingleton(func() *graphql.Interface { var GQLInterfaceThread = NewSingleton(func() *graphql.Interface {
gql_interface_thread := graphql.NewInterface(graphql.InterfaceConfig{ gql_interface_thread := graphql.NewInterface(graphql.InterfaceConfig{
Name: "Thread", Name: "Thread",
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, ok := p.Context.Value("graph_context").(*Context) ctx, err := PrepTypeResolve(p)
if ok == false { if err != nil {
return nil return nil
} }
valid_threads := ctx.GQL.ValidThreads valid_threads := ctx.GQLContext.ValidThreads
p_type := reflect.TypeOf(p.Value) p_type := reflect.TypeOf(p.Value)
for key, value := range(valid_threads) { for key, value := range(valid_threads) {
@ -147,9 +139,14 @@ var GQLInterfaceThread = NewSingleton(func() *graphql.Interface {
} }
} }
_, ok = p.Value.(Thread) node, ok := p.Value.(*Node)
if ok == true { if ok == false {
return ctx.GQL.BaseThreadType return nil
}
_, err = GetExt[*ThreadExt](node)
if err == nil {
return ctx.GQLContext.BaseThreadType
} }
return nil return nil
@ -159,5 +156,17 @@ var GQLInterfaceThread = NewSingleton(func() *graphql.Interface {
return gql_interface_thread return gql_interface_thread
}, func(thread *graphql.Interface, thread_list *graphql.List) { }, func(thread *graphql.Interface, thread_list *graphql.List) {
addThreadInterfaceFields(thread, thread, thread_list) thread.AddFieldConfig("Children", &graphql.Field{
Type: thread_list,
})
thread.AddFieldConfig("Parent", &graphql.Field{
Type: thread,
})
thread.AddFieldConfig("State", &graphql.Field{
Type: graphql.String,
})
AddNodeInterfaceFields(thread)
}) })

@ -5,12 +5,18 @@ import (
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
) )
func PrepResolve(p graphql.ResolveParams) (*ResolveContext, error) { func PrepResolve(p graphql.ResolveParams) (*Node, *ResolveContext, error) {
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext) resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false { if ok == false {
return nil, fmt.Errorf("Bad resolve in params context") return nil, nil, fmt.Errorf("Bad resolve in params context")
} }
return resolve_context, nil
node, ok := p.Source.(*Node)
if ok == false {
return nil, nil, fmt.Errorf("Source is not a *Node in PrepResolve")
}
return node, resolve_context, nil
} }
// TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does // TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does
@ -65,30 +71,38 @@ func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) {
// TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted // TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted
func GQLNodeID(p graphql.ResolveParams) (interface{}, error) { func GQLNodeID(p graphql.ResolveParams) (interface{}, error) {
node, ok := p.Source.(Node) node, _, err := PrepResolve(p)
if ok == false || node == nil { if err != nil {
return nil, fmt.Errorf("Failed to cast source to Node") return nil, err
} }
return node.ID(), nil return node.ID, nil
} }
func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) { func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p) node, _, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node, ok := p.Source.(*GQLThread) return string(node.Type), nil
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to GQLThread")
} }
func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
gql_ext, err := GetExt[*GQLExt](node)
if err != nil {
return nil, err
}
listen := "" listen := ""
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"listen"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewACLInfo(node, []string{"listen"}), func(context *StateContext) error {
listen = node.Listen listen = gql_ext.Listen
return nil return nil
}) })
@ -100,20 +114,20 @@ func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) {
} }
func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) { func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p) node, ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node, ok := p.Source.(*Thread) thread_ext, err := GetExt[*ThreadExt](node)
if ok == false || node == nil { if err != nil {
return nil, fmt.Errorf("Failed to cast source to Thread") return nil, err
} }
var parent ThreadNode = nil var parent *Node = nil
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"parent"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
parent = node.ThreadHandle().Parent parent = thread_ext.Parent
return nil return nil
}) })
@ -125,20 +139,20 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
} }
func GQLThreadState(p graphql.ResolveParams) (interface{}, error) { func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p) node, ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node, ok := p.Source.(ThreadNode) thread_ext, err := GetExt[*ThreadExt](node)
if ok == false || node == nil { if err != nil {
return nil, fmt.Errorf("Failed to cast source to Thread") return nil, err
} }
var state string var state string
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"state"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewACLInfo(node, []string{"state"}), func(context *StateContext) error {
state = node.ThreadHandle().StateName state = thread_ext.State
return nil return nil
}) })
@ -150,50 +164,20 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
} }
func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) { func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p) node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(ThreadNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Thread")
}
var children []ThreadNode = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"children"}), func(context *StateContext) error {
children = make([]ThreadNode, len(node.ThreadHandle().Children))
i := 0
for _, info := range(node.ThreadHandle().Children) {
children[i] = info.Child
i += 1
}
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
return children, nil thread_ext, err := GetExt[*ThreadExt](node)
}
func GQLLockableName(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node, ok := p.Source.(LockableNode) var children []*Node = nil
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable")
}
name := ""
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"name"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
name = node.LockableHandle().Name children = thread_ext.ChildList()
return nil return nil
}) })
@ -201,26 +185,26 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) {
return nil, err return nil, err
} }
return name, nil return children, nil
} }
func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) { func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p) node, ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node, ok := p.Source.(LockableNode) lockable_ext, err := GetExt[*LockableExt](node)
if ok == false || node == nil { if err != nil {
return nil, fmt.Errorf("Failed to cast source to Lockable") return nil, err
} }
var requirements []LockableNode = nil var requirements []*Node = nil
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"requirements"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
requirements = make([]LockableNode, len(node.LockableHandle().Requirements)) requirements = make([]*Node, len(lockable_ext.Requirements))
i := 0 i := 0
for _, req := range(node.LockableHandle().Requirements) { for _, req := range(lockable_ext.Requirements) {
requirements[i] = req requirements[i] = req
i += 1 i += 1
} }
@ -235,22 +219,22 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
} }
func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) { func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p) node, ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node, ok := p.Source.(LockableNode) lockable_ext, err := GetExt[*LockableExt](node)
if ok == false || node == nil { if err != nil {
return nil, fmt.Errorf("Failed to cast source to Lockable") return nil, err
} }
var dependencies []LockableNode = nil var dependencies []*Node = nil
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"dependencies"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewACLInfo(node, []string{"dependencies"}), func(context *StateContext) error {
dependencies = make([]LockableNode, len(node.LockableHandle().Dependencies)) dependencies = make([]*Node, len(lockable_ext.Dependencies))
i := 0 i := 0
for _, dep := range(node.LockableHandle().Dependencies) { for _, dep := range(lockable_ext.Dependencies) {
dependencies[i] = dep dependencies[i] = dep
i += 1 i += 1
} }
@ -265,20 +249,20 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
} }
func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p) node, ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node, ok := p.Source.(LockableNode) lockable_ext, err := GetExt[*LockableExt](node)
if ok == false || node == nil { if err != nil {
return nil, fmt.Errorf("Failed to cast source to Lockable") return nil, err
} }
var owner Node = nil var owner *Node = nil
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"owner"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewACLInfo(node, []string{"owner"}), func(context *StateContext) error {
owner = node.LockableHandle().Owner owner = lockable_ext.Owner
return nil return nil
}) })
@ -289,24 +273,24 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
return owner, nil return owner, nil
} }
func GQLGroupNodeUsers(p graphql.ResolveParams) (interface{}, error) { func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p) node, ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node, ok := p.Source.(GroupNode) group_ext, err := GetExt[*GroupExt](node)
if ok == false || node == nil { if err != nil {
return nil, fmt.Errorf("Failed to cast source to GQLThread") return nil, err
} }
var users []*User var members []*Node
context := NewReadContext(ctx.Context) context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"users"}), func(context *StateContext) error { err = UseStates(context, ctx.User, NewACLInfo(node, []string{"users"}), func(context *StateContext) error {
users = make([]*User, len(node.Users())) members = make([]*Node, len(group_ext.Members))
i := 0 i := 0
for _, user := range(node.Users()) { for _, member := range(group_ext.Members) {
users[i] = user members[i] = member
i += 1 i += 1
} }
return nil return nil
@ -316,7 +300,7 @@ func GQLGroupNodeUsers(p graphql.ResolveParams) (interface{}, error) {
return nil, err return nil, err
} }
return users, nil return members, nil
} }
func GQLSignalFn(p graphql.ResolveParams, fn func(GraphSignal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { func GQLSignalFn(p graphql.ResolveParams, fn func(GraphSignal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {

@ -9,16 +9,16 @@ func AddNodeFields(obj *graphql.Object) {
Type: graphql.String, Type: graphql.String,
Resolve: GQLNodeID, Resolve: GQLNodeID,
}) })
obj.AddFieldConfig("TypeHash", &graphql.Field{
Type: graphql.String,
Resolve: GQLNodeTypeHash,
})
} }
func AddLockableFields(obj *graphql.Object) { func AddLockableFields(obj *graphql.Object) {
AddNodeFields(obj) AddNodeFields(obj)
obj.AddFieldConfig("Name", &graphql.Field{
Type: graphql.String,
Resolve: GQLLockableName,
})
obj.AddFieldConfig("Requirements", &graphql.Field{ obj.AddFieldConfig("Requirements", &graphql.Field{
Type: GQLInterfaceLockable.List, Type: GQLInterfaceLockable.List,
Resolve: GQLLockableRequirements, Resolve: GQLLockableRequirements,
@ -36,7 +36,7 @@ func AddLockableFields(obj *graphql.Object) {
} }
func AddThreadFields(obj *graphql.Object) { func AddThreadFields(obj *graphql.Object) {
AddLockableFields(obj) AddNodeFields(obj)
obj.AddFieldConfig("State", &graphql.Field{ obj.AddFieldConfig("State", &graphql.Field{
Type: graphql.String, Type: graphql.String,
@ -54,56 +54,7 @@ func AddThreadFields(obj *graphql.Object) {
}) })
} }
var GQLTypeUser = NewSingleton(func() *graphql.Object { var GQLTypeBaseThread = NewSingleton(func() *graphql.Object {
gql_type_user := graphql.NewObject(graphql.ObjectConfig{
Name: "User",
Interfaces: []*graphql.Interface{
GQLInterfaceNode.Type,
GQLInterfaceLockable.Type,
},
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(*User)
return ok
},
Fields: graphql.Fields{},
})
AddLockableFields(gql_type_user)
return gql_type_user
}, nil)
var GQLTypeGQLThread = NewSingleton(func() *graphql.Object {
gql_type_gql_thread := graphql.NewObject(graphql.ObjectConfig{
Name: "GQLThread",
Interfaces: []*graphql.Interface{
GQLInterfaceNode.Type,
GQLInterfaceThread.Type,
GQLInterfaceLockable.Type,
},
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(*GQLThread)
return ok
},
Fields: graphql.Fields{},
})
AddThreadFields(gql_type_gql_thread)
gql_type_gql_thread.AddFieldConfig("Users", &graphql.Field{
Type: GQLTypeUser.List,
Resolve: GQLGroupNodeUsers,
})
gql_type_gql_thread.AddFieldConfig("Listen", &graphql.Field{
Type: graphql.String,
Resolve: GQLThreadListen,
})
return gql_type_gql_thread
}, nil)
var GQLTypeSimpleThread = NewSingleton(func() *graphql.Object {
gql_type_simple_thread := graphql.NewObject(graphql.ObjectConfig{ gql_type_simple_thread := graphql.NewObject(graphql.ObjectConfig{
Name: "SimpleThread", Name: "SimpleThread",
Interfaces: []*graphql.Interface{ Interfaces: []*graphql.Interface{
@ -112,8 +63,13 @@ var GQLTypeSimpleThread = NewSingleton(func() *graphql.Object {
GQLInterfaceLockable.Type, GQLInterfaceLockable.Type,
}, },
IsTypeOf: func(p graphql.IsTypeOfParams) bool { IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(Thread) node, ok := p.Value.(*Node)
return ok if ok == false {
return false
}
_, err := GetExt[*ThreadExt](node)
return err == nil
}, },
Fields: graphql.Fields{}, Fields: graphql.Fields{},
}) })
@ -123,7 +79,7 @@ var GQLTypeSimpleThread = NewSingleton(func() *graphql.Object {
return gql_type_simple_thread return gql_type_simple_thread
}, nil) }, nil)
var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object { var GQLTypeBaseLockable = NewSingleton(func() *graphql.Object {
gql_type_simple_lockable := graphql.NewObject(graphql.ObjectConfig{ gql_type_simple_lockable := graphql.NewObject(graphql.ObjectConfig{
Name: "SimpleLockable", Name: "SimpleLockable",
Interfaces: []*graphql.Interface{ Interfaces: []*graphql.Interface{
@ -131,8 +87,13 @@ var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object {
GQLInterfaceLockable.Type, GQLInterfaceLockable.Type,
}, },
IsTypeOf: func(p graphql.IsTypeOfParams) bool { IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(Lockable) node, ok := p.Value.(*Node)
return ok if ok == false {
return false
}
_, err := GetExt[*LockableExt](node)
return err == nil
}, },
Fields: graphql.Fields{}, Fields: graphql.Fields{},
}) })
@ -142,14 +103,14 @@ var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object {
return gql_type_simple_lockable return gql_type_simple_lockable
}, nil) }, nil)
var GQLTypeSimpleNode = NewSingleton(func() *graphql.Object { var GQLTypeBaseNode = NewSingleton(func() *graphql.Object {
object := graphql.NewObject(graphql.ObjectConfig{ object := graphql.NewObject(graphql.ObjectConfig{
Name: "SimpleNode", Name: "SimpleNode",
Interfaces: []*graphql.Interface{ Interfaces: []*graphql.Interface{
GQLInterfaceNode.Type, GQLInterfaceNode.Type,
}, },
IsTypeOf: func(p graphql.IsTypeOfParams) bool { IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(Node) _, ok := p.Value.(*Node)
return ok return ok
}, },
Fields: graphql.Fields{}, Fields: graphql.Fields{},

@ -45,46 +45,57 @@ func (ext *LockableExt) Type() ExtType {
return LockableExtType return LockableExtType
} }
type LockableExtJSON struct {
Owner string `json:"owner"`
Requirements []string `json:"requirements"`
Dependencies []string `json:"dependencies"`
LocksHeld map[string]string `json:"locks_held"`
}
func (ext *LockableExt) Serialize() ([]byte, error) { func (ext *LockableExt) Serialize() ([]byte, error) {
requirements := make([]string, len(ext.Requirements)) return json.MarshalIndent(&LockableExtJSON{
req_n := 0 Owner: SaveNode(ext.Owner),
for id, _ := range(ext.Requirements) { Requirements: SaveNodeList(ext.Requirements),
requirements[req_n] = id.String() Dependencies: SaveNodeList(ext.Dependencies),
req_n++ LocksHeld: SaveNodeMap(ext.LocksHeld),
}, "", " ")
} }
dependencies := make([]string, len(ext.Dependencies)) func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
dep_n := 0 var j LockableExtJSON
for id, _ := range(ext.Dependencies) { err := json.Unmarshal(data, &j)
dependencies[dep_n] = id.String() if err != nil {
dep_n++ return nil, err
} }
owner := "" owner, err := RestoreNode(ctx, j.Owner)
if ext.Owner != nil { if err != nil {
owner = ext.Owner.ID.String() return nil, err
} }
locks_held := map[string]string{} requirements, err := RestoreNodeList(ctx, j.Requirements)
for lockable_id, node := range(ext.LocksHeld) { if err != nil {
if node == nil { return nil, err
locks_held[lockable_id.String()] = ""
} else {
locks_held[lockable_id.String()] = node.ID.String()
} }
dependencies, err := RestoreNodeList(ctx, j.Dependencies)
if err != nil {
return nil, err
} }
return json.MarshalIndent(&struct{ locks_held, err := RestoreNodeMap(ctx, j.LocksHeld)
Owner string `json:"owner"` if err != nil {
Requirements []string `json:"requirements"` return nil, err
Dependencies []string `json:"dependencies"` }
LocksHeld map[string]string `json:"locks_held"`
}{ extension := LockableExt{
Owner: owner, Owner: owner,
Requirements: requirements, Requirements: requirements,
Dependencies: dependencies, Dependencies: dependencies,
LocksHeld: locks_held, LocksHeld: locks_held,
}, "", " ") }
return &extension, nil
} }
func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphSignal) error { func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
@ -469,6 +480,14 @@ func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node)
}) })
} }
func SaveNode(node *Node) string {
str := ""
if node != nil {
str = node.ID.String()
}
return str
}
func RestoreNode(ctx *Context, id_str string) (*Node, error) { func RestoreNode(ctx *Context, id_str string) (*Node, error) {
id, err := ParseID(id_str) id, err := ParseID(id_str)
if err != nil { if err != nil {
@ -478,6 +497,14 @@ func RestoreNode(ctx *Context, id_str string) (*Node, error) {
return LoadNode(ctx, id) return LoadNode(ctx, id)
} }
func SaveNodeMap(nodes NodeMap) map[string]string {
m := map[string]string{}
for id, node := range(nodes) {
m[id.String()] = SaveNode(node)
}
return m
}
func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) { func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) {
nodes := NodeMap{} nodes := NodeMap{}
for id_str_1, id_str_2 := range(ids) { for id_str_1, id_str_2 := range(ids) {
@ -507,6 +534,17 @@ func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) {
return nodes, nil return nodes, nil
} }
func SaveNodeList(nodes NodeMap) []string {
ids := make([]string, len(nodes))
i := 0
for id, _ := range(nodes) {
ids[i] = id.String()
i += 1
}
return ids
}
func RestoreNodeList(ctx *Context, ids []string) (NodeMap, error) { func RestoreNodeList(ctx *Context, ids []string) (NodeMap, error) {
nodes := NodeMap{} nodes := NodeMap{}

@ -73,6 +73,7 @@ type Extension interface {
// Nodes represent an addressible group of extensions // Nodes represent an addressible group of extensions
type Node struct { type Node struct {
ID NodeID ID NodeID
Type NodeType
Lock sync.RWMutex Lock sync.RWMutex
ExtensionMap map[ExtType]Extension ExtensionMap map[ExtType]Extension
} }
@ -93,65 +94,6 @@ func GetExt[T Extension](node *Node) (T, error) {
return ret, nil return ret, nil
} }
// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies
type ACLExtension struct {
Delegations NodeMap
}
func (ext ACLExtension) Process(context *StateContext, node *Node, signal GraphSignal) error {
return nil
}
func LoadACLExtension(ctx *Context, data []byte) (Extension, error) {
var j struct {
Delegations []string `json:"delegation"`
}
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
delegations := NodeMap{}
for _, str := range(j.Delegations) {
id, err := ParseID(str)
if err != nil {
return nil, err
}
node, err := LoadNode(ctx, id)
if err != nil {
return nil, err
}
delegations[id] = node
}
return ACLExtension{
Delegations: delegations,
}, nil
}
func (ext ACLExtension) Serialize() ([]byte, error) {
delegations := make([]string, len(ext.Delegations))
i := 0
for id, _ := range(ext.Delegations) {
delegations[i] = id.String()
i += 1
}
return json.MarshalIndent(&struct{
Delegations []string `json:"delegations"`
}{
Delegations: delegations,
}, "", " ")
}
const ACLExtType = ExtType("ACL")
func (extension ACLExtension) Type() ExtType {
return ACLExtType
}
func (node *Node) Serialize() ([]byte, error) { func (node *Node) Serialize() ([]byte, error) {
extensions := make([]ExtensionDB, len(node.ExtensionMap)) extensions := make([]ExtensionDB, len(node.ExtensionMap))
node_db := NodeDB{ node_db := NodeDB{
@ -181,9 +123,10 @@ func (node *Node) Serialize() ([]byte, error) {
return node_db.Serialize(), nil return node_db.Serialize(), nil
} }
func NewNode(id NodeID) Node { func NewNode(id NodeID, node_type NodeType) Node {
return Node{ return Node{
ID: id, ID: id,
Type: node_type,
ExtensionMap: map[ExtType]Extension{}, ExtensionMap: map[ExtType]Extension{},
} }
} }
@ -198,15 +141,15 @@ func Allowed(context *StateContext, principal *Node, action string, node *Node)
if exists == false { if exists == false {
return fmt.Errorf("%s does not have ACL extension, other nodes cannot perform actions on it", node.ID) return fmt.Errorf("%s does not have ACL extension, other nodes cannot perform actions on it", node.ID)
} }
acl_ext := ext.(ACLExtension) acl_ext := ext.(ACLExt)
for _, policy_node := range(acl_ext.Delegations) { for _, policy_node := range(acl_ext.Delegations) {
ext, exists := policy_node.ExtensionMap[ACLPolicyExtType] ext, exists := policy_node.ExtensionMap[ACLPolicyExtType]
if exists == false { if exists == false {
context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExtension") context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExt")
continue continue
} }
policy_ext := ext.(ACLPolicyExtension) policy_ext := ext.(ACLPolicyExt)
if policy_ext.Allows(context, principal, action, node) == true { if policy_ext.Allows(context, principal, action, node) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s", principal.ID, node.ID, action) context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s", principal.ID, node.ID, action)
return nil return nil
@ -238,11 +181,12 @@ func Signal(context *StateContext, node *Node, princ *Node, signal GraphSignal)
// Magic first four bytes of serialized DB content, stored big endian // Magic first four bytes of serialized DB content, stored big endian
const NODE_DB_MAGIC = 0x2491df14 const NODE_DB_MAGIC = 0x2491df14
// Total length of the node database header, has magic to verify and type_hash to map to load function // Total length of the node database header, has magic to verify and type_hash to map to load function
const NODE_DB_HEADER_LEN = 8 const NODE_DB_HEADER_LEN = 16
// A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node // A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node
type NodeDBHeader struct { type NodeDBHeader struct {
Magic uint32 Magic uint32
NumExtensions uint32 NumExtensions uint32
TypeHash uint64
} }
type NodeDB struct { type NodeDB struct {
@ -258,6 +202,7 @@ func NewNodeDB(data []byte) (NodeDB, error) {
magic := binary.BigEndian.Uint32(data[0:4]) magic := binary.BigEndian.Uint32(data[0:4])
num_extensions := binary.BigEndian.Uint32(data[4:8]) num_extensions := binary.BigEndian.Uint32(data[4:8])
node_type_hash := binary.BigEndian.Uint64(data[8:16])
ptr += NODE_DB_HEADER_LEN ptr += NODE_DB_HEADER_LEN
@ -290,6 +235,7 @@ func NewNodeDB(data []byte) (NodeDB, error) {
return NodeDB{ return NodeDB{
Header: NodeDBHeader{ Header: NodeDBHeader{
Magic: magic, Magic: magic,
TypeHash: node_type_hash,
NumExtensions: num_extensions, NumExtensions: num_extensions,
}, },
Extensions: extensions, Extensions: extensions,
@ -304,6 +250,7 @@ func (header NodeDBHeader) Serialize() []byte {
ret := make([]byte, NODE_DB_HEADER_LEN) ret := make([]byte, NODE_DB_HEADER_LEN)
binary.BigEndian.PutUint32(ret[0:4], header.Magic) binary.BigEndian.PutUint32(ret[0:4], header.Magic)
binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions) binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions)
binary.BigEndian.PutUint64(ret[8:16], header.TypeHash)
return ret return ret
} }
@ -411,8 +358,13 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
return nil, err return nil, err
} }
node_type, known := ctx.Types[node_db.Header.TypeHash]
if known == false {
return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash)
}
// Create the blank node with the ID, and add it to the context // Create the blank node with the ID, and add it to the context
new_node := NewNode(id) new_node := NewNode(id, node_type.Type)
node = &new_node node = &new_node
ctx.Nodes[id] = node ctx.Nodes[id] = node
@ -476,6 +428,12 @@ func ACLList(list []*Node, resources []string) ACLMap {
return reqs return reqs
} }
type NodeType string
func (node NodeType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("NODE: %s", string(node))))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
}
type PolicyType string type PolicyType string
func (policy PolicyType) Hash() uint64 { func (policy PolicyType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("POLICY: %s", string(policy)))) hash := sha512.Sum512([]byte(fmt.Sprintf("POLICY: %s", string(policy))))

@ -32,10 +32,58 @@ func (policy AllNodesPolicy) Serialize() ([]byte, error) {
} }
// Extension to allow a node to hold ACL policies // Extension to allow a node to hold ACL policies
type ACLPolicyExtension struct { type ACLPolicyExt struct {
Policies map[PolicyType]Policy Policies map[PolicyType]Policy
} }
// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies
type ACLExt struct {
Delegations NodeMap
}
func (ext ACLExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
return nil
}
func LoadACLExt(ctx *Context, data []byte) (Extension, error) {
var j struct {
Delegations []string `json:"delegation"`
}
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
delegations, err := RestoreNodeList(ctx, j.Delegations)
if err != nil {
return nil, err
}
return ACLExt{
Delegations: delegations,
}, nil
}
func (ext ACLExt) Serialize() ([]byte, error) {
delegations := make([]string, len(ext.Delegations))
i := 0
for id, _ := range(ext.Delegations) {
delegations[i] = id.String()
i += 1
}
return json.MarshalIndent(&struct{
Delegations []string `json:"delegations"`
}{
Delegations: delegations,
}, "", " ")
}
const ACLExtType = ExtType("ACL")
func (extension ACLExt) Type() ExtType {
return ACLExtType
}
type PolicyLoadFunc func(*Context, []byte) (Policy, error) type PolicyLoadFunc func(*Context, []byte) (Policy, error)
type PolicyInfo struct { type PolicyInfo struct {
@ -43,11 +91,15 @@ type PolicyInfo struct {
Type PolicyType Type PolicyType
} }
type ACLPolicyExtensionContext struct { type ACLPolicyExtContext struct {
Types map[PolicyType]PolicyInfo Types map[PolicyType]PolicyInfo
} }
func (ext ACLPolicyExtension) Serialize() ([]byte, error) { func NewACLPolicyExtContext() *ACLPolicyExtContext {
return nil
}
func (ext ACLPolicyExt) Serialize() ([]byte, error) {
policies := map[string][]byte{} policies := map[string][]byte{}
for name, policy := range(ext.Policies) { for name, policy := range(ext.Policies) {
ser, err := policy.Serialize() ser, err := policy.Serialize()
@ -64,11 +116,11 @@ func (ext ACLPolicyExtension) Serialize() ([]byte, error) {
}, "", " ") }, "", " ")
} }
func (ext ACLPolicyExtension) Process(context *StateContext, node *Node, signal GraphSignal) error { func (ext ACLPolicyExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
return nil return nil
} }
func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) { func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) {
var j struct { var j struct {
Policies map[string][]byte `json:"policies"` Policies map[string][]byte `json:"policies"`
} }
@ -78,7 +130,7 @@ func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) {
} }
policies := map[PolicyType]Policy{} policies := map[PolicyType]Policy{}
acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(ACLPolicyExtensionContext) acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(ACLPolicyExtContext)
for name, ser := range(j.Policies) { for name, ser := range(j.Policies) {
policy_def, exists := acl_ctx.Types[PolicyType(name)] policy_def, exists := acl_ctx.Types[PolicyType(name)]
if exists == false { if exists == false {
@ -92,18 +144,18 @@ func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) {
policies[PolicyType(name)] = policy policies[PolicyType(name)] = policy
} }
return ACLPolicyExtension{ return ACLPolicyExt{
Policies: policies, Policies: policies,
}, nil }, nil
} }
const ACLPolicyExtType = ExtType("ACL_POLICIES") const ACLPolicyExtType = ExtType("ACL_POLICIES")
func (ext ACLPolicyExtension) Type() ExtType { func (ext ACLPolicyExt) Type() ExtType {
return ACLPolicyExtType return ACLPolicyExtType
} }
// Check if the extension allows the principal to perform action on node // Check if the extension allows the principal to perform action on node
func (ext ACLPolicyExtension) Allows(context *StateContext, principal *Node, action string, node *Node) bool { func (ext ACLPolicyExt) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
for _, policy := range(ext.Policies) { for _, policy := range(ext.Policies) {
if policy.Allows(context, principal, action, node) == true { if policy.Allows(context, principal, action, node) == true {
return true return true

@ -8,6 +8,19 @@ import (
"encoding/json" "encoding/json"
) )
type QueuedAction struct {
Timeout time.Time `json:"time"`
Action string `json:"action"`
}
type ThreadExtContext struct {
Loads map[InfoType]func([]byte)ThreadInfo
}
func NewThreadExtContext() *ThreadExtContext {
return nil
}
type ThreadExt struct { type ThreadExt struct {
Actions ThreadActions Actions ThreadActions
Handlers ThreadHandlers Handlers ThreadHandlers
@ -19,7 +32,7 @@ type ThreadExt struct {
ActiveLock sync.Mutex ActiveLock sync.Mutex
Active bool Active bool
StateName string State string
Parent *Node Parent *Node
Children map[NodeID]ChildInfo Children map[NodeID]ChildInfo
@ -28,15 +41,94 @@ type ThreadExt struct {
NextAction *QueuedAction NextAction *QueuedAction
} }
type ThreadExtJSON struct {
State string `json:"state"`
Parent string `json:"parent"`
Children map[string][]byte `json:"children"`
ActionQueue []QueuedAction
}
func (ext *ThreadExt) Serialize() ([]byte, error) { func (ext *ThreadExt) Serialize() ([]byte, error) {
return nil, fmt.Errorf("NOT_IMPLEMENTED") return nil, fmt.Errorf("NOT_IMPLEMENTED")
} }
const THREAD_BUFFER_SIZE int = 1024
func LoadThreadExt(ctx *Context, data []byte) (Extension, error) {
var j ThreadExtJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
parent, err := RestoreNode(ctx, j.Parent)
if err != nil {
return nil, err
}
children := map[NodeID]ChildInfo{}
for id_str, _ := range(j.Children) {
child_node, err := RestoreNode(ctx, id_str)
if err != nil {
return nil, err
}
//TODO: Restore child info based off context
children[child_node.ID] = ChildInfo{
Child: child_node,
Infos: map[InfoType]ThreadInfo{},
}
}
next_action, timeout_chan := SoonestAction(j.ActionQueue)
extension := ThreadExt{
Actions: BaseThreadActions,
Handlers: BaseThreadHandlers,
SignalChan: make(chan GraphSignal, THREAD_BUFFER_SIZE),
TimeoutChan: timeout_chan,
Active: false,
State: j.State,
Parent: parent,
Children: children,
ActionQueue: j.ActionQueue,
NextAction: next_action,
}
return &extension, nil
}
const ThreadExtType = ExtType("THREAD") const ThreadExtType = ExtType("THREAD")
func (ext *ThreadExt) Type() ExtType { func (ext *ThreadExt) Type() ExtType {
return ThreadExtType return ThreadExtType
} }
func (ext *ThreadExt) QueueAction(end time.Time, action string) {
ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action})
ext.NextAction, ext.TimeoutChan = SoonestAction(ext.ActionQueue)
}
func (ext *ThreadExt) ClearActionQueue() {
ext.ActionQueue = []QueuedAction{}
ext.NextAction = nil
ext.TimeoutChan = nil
}
func SoonestAction(actions []QueuedAction) (*QueuedAction, <-chan time.Time) {
var soonest_action *QueuedAction
var soonest_time time.Time
for _, action := range(actions) {
if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil {
soonest_action = &action
soonest_time = action.Timeout
}
}
if soonest_action != nil {
return soonest_action, time.After(time.Until(soonest_action.Timeout))
} else {
return nil, nil
}
}
func (ext *ThreadExt) ChildList() []*Node { func (ext *ThreadExt) ChildList() []*Node {
ret := make([]*Node, len(ext.Children)) ret := make([]*Node, len(ext.Children))
i := 0 i := 0
@ -235,38 +327,6 @@ func NewChildInfo(child *Node, infos map[InfoType]ThreadInfo) ChildInfo {
} }
} }
type QueuedAction struct {
Timeout time.Time `json:"time"`
Action string `json:"action"`
}
func (ext *ThreadExt) QueueAction(end time.Time, action string) {
ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action})
ext.NextAction, ext.TimeoutChan = ext.SoonestAction()
}
func (ext *ThreadExt) ClearActionQueue() {
ext.ActionQueue = []QueuedAction{}
ext.NextAction = nil
ext.TimeoutChan = nil
}
func (ext *ThreadExt) SoonestAction() (*QueuedAction, <-chan time.Time) {
var soonest_action *QueuedAction
var soonest_time time.Time
for _, action := range(ext.ActionQueue) {
if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil {
soonest_action = &action
soonest_time = action.Timeout
}
}
if soonest_action != nil {
return soonest_action, time.After(time.Until(soonest_action.Timeout))
} else {
return nil, nil
}
}
var deserializers = map[InfoType]func(interface{})(interface{}, error) { var deserializers = map[InfoType]func(interface{})(interface{}, error) {
"parent": func(raw interface{})(interface{}, error) { "parent": func(raw interface{})(interface{}, error) {
m, ok := raw.(map[string]interface{}) m, ok := raw.(map[string]interface{})
@ -294,14 +354,14 @@ var deserializers = map[InfoType]func(interface{})(interface{}, error) {
}, },
} }
func NewThreadExt(buffer int, name string, state_name string, actions ThreadActions, handlers ThreadHandlers) ThreadExt { func NewThreadExt(buffer int, name string, state string, actions ThreadActions, handlers ThreadHandlers) ThreadExt {
return ThreadExt{ return ThreadExt{
Actions: actions, Actions: actions,
Handlers: handlers, Handlers: handlers,
SignalChan: make(chan GraphSignal, buffer), SignalChan: make(chan GraphSignal, buffer),
TimeoutChan: nil, TimeoutChan: nil,
Active: false, Active: false,
StateName: state_name, State: state,
Parent: nil, Parent: nil,
Children: map[NodeID]ChildInfo{}, Children: map[NodeID]ChildInfo{},
ActionQueue: []QueuedAction{}, ActionQueue: []QueuedAction{},
@ -322,7 +382,7 @@ func (ext *ThreadExt) SetActive(active bool) error {
} }
func (ext *ThreadExt) SetState(state string) error { func (ext *ThreadExt) SetState(state string) error {
ext.StateName = state ext.State = state
return nil return nil
} }
@ -485,7 +545,7 @@ func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start boo
} }
parent_info := info.Infos["parent"].(*ParentThreadInfo) parent_info := info.Infos["parent"].(*ParentThreadInfo)
if parent_info.Start == true && child_ext.StateName != "finished" { if parent_info.Start == true && child_ext.State != "finished" {
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID) ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID)
if start == true { if start == true {
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction) ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
@ -537,7 +597,7 @@ func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, err
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewACLMap(NewACLInfo(thread, []string{"timeout"})), func(context *StateContext) error { err := UpdateStates(context, thread, NewACLMap(NewACLInfo(thread, []string{"timeout"})), func(context *StateContext) error {
timeout_action = thread_ext.NextAction.Action timeout_action = thread_ext.NextAction.Action
thread_ext.NextAction, thread_ext.TimeoutChan = thread_ext.SoonestAction() thread_ext.NextAction, thread_ext.TimeoutChan = SoonestAction(thread_ext.ActionQueue)
return nil return nil
}) })
if err != nil { if err != nil {

@ -8,47 +8,47 @@ import (
"crypto/x509" "crypto/x509"
) )
type GroupNode interface { type ECDHExt struct {
Node
Users() map[NodeID]*User
}
type User struct {
Lockable
Granted time.Time Granted time.Time
Pubkey *ecdsa.PublicKey Pubkey *ecdsa.PublicKey
Shared []byte Shared []byte
Tags []string
} }
type UserJSON struct { type ECDHExtJSON struct {
LockableJSON
Granted time.Time `json:"granted"` Granted time.Time `json:"granted"`
Pubkey []byte `json:"pubkey"` Pubkey []byte `json:"pubkey"`
Shared []byte `json:"shared"` Shared []byte `json:"shared"`
} }
func (user *User) Type() NodeType { func (ext *ECDHExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
return NodeType("user") return nil
} }
func (user *User) Serialize() ([]byte, error) { const ECDHExtType = ExtType("ECDH")
lockable_json := NewLockableJSON(&user.Lockable) func (ext *ECDHExt) Type() ExtType {
pubkey, err := x509.MarshalPKIXPublicKey(user.Pubkey) return ECDHExtType
}
func (ext *ECDHExt) Serialize() ([]byte, error) {
pubkey, err := x509.MarshalPKIXPublicKey(ext.Pubkey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return json.MarshalIndent(&UserJSON{ return json.MarshalIndent(&ECDHExtJSON{
LockableJSON: lockable_json, Granted: ext.Granted,
Granted: user.Granted,
Shared: user.Shared,
Pubkey: pubkey, Pubkey: pubkey,
Shared: ext.Shared,
}, "", " ") }, "", " ")
} }
var LoadUser = LoadJSONNode(func(id NodeID, j UserJSON) (Node, error) { func LoadECDHExt(ctx *Context, data []byte) (Extension, error) {
var j ECDHExtJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
pub, err := x509.ParsePKIXPublicKey(j.Pubkey) pub, err := x509.ParsePKIXPublicKey(j.Pubkey)
if err != nil { if err != nil {
return nil, err return nil, err
@ -59,83 +59,56 @@ var LoadUser = LoadJSONNode(func(id NodeID, j UserJSON) (Node, error) {
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
pubkey = pub.(*ecdsa.PublicKey) pubkey = pub.(*ecdsa.PublicKey)
default: default:
return nil, fmt.Errorf("Invalid key type") return nil, fmt.Errorf("Invalid key type: %+v", pub)
} }
user := NewUser(j.Name, j.Granted, pubkey, j.Shared) extension := ECDHExt{
return &user, nil Granted: j.Granted,
}, func(ctx *Context, user *User, j UserJSON, nodes NodeMap) error {
return RestoreLockable(ctx, user, j.LockableJSON, nodes)
})
func NewUser(name string, granted time.Time, pubkey *ecdsa.PublicKey, shared []byte) User {
id := KeyID(pubkey)
return User{
Lockable: NewLockable(id, name),
Granted: granted,
Pubkey: pubkey, Pubkey: pubkey,
Shared: shared, Shared: j.Shared,
}
} }
type Group struct { return &extension, nil
Lockable
UserMap map[NodeID]*User
} }
func NewGroup(id NodeID, name string) Group { type GroupExt struct {
return Group{ Members NodeMap
Lockable: NewLockable(id, name),
UserMap: map[NodeID]*User{},
}
} }
type GroupJSON struct { const GroupExtType = ExtType("GROUP")
LockableJSON func (ext *GroupExt) Type() ExtType {
Users []string `json:"users"` return GroupExtType
} }
func (group *Group) Type() NodeType { func (ext *GroupExt) Serialize() ([]byte, error) {
return NodeType("group") return json.MarshalIndent(&struct{
} Members []string `json:"members"`
}{
func (group *Group) Serialize() ([]byte, error) { Members: SaveNodeList(ext.Members),
users := make([]string, len(group.UserMap)) }, "", " ")
i := 0
for id, _ := range(group.UserMap) {
users[i] = id.String()
i += 1
} }
return json.MarshalIndent(&GroupJSON{ func LoadGroupExt(ctx *Context, data []byte) (Extension, error) {
LockableJSON: NewLockableJSON(&group.Lockable), var j struct {
Users: users, Members []string `json:"members"`
}, "", " ")
} }
var LoadGroup = LoadJSONNode(func(id NodeID, j GroupJSON) (Node, error) { err := json.Unmarshal(data, &j)
group := NewGroup(id, j.Name)
return &group, nil
}, func(ctx *Context, group *Group, j GroupJSON, nodes NodeMap) error {
for _, id_str := range(j.Users) {
id, err := ParseID(id_str)
if err != nil { if err != nil {
return err return nil, err
} }
user_node, err := LoadNodeRecurse(ctx, id, nodes) members, err := RestoreNodeList(ctx, j.Members)
if err != nil { if err != nil {
return err return nil, err
} }
user, ok := user_node.(*User) extension := GroupExt{
if ok == false { Members: members,
return fmt.Errorf("%s is not a *User", id_str)
} }
return &extension, nil
group.UserMap[id] = user
} }
return RestoreLockable(ctx, group, j.LockableJSON, nodes) func (ext *GroupExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
}) return nil
}