Added gql to the rework

graph-rework-2
noah metz 2023-07-26 00:18:11 -06:00
parent ff813d6c2b
commit f1c0f1e7de
10 changed files with 695 additions and 572 deletions

@ -5,13 +5,20 @@ import (
"fmt"
)
//Function to load an extension from bytes
type ExtensionLoadFunc func(*Context, []byte) (Extension, error)
// Information about a loaded extension
type ExtensionInfo struct {
Load ExtensionLoadFunc
Type ExtType
Data interface{}
}
// Information about a loaded node type
type NodeInfo struct {
Type NodeType
}
// A Context is all the data needed to run a graphvent
type Context struct {
// DB is the database connection used to load and write nodes
@ -20,6 +27,8 @@ type Context struct {
Log Logger
// A mapping between type hashes and their corresponding extension definitions
Extensions map[uint64]ExtensionInfo
// A mapping between type hashes and their corresponding node definitions
Types map[uint64]NodeInfo
// All loaded Nodes
Nodes map[NodeID]*Node
}
@ -30,8 +39,21 @@ func (ctx *Context) ExtByType(ext_type ExtType) ExtensionInfo {
return ext
}
func (ctx *Context) RegisterNodeType(node_type NodeType) error {
type_hash := node_type.Hash()
_, exists := ctx.Types[type_hash]
if exists == true {
return fmt.Errorf("Cannot register node type %s, type already exists in context", node_type)
}
ctx.Types[type_hash] = NodeInfo{
Type: node_type,
}
return nil
}
// Add a node to a context, returns an error if the def is invalid or already exists in the context
func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc) error {
func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc, data interface{}) error {
if load_fn == nil {
return fmt.Errorf("def has no load function")
}
@ -45,6 +67,7 @@ func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFun
ctx.Extensions[type_hash] = ExtensionInfo{
Load: load_fn,
Type: ext_type,
Data: data,
}
return nil
}
@ -55,15 +78,41 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
DB: db,
Log: log,
Extensions: map[uint64]ExtensionInfo{},
Types: map[uint64]NodeInfo{},
Nodes: map[NodeID]*Node{},
}
err := ctx.RegisterExtension(ACLExtType, LoadACLExtension)
err := ctx.RegisterExtension(ACLExtType, LoadACLExt, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(ACLPolicyExtType, LoadACLPolicyExt, NewACLPolicyExtContext())
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(LockableExtType, LoadLockableExt, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(ThreadExtType, LoadThreadExt, NewThreadExtContext())
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(ECDHExtType, LoadECDHExt, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(GroupExtType, LoadGroupExt, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(ACLPolicyExtType, LoadACLPolicyExtension)
err = ctx.RegisterExtension(GQLExtType, LoadGQLExt, NewGQLExtContext())
if err != nil {
return nil, err
}

313
gql.go

@ -70,7 +70,7 @@ type AuthRespJSON struct {
Signature []byte `json:"signature"`
}
func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) {
func NewAuthRespJSON(gql_ext *GQLExt, req AuthReqJSON) (AuthRespJSON, *ecdsa.PublicKey, []byte, error) {
// Check if req.Time is within +- 1 second of now
now := time.Now()
earliest := now.Add(-1 * time.Second)
@ -82,12 +82,12 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_TIME_TOO_EARLY: %s", req.Time)
}
x, y := elliptic.Unmarshal(thread.Key.Curve, req.Pubkey)
x, y := elliptic.Unmarshal(gql_ext.Key.Curve, req.Pubkey)
if x == nil {
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_UNMARSHAL_FAIL: %+v", req.Pubkey)
}
remote, err := thread.ECDH.NewPublicKey(req.ECDHPubkey)
remote, err := gql_ext.ECDH.NewPublicKey(req.ECDHPubkey)
if err != nil {
return AuthRespJSON{}, nil, nil, err
}
@ -98,7 +98,7 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P
sig_hash := sha512.Sum512(sig_data)
remote_key := &ecdsa.PublicKey{
Curve: thread.Key.Curve,
Curve: gql_ext.Key.Curve,
X: x,
Y: y,
}
@ -113,7 +113,7 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P
return AuthRespJSON{}, nil, nil, fmt.Errorf("GQL_AUTH_VERIFY_FAIL: %+v", req)
}
ec_key, err := thread.ECDH.GenerateKey(rand.Reader)
ec_key, err := gql_ext.ECDH.GenerateKey(rand.Reader)
if err != nil {
return AuthRespJSON{}, nil, nil, err
}
@ -125,7 +125,7 @@ func NewAuthRespJSON(thread *GQLThread, req AuthReqJSON) (AuthRespJSON, *ecdsa.P
resp_sig_data := append(ec_key_pub, time_ser...)
resp_sig_hash := sha512.Sum512(resp_sig_data)
resp_sig, err := ecdsa.SignASN1(rand.Reader, thread.Key, resp_sig_hash[:])
resp_sig, err := ecdsa.SignASN1(rand.Reader, gql_ext.Key, resp_sig_hash[:])
if err != nil {
return AuthRespJSON{}, nil, nil, err
}
@ -156,7 +156,7 @@ func ParseAuthRespJSON(resp AuthRespJSON, ecdsa_curve elliptic.Curve, ecdh_curve
return shared_secret, nil
}
func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *http.Request) {
func AuthHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx.Log.Logf("gql", "GQL_AUTH_REQUEST: %s", r.RemoteAddr)
enableCORS(&w)
@ -174,7 +174,7 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
return
}
resp, remote_id, shared, err := NewAuthRespJSON(server, req)
resp, _, _, err := NewAuthRespJSON(gql_ext, req)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_VERIFY_ERROR: %s", err)
return
@ -195,34 +195,31 @@ func AuthHandler(ctx *Context, server *GQLThread) func(http.ResponseWriter, *htt
return
}
key_id := KeyID(remote_id)
_, exists := server.UserMap[key_id]
if exists {
/*if exists {
ctx.Log.Logf("gql", "REFRESHING AUTH FOR %s", key_id)
} else {
ctx.Log.Logf("gql", "AUTHORIZING NEW USER %s - %s", key_id, shared)
new_user := NewUser(fmt.Sprintf("GQL_USER %s", key_id.String()), time.Now(), remote_id, shared)
context := NewWriteContext(ctx)
err := UpdateStates(context, server, NewLockMap(LockMap{
server.ID(): LockInfo{
err := UpdateStates(context, server, ACLMap{
server.ID: ACLInfo{
Node: server,
Resources: []string{"users"},
},
new_user.ID(): LockInfo{
new_user.ID: ACLInfo{
Node: &new_user,
Resources: nil,
},
}), func(context *StateContext) error {
server.UserMap[key_id] = &new_user
}, func(context *StateContext) error {
server.Users[key_id] = &new_user
return nil
})
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_UPDATE_ERR: %s", err)
return
}
}
}*/
}
}
@ -363,11 +360,13 @@ func checkForAuthHeader(header http.Header) (string, bool) {
type ResolveContext struct {
Context *Context
Server *GQLThread
User *User
GQLContext *GQLExtContext
Server *Node
Ext *GQLExt
User *Node
}
func NewResolveContext(ctx *Context, server *GQLThread, r *http.Request) (*ResolveContext, error) {
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) {
username, password, ok := r.BasicAuth()
if ok == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header")
@ -378,25 +377,29 @@ func NewResolveContext(ctx *Context, server *GQLThread, r *http.Request) (*Resol
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username)
}
user, exists := server.UserMap[auth_id]
user, exists := gql_ext.Users[auth_id]
if exists == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no existing authorization for client %s", auth_id)
}
if base64.StdEncoding.EncodeToString(user.Shared) != password {
user_ext, err := GetExt[*ECDHExt](user)
if err != nil {
return nil, err
}
if base64.StdEncoding.EncodeToString(user_ext.Shared) != password {
return nil, fmt.Errorf("GQL_AUTH_FAIL")
}
return &ResolveContext{
Context: ctx,
GQLContext: ctx.Extensions[GQLExtType.Hash()].Data.(*GQLExtContext),
Server: server,
User: user,
}, nil
}
func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) {
gql_ctx := context.Background()
func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr)
enableCORS(&w)
@ -406,7 +409,7 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht
}
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, r)
resolve_context, err := NewResolveContext(ctx, server, gql_ext, r)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err)))
@ -414,7 +417,7 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht
}
req_ctx := context.Background()
req_ctx = context.WithValue(gql_ctx, "resolve", resolve_context)
req_ctx = context.WithValue(req_ctx, "resolve", resolve_context)
str, err := io.ReadAll(r.Body)
if err != nil {
@ -425,8 +428,10 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht
query := GQLPayload{}
json.Unmarshal(str, &query)
gql_context := ctx.Extensions[GQLExtType.Hash()].Data.(*GQLExtContext)
params := graphql.Params{
Schema: ctx.GQL.Schema,
Schema: gql_context.Schema,
Context: req_ctx,
RequestString: query.Query,
}
@ -494,11 +499,7 @@ func GQLWSDo(ctx * Context, p graphql.Params) chan *graphql.Result {
return sendOneResultAndClose(res)
}
func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) {
gql_ctx := context.Background()
gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx)
gql_ctx = context.WithValue(gql_ctx, "gql_server", server)
func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr)
enableCORS(&w)
@ -508,7 +509,7 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *
}
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, r)
resolve_context, err := NewResolveContext(ctx, server, gql_ext, r)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
return
@ -557,8 +558,9 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *
}
} else if msg.Type == "subscribe" {
ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload)
gql_context := ctx.Extensions[GQLExtType.Hash()].Data.(*GQLExtContext)
params := graphql.Params{
Schema: ctx.GQL.Schema,
Schema: gql_context.Schema,
Context: req_ctx,
RequestString: msg.Payload.Query,
}
@ -628,35 +630,94 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *
}
}
type GQLThread struct {
Thread
// Map of go types to graphql types
type ObjTypeMap map[reflect.Type]*graphql.Object
// GQL Specific Context information
type GQLExtContext struct {
// Generated GQL schema
Schema graphql.Schema
// List of GQL types
TypeList []graphql.Type
// Interface type maps to map go types of specific interfaces to gql types
ValidNodes ObjTypeMap
ValidLockables ObjTypeMap
ValidThreads ObjTypeMap
BaseNodeType *graphql.Object
BaseLockableType *graphql.Object
BaseThreadType *graphql.Object
Query *graphql.Object
Mutation *graphql.Object
Subscription *graphql.Object
}
func NewGQLExtContext() *GQLExtContext {
query := graphql.NewObject(graphql.ObjectConfig{
Name: "Query",
Fields: graphql.Fields{},
})
mutation := graphql.NewObject(graphql.ObjectConfig{
Name: "Mutation",
Fields: graphql.Fields{},
})
subscription := graphql.NewObject(graphql.ObjectConfig{
Name: "Subscription",
Fields: graphql.Fields{},
})
context := GQLExtContext{
Schema: graphql.Schema{},
TypeList: []graphql.Type{},
ValidNodes: ObjTypeMap{},
ValidThreads: ObjTypeMap{},
ValidLockables: ObjTypeMap{},
Query: query,
Mutation: mutation,
Subscription: subscription,
BaseNodeType: GQLTypeBaseNode.Type,
BaseLockableType: GQLTypeBaseLockable.Type,
BaseThreadType: GQLTypeBaseThread.Type,
}
return &context
}
type GQLExt struct {
tcp_listener net.Listener
http_server *http.Server
http_done *sync.WaitGroup
http_done sync.WaitGroup
tls_key []byte
tls_cert []byte
Listen string
UserMap map[NodeID]*User
Users NodeMap
Key *ecdsa.PrivateKey
ECDH ecdh.Curve
SubscribeLock sync.Mutex
SubscribeListeners []chan GraphSignal
}
func (thread *GQLThread) NewSubscriptionChannel(buffer int) chan GraphSignal {
thread.SubscribeLock.Lock()
defer thread.SubscribeLock.Unlock()
func (ext *GQLExt) NewSubscriptionChannel(buffer int) chan GraphSignal {
ext.SubscribeLock.Lock()
defer ext.SubscribeLock.Unlock()
new_listener := make(chan GraphSignal, buffer)
thread.SubscribeListeners = append(thread.SubscribeListeners, new_listener)
ext.SubscribeListeners = append(ext.SubscribeListeners, new_listener)
return new_listener
}
func (thread *GQLThread) Process(context *StateContext, signal GraphSignal) error {
func (ext *GQLExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
ext.SubscribeLock.Lock()
defer ext.SubscribeLock.Unlock()
active_listeners := []chan GraphSignal{}
thread.SubscribeLock.Lock()
for _, listener := range(thread.SubscribeListeners) {
for _, listener := range(ext.SubscribeListeners) {
select {
case listener <- signal:
active_listeners = append(active_listeners, listener)
@ -667,34 +728,38 @@ func (thread *GQLThread) Process(context *StateContext, signal GraphSignal) erro
}(listener)
}
}
thread.SubscribeListeners = active_listeners
thread.SubscribeLock.Unlock()
return thread.Thread.Process(context, signal)
}
func (thread * GQLThread) Type() NodeType {
return NodeType("gql_thread")
}
func (thread * GQLThread) Serialize() ([]byte, error) {
thread_json := NewGQLThreadJSON(thread)
return json.MarshalIndent(&thread_json, "", " ")
ext.SubscribeListeners = active_listeners
return nil
}
func (thread * GQLThread) Users() map[NodeID]*User {
return thread.UserMap
const GQLExtType = ExtType("gql_thread")
func (ext *GQLExt) Type() ExtType {
return GQLExtType
}
type GQLThreadJSON struct {
ThreadJSON
type GQLExtJSON struct {
Listen string `json:"listen"`
Users []string `json:"users"`
Key []byte `json:"key"`
ECDH uint8 `json:"ecdh_curve"`
TLSKey []byte `json:"ssl_key"`
TLSCert []byte `json:"ssl_cert"`
}
func (ext *GQLExt) Serialize() ([]byte, error) {
ser_key, err := x509.MarshalECPrivateKey(ext.Key)
if err != nil {
return nil, err
}
return json.MarshalIndent(&GQLExtJSON{
Listen: ext.Listen,
Key: ser_key,
ECDH: ecdh_curve_ids[ext.ECDH],
TLSKey: ext.tls_key,
TLSCert: ext.tls_cert,
}, "", " ")
}
var ecdsa_curves = map[uint8]elliptic.Curve{
0: elliptic.P256(),
}
@ -711,33 +776,13 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{
ecdh.P256(): 0,
}
func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON {
thread_json := NewThreadJSON(&thread.Thread)
ser_key, err := x509.MarshalECPrivateKey(thread.Key)
func LoadGQLExt(ctx *Context, data []byte) (Extension, error) {
var j GQLExtJSON
err := json.Unmarshal(data, &j)
if err != nil {
panic(err)
}
users := make([]string, len(thread.UserMap))
i := 0
for id, _ := range(thread.UserMap) {
users[i] = id.String()
i += 1
}
return GQLThreadJSON{
ThreadJSON: thread_json,
Listen: thread.Listen,
Users: users,
Key: ser_key,
ECDH: ecdh_curve_ids[thread.ECDH],
TLSKey: thread.tls_key,
TLSCert: thread.tls_cert,
return nil, err
}
}
var LoadGQLThread = LoadJSONNode(func(id NodeID, j GQLThreadJSON) (Node, error) {
ecdh_curve, ok := ecdh_curves[j.ECDH]
if ok == false {
return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH)
@ -748,27 +793,19 @@ var LoadGQLThread = LoadJSONNode(func(id NodeID, j GQLThreadJSON) (Node, error)
return nil, err
}
thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey)
return &thread, nil
}, func(ctx *Context, thread *GQLThread, j GQLThreadJSON, nodes NodeMap) error {
thread.UserMap = map[NodeID]*User{}
for _, id_str := range(j.Users) {
ctx.Log.Logf("db", "THREAD_LOAD_USER: %s", id_str)
user_id, err := ParseID(id_str)
if err != nil {
return err
}
user, err := LoadNodeRecurse(ctx, user_id, nodes)
if err != nil {
return err
}
thread.UserMap[user_id] = user.(*User)
extension := GQLExt{
Listen: j.Listen,
Key: key,
ECDH: ecdh_curve,
SubscribeListeners: []chan GraphSignal{},
tls_key: j.TLSKey,
tls_cert: j.TLSCert,
}
return RestoreThread(ctx, thread, j.ThreadJSON, nodes)
})
return &extension, nil
}
func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLThread {
func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLExt {
if tls_cert == nil || tls_key == nil {
ssl_key, err := ecdsa.GenerateKey(key.Curve, rand.Reader)
if err != nil {
@ -808,12 +845,9 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh
tls_cert = ssl_cert_pem
tls_key = ssl_key_pem
}
return GQLThread{
Thread: NewThread(id, name, state_name, []InfoType{"parent"}, gql_actions, gql_handlers),
return GQLExt{
Listen: listen,
SubscribeListeners: []chan GraphSignal{},
UserMap: map[NodeID]*User{},
http_done: &sync.WaitGroup{},
Key: key,
ECDH: ecdh_curve,
tls_cert: tls_cert,
@ -823,23 +857,26 @@ func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh
var gql_actions ThreadActions = ThreadActions{
"wait": ThreadWait,
"restore": func(ctx *Context, node ThreadNode) (string, error) {
return "start_server", ThreadRestore(ctx, node, false)
"restore": func(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
return "start_server", ThreadRestore(ctx, thread, thread_ext, false)
},
"start": func(ctx * Context, node ThreadNode) (string, error) {
_, err := ThreadStart(ctx, node)
"start": func(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
_, err := ThreadStart(ctx, thread, thread_ext)
if err != nil {
return "", err
}
return "start_server", ThreadRestore(ctx, node, true)
return "start_server", ThreadRestore(ctx, thread, thread_ext, true)
},
"start_server": func(ctx * Context, node ThreadNode) (string, error) {
gql_thread := node.(*GQLThread)
"start_server": func(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
gql_ext, err := GetExt[*GQLExt](thread)
if err != nil {
return "", err
}
mux := http.NewServeMux()
mux.HandleFunc("/auth", AuthHandler(ctx, gql_thread))
mux.HandleFunc("/gql", GQLHandler(ctx, gql_thread))
mux.HandleFunc("/gqlws", GQLWSHandler(ctx, gql_thread))
mux.HandleFunc("/auth", AuthHandler(ctx, thread, gql_ext))
mux.HandleFunc("/gql", GQLHandler(ctx, thread, gql_ext))
mux.HandleFunc("/gqlws", GQLWSHandler(ctx, thread, gql_ext))
// Server a graphiql interface(TODO make configurable whether to start this)
mux.HandleFunc("/graphiql", GraphiQLHandler())
@ -849,7 +886,7 @@ var gql_actions ThreadActions = ThreadActions{
mux.Handle("/site/", http.StripPrefix("/site", fs))
http_server := &http.Server{
Addr: gql_thread.Listen,
Addr: gql_ext.Listen,
Handler: mux,
}
@ -858,7 +895,7 @@ var gql_actions ThreadActions = ThreadActions{
return "", fmt.Errorf("Failed to start listener for server on %s", http_server.Addr)
}
cert, err := tls.X509KeyPair(gql_thread.tls_cert, gql_thread.tls_key)
cert, err := tls.X509KeyPair(gql_ext.tls_cert, gql_ext.tls_key)
if err != nil {
return "", err
}
@ -870,23 +907,21 @@ var gql_actions ThreadActions = ThreadActions{
listener := tls.NewListener(l, &config)
gql_thread.http_done.Add(1)
go func(gql_thread *GQLThread) {
defer gql_thread.http_done.Done()
gql_ext.http_done.Add(1)
go func(qql_ext *GQLExt) {
defer gql_ext.http_done.Done()
err := http_server.Serve(listener)
if err != http.ErrServerClosed {
panic(fmt.Sprintf("Failed to start gql server: %s", err))
}
}(gql_thread)
}(gql_ext)
context := NewWriteContext(ctx)
err = UpdateStates(context, node, NewLockMap(
NewLockInfo(node, []string{"http_server"}),
), func(context *StateContext) error {
gql_thread.tcp_listener = listener
gql_thread.http_server = http_server
err = UpdateStates(context, thread, NewACLInfo(thread, []string{"http_server"}), func(context *StateContext) error {
gql_ext.tcp_listener = listener
gql_ext.http_server = http_server
return nil
})
@ -895,18 +930,22 @@ var gql_actions ThreadActions = ThreadActions{
}
context = NewReadContext(ctx)
err = Signal(context, gql_thread, gql_thread, NewStatusSignal("server_started", gql_thread.ID()))
err = Signal(context, thread, thread, NewStatusSignal("server_started", thread.ID))
if err != nil {
return "", err
}
return "wait", nil
},
"finish": func(ctx *Context, node ThreadNode) (string, error) {
gql_thread := node.(*GQLThread)
gql_thread.http_server.Shutdown(context.TODO())
gql_thread.http_done.Wait()
return ThreadFinish(ctx, node)
"finish": func(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, error) {
gql_ext, err := GetExt[*GQLExt](thread)
if err != nil {
return "", err
}
gql_ext.http_server.Shutdown(context.TODO())
gql_ext.http_done.Wait()
return ThreadFinish(ctx, thread, thread_ext)
},
}

@ -3,6 +3,7 @@ package graphvent
import (
"github.com/graphql-go/graphql"
"reflect"
"fmt"
)
func NewField(init func()*graphql.Field) *graphql.Field {
@ -26,22 +27,34 @@ func NewSingleton[K graphql.Type](init func() K, post_init func(K, *graphql.List
}
}
func addNodeInterfaceFields(i *graphql.Interface) {
func AddNodeInterfaceFields(i *graphql.Interface) {
i.AddFieldConfig("ID", &graphql.Field{
Type: graphql.String,
})
i.AddFieldConfig("TypeHash", &graphql.Field{
Type: graphql.String,
})
}
func PrepTypeResolve(p graphql.ResolveTypeParams) (*ResolveContext, error) {
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return nil, fmt.Errorf("Bad resolve in params context")
}
return resolve_context, nil
}
var GQLInterfaceNode = NewSingleton(func() *graphql.Interface {
i := graphql.NewInterface(graphql.InterfaceConfig{
Name: "Node",
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, ok := p.Context.Value("graph_context").(*Context)
if ok == false {
ctx, err := PrepTypeResolve(p)
if err != nil {
return nil
}
valid_nodes := ctx.GQL.ValidNodes
valid_nodes := ctx.GQLContext.ValidNodes
p_type := reflect.TypeOf(p.Value)
for key, value := range(valid_nodes) {
@ -50,9 +63,9 @@ var GQLInterfaceNode = NewSingleton(func() *graphql.Interface {
}
}
_, ok = p.Value.(Node)
_, ok := p.Value.(Node)
if ok == true {
return ctx.GQL.BaseNodeType
return ctx.GQLContext.BaseNodeType
}
return nil
@ -60,41 +73,21 @@ var GQLInterfaceNode = NewSingleton(func() *graphql.Interface {
Fields: graphql.Fields{},
})
addNodeInterfaceFields(i)
AddNodeInterfaceFields(i)
return i
}, nil)
func addLockableInterfaceFields(i *graphql.Interface, lockable *graphql.Interface, list *graphql.List) {
addNodeInterfaceFields(i)
i.AddFieldConfig("Name", &graphql.Field{
Type: graphql.String,
})
i.AddFieldConfig("Requirements", &graphql.Field{
Type: list,
})
i.AddFieldConfig("Dependencies", &graphql.Field{
Type: list,
})
i.AddFieldConfig("Owner", &graphql.Field{
Type: lockable,
})
}
var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface {
gql_interface_lockable := graphql.NewInterface(graphql.InterfaceConfig{
Name: "Lockable",
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, ok := p.Context.Value("graph_context").(*Context)
if ok == false {
ctx, err := PrepTypeResolve(p)
if err != nil {
return nil
}
valid_lockables := ctx.GQL.ValidLockables
valid_lockables := ctx.GQLContext.ValidLockables
p_type := reflect.TypeOf(p.Value)
for key, value := range(valid_lockables) {
@ -103,9 +96,9 @@ var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface {
}
}
_, ok = p.Value.(Lockable)
if ok == true {
return ctx.GQL.BaseLockableType
_, ok := p.Value.(*Node)
if ok == false {
return ctx.GQLContext.BaseLockableType
}
return nil
},
@ -114,31 +107,30 @@ var GQLInterfaceLockable = NewSingleton(func() *graphql.Interface {
return gql_interface_lockable
}, func(lockable *graphql.Interface, lockable_list *graphql.List) {
addLockableInterfaceFields(lockable, lockable, lockable_list)
})
func addThreadInterfaceFields(i *graphql.Interface, thread *graphql.Interface, list *graphql.List) {
addLockableInterfaceFields(i, GQLInterfaceLockable.Type, GQLInterfaceLockable.List)
lockable.AddFieldConfig("Requirements", &graphql.Field{
Type: lockable_list,
})
i.AddFieldConfig("Children", &graphql.Field{
Type: list,
lockable.AddFieldConfig("Dependencies", &graphql.Field{
Type: lockable_list,
})
i.AddFieldConfig("Parent", &graphql.Field{
Type: thread,
lockable.AddFieldConfig("Owner", &graphql.Field{
Type: lockable,
})
}
AddNodeInterfaceFields(lockable)
})
var GQLInterfaceThread = NewSingleton(func() *graphql.Interface {
gql_interface_thread := graphql.NewInterface(graphql.InterfaceConfig{
Name: "Thread",
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, ok := p.Context.Value("graph_context").(*Context)
if ok == false {
ctx, err := PrepTypeResolve(p)
if err != nil {
return nil
}
valid_threads := ctx.GQL.ValidThreads
valid_threads := ctx.GQLContext.ValidThreads
p_type := reflect.TypeOf(p.Value)
for key, value := range(valid_threads) {
@ -147,9 +139,14 @@ var GQLInterfaceThread = NewSingleton(func() *graphql.Interface {
}
}
_, ok = p.Value.(Thread)
if ok == true {
return ctx.GQL.BaseThreadType
node, ok := p.Value.(*Node)
if ok == false {
return nil
}
_, err = GetExt[*ThreadExt](node)
if err == nil {
return ctx.GQLContext.BaseThreadType
}
return nil
@ -159,5 +156,17 @@ var GQLInterfaceThread = NewSingleton(func() *graphql.Interface {
return gql_interface_thread
}, func(thread *graphql.Interface, thread_list *graphql.List) {
addThreadInterfaceFields(thread, thread, thread_list)
thread.AddFieldConfig("Children", &graphql.Field{
Type: thread_list,
})
thread.AddFieldConfig("Parent", &graphql.Field{
Type: thread,
})
thread.AddFieldConfig("State", &graphql.Field{
Type: graphql.String,
})
AddNodeInterfaceFields(thread)
})

@ -5,12 +5,18 @@ import (
"github.com/graphql-go/graphql"
)
func PrepResolve(p graphql.ResolveParams) (*ResolveContext, error) {
func PrepResolve(p graphql.ResolveParams) (*Node, *ResolveContext, error) {
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return nil, fmt.Errorf("Bad resolve in params context")
return nil, nil, fmt.Errorf("Bad resolve in params context")
}
return resolve_context, nil
node, ok := p.Source.(*Node)
if ok == false {
return nil, nil, fmt.Errorf("Source is not a *Node in PrepResolve")
}
return node, resolve_context, nil
}
// TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does
@ -65,30 +71,38 @@ func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) {
// TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted
func GQLNodeID(p graphql.ResolveParams) (interface{}, error) {
node, ok := p.Source.(Node)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Node")
node, _, err := PrepResolve(p)
if err != nil {
return nil, err
}
return node.ID(), nil
return node.ID, nil
}
func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
node, _, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(*GQLThread)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to GQLThread")
return string(node.Type), nil
}
func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
gql_ext, err := GetExt[*GQLExt](node)
if err != nil {
return nil, err
}
listen := ""
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"listen"}), func(context *StateContext) error {
listen = node.Listen
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"listen"}), func(context *StateContext) error {
listen = gql_ext.Listen
return nil
})
@ -100,20 +114,20 @@ func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) {
}
func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(*Thread)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Thread")
thread_ext, err := GetExt[*ThreadExt](node)
if err != nil {
return nil, err
}
var parent ThreadNode = nil
var parent *Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"parent"}), func(context *StateContext) error {
parent = node.ThreadHandle().Parent
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
parent = thread_ext.Parent
return nil
})
@ -125,20 +139,20 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) {
}
func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(ThreadNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Thread")
thread_ext, err := GetExt[*ThreadExt](node)
if err != nil {
return nil, err
}
var state string
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"state"}), func(context *StateContext) error {
state = node.ThreadHandle().StateName
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"state"}), func(context *StateContext) error {
state = thread_ext.State
return nil
})
@ -150,50 +164,20 @@ func GQLThreadState(p graphql.ResolveParams) (interface{}, error) {
}
func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(ThreadNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Thread")
}
var children []ThreadNode = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"children"}), func(context *StateContext) error {
children = make([]ThreadNode, len(node.ThreadHandle().Children))
i := 0
for _, info := range(node.ThreadHandle().Children) {
children[i] = info.Child
i += 1
}
return nil
})
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
return children, nil
}
func GQLLockableName(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
thread_ext, err := GetExt[*ThreadExt](node)
if err != nil {
return nil, err
}
node, ok := p.Source.(LockableNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable")
}
name := ""
var children []*Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"name"}), func(context *StateContext) error {
name = node.LockableHandle().Name
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
children = thread_ext.ChildList()
return nil
})
@ -201,26 +185,26 @@ func GQLLockableName(p graphql.ResolveParams) (interface{}, error) {
return nil, err
}
return name, nil
return children, nil
}
func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(LockableNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable")
lockable_ext, err := GetExt[*LockableExt](node)
if err != nil {
return nil, err
}
var requirements []LockableNode = nil
var requirements []*Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"requirements"}), func(context *StateContext) error {
requirements = make([]LockableNode, len(node.LockableHandle().Requirements))
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
requirements = make([]*Node, len(lockable_ext.Requirements))
i := 0
for _, req := range(node.LockableHandle().Requirements) {
for _, req := range(lockable_ext.Requirements) {
requirements[i] = req
i += 1
}
@ -235,22 +219,22 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) {
}
func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(LockableNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable")
lockable_ext, err := GetExt[*LockableExt](node)
if err != nil {
return nil, err
}
var dependencies []LockableNode = nil
var dependencies []*Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"dependencies"}), func(context *StateContext) error {
dependencies = make([]LockableNode, len(node.LockableHandle().Dependencies))
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"dependencies"}), func(context *StateContext) error {
dependencies = make([]*Node, len(lockable_ext.Dependencies))
i := 0
for _, dep := range(node.LockableHandle().Dependencies) {
for _, dep := range(lockable_ext.Dependencies) {
dependencies[i] = dep
i += 1
}
@ -265,20 +249,20 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) {
}
func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(LockableNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to Lockable")
lockable_ext, err := GetExt[*LockableExt](node)
if err != nil {
return nil, err
}
var owner Node = nil
var owner *Node = nil
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"owner"}), func(context *StateContext) error {
owner = node.LockableHandle().Owner
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"owner"}), func(context *StateContext) error {
owner = lockable_ext.Owner
return nil
})
@ -289,24 +273,24 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) {
return owner, nil
}
func GQLGroupNodeUsers(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) {
node, ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(GroupNode)
if ok == false || node == nil {
return nil, fmt.Errorf("Failed to cast source to GQLThread")
group_ext, err := GetExt[*GroupExt](node)
if err != nil {
return nil, err
}
var users []*User
var members []*Node
context := NewReadContext(ctx.Context)
err = UseStates(context, ctx.User, NewLockInfo(node, []string{"users"}), func(context *StateContext) error {
users = make([]*User, len(node.Users()))
err = UseStates(context, ctx.User, NewACLInfo(node, []string{"users"}), func(context *StateContext) error {
members = make([]*Node, len(group_ext.Members))
i := 0
for _, user := range(node.Users()) {
users[i] = user
for _, member := range(group_ext.Members) {
members[i] = member
i += 1
}
return nil
@ -316,7 +300,7 @@ func GQLGroupNodeUsers(p graphql.ResolveParams) (interface{}, error) {
return nil, err
}
return users, nil
return members, nil
}
func GQLSignalFn(p graphql.ResolveParams, fn func(GraphSignal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {

@ -9,16 +9,16 @@ func AddNodeFields(obj *graphql.Object) {
Type: graphql.String,
Resolve: GQLNodeID,
})
obj.AddFieldConfig("TypeHash", &graphql.Field{
Type: graphql.String,
Resolve: GQLNodeTypeHash,
})
}
func AddLockableFields(obj *graphql.Object) {
AddNodeFields(obj)
obj.AddFieldConfig("Name", &graphql.Field{
Type: graphql.String,
Resolve: GQLLockableName,
})
obj.AddFieldConfig("Requirements", &graphql.Field{
Type: GQLInterfaceLockable.List,
Resolve: GQLLockableRequirements,
@ -36,7 +36,7 @@ func AddLockableFields(obj *graphql.Object) {
}
func AddThreadFields(obj *graphql.Object) {
AddLockableFields(obj)
AddNodeFields(obj)
obj.AddFieldConfig("State", &graphql.Field{
Type: graphql.String,
@ -54,56 +54,7 @@ func AddThreadFields(obj *graphql.Object) {
})
}
var GQLTypeUser = NewSingleton(func() *graphql.Object {
gql_type_user := graphql.NewObject(graphql.ObjectConfig{
Name: "User",
Interfaces: []*graphql.Interface{
GQLInterfaceNode.Type,
GQLInterfaceLockable.Type,
},
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(*User)
return ok
},
Fields: graphql.Fields{},
})
AddLockableFields(gql_type_user)
return gql_type_user
}, nil)
var GQLTypeGQLThread = NewSingleton(func() *graphql.Object {
gql_type_gql_thread := graphql.NewObject(graphql.ObjectConfig{
Name: "GQLThread",
Interfaces: []*graphql.Interface{
GQLInterfaceNode.Type,
GQLInterfaceThread.Type,
GQLInterfaceLockable.Type,
},
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(*GQLThread)
return ok
},
Fields: graphql.Fields{},
})
AddThreadFields(gql_type_gql_thread)
gql_type_gql_thread.AddFieldConfig("Users", &graphql.Field{
Type: GQLTypeUser.List,
Resolve: GQLGroupNodeUsers,
})
gql_type_gql_thread.AddFieldConfig("Listen", &graphql.Field{
Type: graphql.String,
Resolve: GQLThreadListen,
})
return gql_type_gql_thread
}, nil)
var GQLTypeSimpleThread = NewSingleton(func() *graphql.Object {
var GQLTypeBaseThread = NewSingleton(func() *graphql.Object {
gql_type_simple_thread := graphql.NewObject(graphql.ObjectConfig{
Name: "SimpleThread",
Interfaces: []*graphql.Interface{
@ -112,8 +63,13 @@ var GQLTypeSimpleThread = NewSingleton(func() *graphql.Object {
GQLInterfaceLockable.Type,
},
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(Thread)
return ok
node, ok := p.Value.(*Node)
if ok == false {
return false
}
_, err := GetExt[*ThreadExt](node)
return err == nil
},
Fields: graphql.Fields{},
})
@ -123,7 +79,7 @@ var GQLTypeSimpleThread = NewSingleton(func() *graphql.Object {
return gql_type_simple_thread
}, nil)
var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object {
var GQLTypeBaseLockable = NewSingleton(func() *graphql.Object {
gql_type_simple_lockable := graphql.NewObject(graphql.ObjectConfig{
Name: "SimpleLockable",
Interfaces: []*graphql.Interface{
@ -131,8 +87,13 @@ var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object {
GQLInterfaceLockable.Type,
},
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(Lockable)
return ok
node, ok := p.Value.(*Node)
if ok == false {
return false
}
_, err := GetExt[*LockableExt](node)
return err == nil
},
Fields: graphql.Fields{},
})
@ -142,14 +103,14 @@ var GQLTypeSimpleLockable = NewSingleton(func() *graphql.Object {
return gql_type_simple_lockable
}, nil)
var GQLTypeSimpleNode = NewSingleton(func() *graphql.Object {
var GQLTypeBaseNode = NewSingleton(func() *graphql.Object {
object := graphql.NewObject(graphql.ObjectConfig{
Name: "SimpleNode",
Interfaces: []*graphql.Interface{
GQLInterfaceNode.Type,
},
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(Node)
_, ok := p.Value.(*Node)
return ok
},
Fields: graphql.Fields{},

@ -45,46 +45,57 @@ func (ext *LockableExt) Type() ExtType {
return LockableExtType
}
type LockableExtJSON struct {
Owner string `json:"owner"`
Requirements []string `json:"requirements"`
Dependencies []string `json:"dependencies"`
LocksHeld map[string]string `json:"locks_held"`
}
func (ext *LockableExt) Serialize() ([]byte, error) {
requirements := make([]string, len(ext.Requirements))
req_n := 0
for id, _ := range(ext.Requirements) {
requirements[req_n] = id.String()
req_n++
return json.MarshalIndent(&LockableExtJSON{
Owner: SaveNode(ext.Owner),
Requirements: SaveNodeList(ext.Requirements),
Dependencies: SaveNodeList(ext.Dependencies),
LocksHeld: SaveNodeMap(ext.LocksHeld),
}, "", " ")
}
func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
var j LockableExtJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
dependencies := make([]string, len(ext.Dependencies))
dep_n := 0
for id, _ := range(ext.Dependencies) {
dependencies[dep_n] = id.String()
dep_n++
owner, err := RestoreNode(ctx, j.Owner)
if err != nil {
return nil, err
}
owner := ""
if ext.Owner != nil {
owner = ext.Owner.ID.String()
requirements, err := RestoreNodeList(ctx, j.Requirements)
if err != nil {
return nil, err
}
locks_held := map[string]string{}
for lockable_id, node := range(ext.LocksHeld) {
if node == nil {
locks_held[lockable_id.String()] = ""
} else {
locks_held[lockable_id.String()] = node.ID.String()
}
dependencies, err := RestoreNodeList(ctx, j.Dependencies)
if err != nil {
return nil, err
}
return json.MarshalIndent(&struct{
Owner string `json:"owner"`
Requirements []string `json:"requirements"`
Dependencies []string `json:"dependencies"`
LocksHeld map[string]string `json:"locks_held"`
}{
locks_held, err := RestoreNodeMap(ctx, j.LocksHeld)
if err != nil {
return nil, err
}
extension := LockableExt{
Owner: owner,
Requirements: requirements,
Dependencies: dependencies,
LocksHeld: locks_held,
}, "", " ")
}
return &extension, nil
}
func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
@ -469,6 +480,14 @@ func UnlockLockables(context *StateContext, to_unlock NodeMap, old_owner *Node)
})
}
func SaveNode(node *Node) string {
str := ""
if node != nil {
str = node.ID.String()
}
return str
}
func RestoreNode(ctx *Context, id_str string) (*Node, error) {
id, err := ParseID(id_str)
if err != nil {
@ -478,6 +497,14 @@ func RestoreNode(ctx *Context, id_str string) (*Node, error) {
return LoadNode(ctx, id)
}
func SaveNodeMap(nodes NodeMap) map[string]string {
m := map[string]string{}
for id, node := range(nodes) {
m[id.String()] = SaveNode(node)
}
return m
}
func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) {
nodes := NodeMap{}
for id_str_1, id_str_2 := range(ids) {
@ -507,6 +534,17 @@ func RestoreNodeMap(ctx *Context, ids map[string]string) (NodeMap, error) {
return nodes, nil
}
func SaveNodeList(nodes NodeMap) []string {
ids := make([]string, len(nodes))
i := 0
for id, _ := range(nodes) {
ids[i] = id.String()
i += 1
}
return ids
}
func RestoreNodeList(ctx *Context, ids []string) (NodeMap, error) {
nodes := NodeMap{}

@ -73,6 +73,7 @@ type Extension interface {
// Nodes represent an addressible group of extensions
type Node struct {
ID NodeID
Type NodeType
Lock sync.RWMutex
ExtensionMap map[ExtType]Extension
}
@ -93,65 +94,6 @@ func GetExt[T Extension](node *Node) (T, error) {
return ret, nil
}
// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies
type ACLExtension struct {
Delegations NodeMap
}
func (ext ACLExtension) Process(context *StateContext, node *Node, signal GraphSignal) error {
return nil
}
func LoadACLExtension(ctx *Context, data []byte) (Extension, error) {
var j struct {
Delegations []string `json:"delegation"`
}
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
delegations := NodeMap{}
for _, str := range(j.Delegations) {
id, err := ParseID(str)
if err != nil {
return nil, err
}
node, err := LoadNode(ctx, id)
if err != nil {
return nil, err
}
delegations[id] = node
}
return ACLExtension{
Delegations: delegations,
}, nil
}
func (ext ACLExtension) Serialize() ([]byte, error) {
delegations := make([]string, len(ext.Delegations))
i := 0
for id, _ := range(ext.Delegations) {
delegations[i] = id.String()
i += 1
}
return json.MarshalIndent(&struct{
Delegations []string `json:"delegations"`
}{
Delegations: delegations,
}, "", " ")
}
const ACLExtType = ExtType("ACL")
func (extension ACLExtension) Type() ExtType {
return ACLExtType
}
func (node *Node) Serialize() ([]byte, error) {
extensions := make([]ExtensionDB, len(node.ExtensionMap))
node_db := NodeDB{
@ -181,9 +123,10 @@ func (node *Node) Serialize() ([]byte, error) {
return node_db.Serialize(), nil
}
func NewNode(id NodeID) Node {
func NewNode(id NodeID, node_type NodeType) Node {
return Node{
ID: id,
Type: node_type,
ExtensionMap: map[ExtType]Extension{},
}
}
@ -198,15 +141,15 @@ func Allowed(context *StateContext, principal *Node, action string, node *Node)
if exists == false {
return fmt.Errorf("%s does not have ACL extension, other nodes cannot perform actions on it", node.ID)
}
acl_ext := ext.(ACLExtension)
acl_ext := ext.(ACLExt)
for _, policy_node := range(acl_ext.Delegations) {
ext, exists := policy_node.ExtensionMap[ACLPolicyExtType]
if exists == false {
context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExtension")
context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExt")
continue
}
policy_ext := ext.(ACLPolicyExtension)
policy_ext := ext.(ACLPolicyExt)
if policy_ext.Allows(context, principal, action, node) == true {
context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s", principal.ID, node.ID, action)
return nil
@ -238,11 +181,12 @@ func Signal(context *StateContext, node *Node, princ *Node, signal GraphSignal)
// Magic first four bytes of serialized DB content, stored big endian
const NODE_DB_MAGIC = 0x2491df14
// Total length of the node database header, has magic to verify and type_hash to map to load function
const NODE_DB_HEADER_LEN = 8
const NODE_DB_HEADER_LEN = 16
// A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node
type NodeDBHeader struct {
Magic uint32
NumExtensions uint32
TypeHash uint64
}
type NodeDB struct {
@ -258,6 +202,7 @@ func NewNodeDB(data []byte) (NodeDB, error) {
magic := binary.BigEndian.Uint32(data[0:4])
num_extensions := binary.BigEndian.Uint32(data[4:8])
node_type_hash := binary.BigEndian.Uint64(data[8:16])
ptr += NODE_DB_HEADER_LEN
@ -290,6 +235,7 @@ func NewNodeDB(data []byte) (NodeDB, error) {
return NodeDB{
Header: NodeDBHeader{
Magic: magic,
TypeHash: node_type_hash,
NumExtensions: num_extensions,
},
Extensions: extensions,
@ -304,6 +250,7 @@ func (header NodeDBHeader) Serialize() []byte {
ret := make([]byte, NODE_DB_HEADER_LEN)
binary.BigEndian.PutUint32(ret[0:4], header.Magic)
binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions)
binary.BigEndian.PutUint64(ret[8:16], header.TypeHash)
return ret
}
@ -411,8 +358,13 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
return nil, err
}
node_type, known := ctx.Types[node_db.Header.TypeHash]
if known == false {
return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash)
}
// Create the blank node with the ID, and add it to the context
new_node := NewNode(id)
new_node := NewNode(id, node_type.Type)
node = &new_node
ctx.Nodes[id] = node
@ -476,6 +428,12 @@ func ACLList(list []*Node, resources []string) ACLMap {
return reqs
}
type NodeType string
func (node NodeType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("NODE: %s", string(node))))
return binary.BigEndian.Uint64(hash[(len(hash)-9):(len(hash)-1)])
}
type PolicyType string
func (policy PolicyType) Hash() uint64 {
hash := sha512.Sum512([]byte(fmt.Sprintf("POLICY: %s", string(policy))))

@ -32,10 +32,58 @@ func (policy AllNodesPolicy) Serialize() ([]byte, error) {
}
// Extension to allow a node to hold ACL policies
type ACLPolicyExtension struct {
type ACLPolicyExt struct {
Policies map[PolicyType]Policy
}
// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies
type ACLExt struct {
Delegations NodeMap
}
func (ext ACLExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
return nil
}
func LoadACLExt(ctx *Context, data []byte) (Extension, error) {
var j struct {
Delegations []string `json:"delegation"`
}
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
delegations, err := RestoreNodeList(ctx, j.Delegations)
if err != nil {
return nil, err
}
return ACLExt{
Delegations: delegations,
}, nil
}
func (ext ACLExt) Serialize() ([]byte, error) {
delegations := make([]string, len(ext.Delegations))
i := 0
for id, _ := range(ext.Delegations) {
delegations[i] = id.String()
i += 1
}
return json.MarshalIndent(&struct{
Delegations []string `json:"delegations"`
}{
Delegations: delegations,
}, "", " ")
}
const ACLExtType = ExtType("ACL")
func (extension ACLExt) Type() ExtType {
return ACLExtType
}
type PolicyLoadFunc func(*Context, []byte) (Policy, error)
type PolicyInfo struct {
@ -43,11 +91,15 @@ type PolicyInfo struct {
Type PolicyType
}
type ACLPolicyExtensionContext struct {
type ACLPolicyExtContext struct {
Types map[PolicyType]PolicyInfo
}
func (ext ACLPolicyExtension) Serialize() ([]byte, error) {
func NewACLPolicyExtContext() *ACLPolicyExtContext {
return nil
}
func (ext ACLPolicyExt) Serialize() ([]byte, error) {
policies := map[string][]byte{}
for name, policy := range(ext.Policies) {
ser, err := policy.Serialize()
@ -64,11 +116,11 @@ func (ext ACLPolicyExtension) Serialize() ([]byte, error) {
}, "", " ")
}
func (ext ACLPolicyExtension) Process(context *StateContext, node *Node, signal GraphSignal) error {
func (ext ACLPolicyExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
return nil
}
func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) {
func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) {
var j struct {
Policies map[string][]byte `json:"policies"`
}
@ -78,7 +130,7 @@ func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) {
}
policies := map[PolicyType]Policy{}
acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(ACLPolicyExtensionContext)
acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(ACLPolicyExtContext)
for name, ser := range(j.Policies) {
policy_def, exists := acl_ctx.Types[PolicyType(name)]
if exists == false {
@ -92,18 +144,18 @@ func LoadACLPolicyExtension(ctx *Context, data []byte) (Extension, error) {
policies[PolicyType(name)] = policy
}
return ACLPolicyExtension{
return ACLPolicyExt{
Policies: policies,
}, nil
}
const ACLPolicyExtType = ExtType("ACL_POLICIES")
func (ext ACLPolicyExtension) Type() ExtType {
func (ext ACLPolicyExt) Type() ExtType {
return ACLPolicyExtType
}
// Check if the extension allows the principal to perform action on node
func (ext ACLPolicyExtension) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
func (ext ACLPolicyExt) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
for _, policy := range(ext.Policies) {
if policy.Allows(context, principal, action, node) == true {
return true

@ -8,6 +8,19 @@ import (
"encoding/json"
)
type QueuedAction struct {
Timeout time.Time `json:"time"`
Action string `json:"action"`
}
type ThreadExtContext struct {
Loads map[InfoType]func([]byte)ThreadInfo
}
func NewThreadExtContext() *ThreadExtContext {
return nil
}
type ThreadExt struct {
Actions ThreadActions
Handlers ThreadHandlers
@ -19,7 +32,7 @@ type ThreadExt struct {
ActiveLock sync.Mutex
Active bool
StateName string
State string
Parent *Node
Children map[NodeID]ChildInfo
@ -28,15 +41,94 @@ type ThreadExt struct {
NextAction *QueuedAction
}
type ThreadExtJSON struct {
State string `json:"state"`
Parent string `json:"parent"`
Children map[string][]byte `json:"children"`
ActionQueue []QueuedAction
}
func (ext *ThreadExt) Serialize() ([]byte, error) {
return nil, fmt.Errorf("NOT_IMPLEMENTED")
}
const THREAD_BUFFER_SIZE int = 1024
func LoadThreadExt(ctx *Context, data []byte) (Extension, error) {
var j ThreadExtJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
parent, err := RestoreNode(ctx, j.Parent)
if err != nil {
return nil, err
}
children := map[NodeID]ChildInfo{}
for id_str, _ := range(j.Children) {
child_node, err := RestoreNode(ctx, id_str)
if err != nil {
return nil, err
}
//TODO: Restore child info based off context
children[child_node.ID] = ChildInfo{
Child: child_node,
Infos: map[InfoType]ThreadInfo{},
}
}
next_action, timeout_chan := SoonestAction(j.ActionQueue)
extension := ThreadExt{
Actions: BaseThreadActions,
Handlers: BaseThreadHandlers,
SignalChan: make(chan GraphSignal, THREAD_BUFFER_SIZE),
TimeoutChan: timeout_chan,
Active: false,
State: j.State,
Parent: parent,
Children: children,
ActionQueue: j.ActionQueue,
NextAction: next_action,
}
return &extension, nil
}
const ThreadExtType = ExtType("THREAD")
func (ext *ThreadExt) Type() ExtType {
return ThreadExtType
}
func (ext *ThreadExt) QueueAction(end time.Time, action string) {
ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action})
ext.NextAction, ext.TimeoutChan = SoonestAction(ext.ActionQueue)
}
func (ext *ThreadExt) ClearActionQueue() {
ext.ActionQueue = []QueuedAction{}
ext.NextAction = nil
ext.TimeoutChan = nil
}
func SoonestAction(actions []QueuedAction) (*QueuedAction, <-chan time.Time) {
var soonest_action *QueuedAction
var soonest_time time.Time
for _, action := range(actions) {
if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil {
soonest_action = &action
soonest_time = action.Timeout
}
}
if soonest_action != nil {
return soonest_action, time.After(time.Until(soonest_action.Timeout))
} else {
return nil, nil
}
}
func (ext *ThreadExt) ChildList() []*Node {
ret := make([]*Node, len(ext.Children))
i := 0
@ -235,38 +327,6 @@ func NewChildInfo(child *Node, infos map[InfoType]ThreadInfo) ChildInfo {
}
}
type QueuedAction struct {
Timeout time.Time `json:"time"`
Action string `json:"action"`
}
func (ext *ThreadExt) QueueAction(end time.Time, action string) {
ext.ActionQueue = append(ext.ActionQueue, QueuedAction{end, action})
ext.NextAction, ext.TimeoutChan = ext.SoonestAction()
}
func (ext *ThreadExt) ClearActionQueue() {
ext.ActionQueue = []QueuedAction{}
ext.NextAction = nil
ext.TimeoutChan = nil
}
func (ext *ThreadExt) SoonestAction() (*QueuedAction, <-chan time.Time) {
var soonest_action *QueuedAction
var soonest_time time.Time
for _, action := range(ext.ActionQueue) {
if action.Timeout.Compare(soonest_time) == -1 || soonest_action == nil {
soonest_action = &action
soonest_time = action.Timeout
}
}
if soonest_action != nil {
return soonest_action, time.After(time.Until(soonest_action.Timeout))
} else {
return nil, nil
}
}
var deserializers = map[InfoType]func(interface{})(interface{}, error) {
"parent": func(raw interface{})(interface{}, error) {
m, ok := raw.(map[string]interface{})
@ -294,14 +354,14 @@ var deserializers = map[InfoType]func(interface{})(interface{}, error) {
},
}
func NewThreadExt(buffer int, name string, state_name string, actions ThreadActions, handlers ThreadHandlers) ThreadExt {
func NewThreadExt(buffer int, name string, state string, actions ThreadActions, handlers ThreadHandlers) ThreadExt {
return ThreadExt{
Actions: actions,
Handlers: handlers,
SignalChan: make(chan GraphSignal, buffer),
TimeoutChan: nil,
Active: false,
StateName: state_name,
State: state,
Parent: nil,
Children: map[NodeID]ChildInfo{},
ActionQueue: []QueuedAction{},
@ -322,7 +382,7 @@ func (ext *ThreadExt) SetActive(active bool) error {
}
func (ext *ThreadExt) SetState(state string) error {
ext.StateName = state
ext.State = state
return nil
}
@ -485,7 +545,7 @@ func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start boo
}
parent_info := info.Infos["parent"].(*ParentThreadInfo)
if parent_info.Start == true && child_ext.StateName != "finished" {
if parent_info.Start == true && child_ext.State != "finished" {
ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID)
if start == true {
ChildGo(ctx, thread_ext, info.Child, parent_info.StartAction)
@ -537,7 +597,7 @@ func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, err
context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewACLMap(NewACLInfo(thread, []string{"timeout"})), func(context *StateContext) error {
timeout_action = thread_ext.NextAction.Action
thread_ext.NextAction, thread_ext.TimeoutChan = thread_ext.SoonestAction()
thread_ext.NextAction, thread_ext.TimeoutChan = SoonestAction(thread_ext.ActionQueue)
return nil
})
if err != nil {

@ -8,47 +8,47 @@ import (
"crypto/x509"
)
type GroupNode interface {
Node
Users() map[NodeID]*User
}
type User struct {
Lockable
type ECDHExt struct {
Granted time.Time
Pubkey *ecdsa.PublicKey
Shared []byte
Tags []string
}
type UserJSON struct {
LockableJSON
type ECDHExtJSON struct {
Granted time.Time `json:"granted"`
Pubkey []byte `json:"pubkey"`
Shared []byte `json:"shared"`
}
func (user *User) Type() NodeType {
return NodeType("user")
func (ext *ECDHExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
return nil
}
func (user *User) Serialize() ([]byte, error) {
lockable_json := NewLockableJSON(&user.Lockable)
pubkey, err := x509.MarshalPKIXPublicKey(user.Pubkey)
const ECDHExtType = ExtType("ECDH")
func (ext *ECDHExt) Type() ExtType {
return ECDHExtType
}
func (ext *ECDHExt) Serialize() ([]byte, error) {
pubkey, err := x509.MarshalPKIXPublicKey(ext.Pubkey)
if err != nil {
return nil, err
}
return json.MarshalIndent(&UserJSON{
LockableJSON: lockable_json,
Granted: user.Granted,
Shared: user.Shared,
return json.MarshalIndent(&ECDHExtJSON{
Granted: ext.Granted,
Pubkey: pubkey,
Shared: ext.Shared,
}, "", " ")
}
var LoadUser = LoadJSONNode(func(id NodeID, j UserJSON) (Node, error) {
func LoadECDHExt(ctx *Context, data []byte) (Extension, error) {
var j ECDHExtJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
pub, err := x509.ParsePKIXPublicKey(j.Pubkey)
if err != nil {
return nil, err
@ -59,83 +59,56 @@ var LoadUser = LoadJSONNode(func(id NodeID, j UserJSON) (Node, error) {
case *ecdsa.PublicKey:
pubkey = pub.(*ecdsa.PublicKey)
default:
return nil, fmt.Errorf("Invalid key type")
return nil, fmt.Errorf("Invalid key type: %+v", pub)
}
user := NewUser(j.Name, j.Granted, pubkey, j.Shared)
return &user, nil
}, func(ctx *Context, user *User, j UserJSON, nodes NodeMap) error {
return RestoreLockable(ctx, user, j.LockableJSON, nodes)
})
func NewUser(name string, granted time.Time, pubkey *ecdsa.PublicKey, shared []byte) User {
id := KeyID(pubkey)
return User{
Lockable: NewLockable(id, name),
Granted: granted,
extension := ECDHExt{
Granted: j.Granted,
Pubkey: pubkey,
Shared: shared,
Shared: j.Shared,
}
}
type Group struct {
Lockable
UserMap map[NodeID]*User
return &extension, nil
}
func NewGroup(id NodeID, name string) Group {
return Group{
Lockable: NewLockable(id, name),
UserMap: map[NodeID]*User{},
}
type GroupExt struct {
Members NodeMap
}
type GroupJSON struct {
LockableJSON
Users []string `json:"users"`
const GroupExtType = ExtType("GROUP")
func (ext *GroupExt) Type() ExtType {
return GroupExtType
}
func (group *Group) Type() NodeType {
return NodeType("group")
func (ext *GroupExt) Serialize() ([]byte, error) {
return json.MarshalIndent(&struct{
Members []string `json:"members"`
}{
Members: SaveNodeList(ext.Members),
}, "", " ")
}
func (group *Group) Serialize() ([]byte, error) {
users := make([]string, len(group.UserMap))
i := 0
for id, _ := range(group.UserMap) {
users[i] = id.String()
i += 1
func LoadGroupExt(ctx *Context, data []byte) (Extension, error) {
var j struct {
Members []string `json:"members"`
}
return json.MarshalIndent(&GroupJSON{
LockableJSON: NewLockableJSON(&group.Lockable),
Users: users,
}, "", " ")
}
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
var LoadGroup = LoadJSONNode(func(id NodeID, j GroupJSON) (Node, error) {
group := NewGroup(id, j.Name)
return &group, nil
}, func(ctx *Context, group *Group, j GroupJSON, nodes NodeMap) error {
for _, id_str := range(j.Users) {
id, err := ParseID(id_str)
if err != nil {
return err
}
user_node, err := LoadNodeRecurse(ctx, id, nodes)
if err != nil {
return err
}
user, ok := user_node.(*User)
if ok == false {
return fmt.Errorf("%s is not a *User", id_str)
}
group.UserMap[id] = user
members, err := RestoreNodeList(ctx, j.Members)
if err != nil {
return nil, err
}
extension := GroupExt{
Members: members,
}
return &extension, nil
}
return RestoreLockable(ctx, group, j.LockableJSON, nodes)
})
func (ext *GroupExt) Process(context *StateContext, node *Node, signal GraphSignal) error {
return nil
}