diff --git a/gql.go b/gql.go index c3f28f3..628ba35 100644 --- a/gql.go +++ b/gql.go @@ -708,6 +708,8 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e } thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey) + nodes[id] = &thread + thread.Users = map[NodeID]*User{} for _, id_str := range(j.Users) { id, err := ParseID(id_str) @@ -720,7 +722,6 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e } thread.Users[id] = user.(*User) } - nodes[id] = &thread err = RestoreSimpleThread(ctx, &thread, j.SimpleThreadJSON, nodes) if err != nil { diff --git a/gql_graph.go b/gql_graph.go index 0e2c86e..ba504d4 100644 --- a/gql_graph.go +++ b/gql_graph.go @@ -894,14 +894,14 @@ func GQLMutationSendUpdate() *graphql.Field { }, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - server, ok := p.Context.Value("gql_server").(*GQLThread) - if ok == false { - return nil, fmt.Errorf("Failed to cast context gql_server to GQLServer: %+v", p.Context.Value("gql_server")) + ctx, server, user, err := GQLPrepResolve(p) + if err != nil { + return nil, err } - ctx, ok := p.Context.Value("graph_context").(*Context) - if ok == false { - return nil, fmt.Errorf("Failed to cast context graph_context to Context: %+v", p.Context.Value("graph_context")) + err = server.Allowed("signal", "self", user.ID()) + if err != nil { + return nil, err } signal_map, ok := p.Args["signal"].(map[string]interface{}) @@ -950,15 +950,40 @@ func GQLMutationSendUpdate() *graphql.Field { return gql_mutation_send_update } +func GQLPrepResolve(p graphql.ResolveParams) (*Context, *GQLThread, *User, error) { + context, ok := p.Context.Value("graph_context").(*Context) + if ok == false { + return nil, nil, nil, fmt.Errorf("failed to cast graph_context to *Context") + } + + server, ok := p.Context.Value("gql_server").(*GQLThread) + if ok == false { + return nil, nil, nil, fmt.Errorf("failed to cast gql_server to *GQLThread") + } + + user, ok := p.Context.Value("user").(*User) + if ok == false { + return nil, nil, nil, fmt.Errorf("failed to cast user to *User") + } + + return context, server, user, nil +} + var gql_query_self *graphql.Field = nil func GQLQuerySelf() *graphql.Field { if gql_query_self == nil { gql_query_self = &graphql.Field{ Type: GQLTypeGQLThread(), Resolve: func(p graphql.ResolveParams) (interface{}, error) { - server, ok := p.Context.Value("gql_server").(*GQLThread) - if ok == false { - return nil, fmt.Errorf("failed to cast gql_server to GQLThread") + _, server, user, err := GQLPrepResolve(p) + + if err != nil { + return nil, err + } + + err = server.Allowed("enumerate", "self", user.ID()) + if err != nil { + return nil, fmt.Errorf("User %s is not allowed to perform self.enumerate on %s", user.ID(), server.ID()) } return server, nil diff --git a/gql_test.go b/gql_test.go index 7944f2c..88420d7 100644 --- a/gql_test.go +++ b/gql_test.go @@ -22,6 +22,10 @@ func TestGQLThread(t * testing.T) { ctx := logTestContext(t, []string{}) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) + + p1_r := NewPerNodePolicy(RandID(), nil, NewNodeActions(nil, []string{"enumerate"})) + p1 := &p1_r + gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql_t := &gql_t_r @@ -30,9 +34,13 @@ func TestGQLThread(t * testing.T) { t2_r := NewSimpleThread(RandID(), "Test thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers) t2 := &t2_r - err = UpdateStates(ctx, []Node{gql_t, t1, t2}, func(nodes NodeMap) error { + err = UpdateStates(ctx, []Node{gql_t, t1, t2, p1}, func(nodes NodeMap) error { + err := gql_t.AddPolicy(p1) + if err != nil { + return err + } i1 := NewParentThreadInfo(true, "start", "restore") - err := LinkThreads(ctx, gql_t, t1, &i1, nodes) + err = LinkThreads(ctx, gql_t, t1, &i1, nodes) if err != nil { return err } @@ -71,7 +79,7 @@ func TestGQLDBLoad(t * testing.T) { u1_r := NewUser("Test User", time.Now(), &u1_key.PublicKey, u1_shared) u1 := &u1_r - p1_r := NewPerNodePolicy(RandID(), nil, NewNodeActions([]string{"*"})) + p1_r := NewPerNodePolicy(RandID(), nil, NewNodeActions(nil, []string{"enumerate"})) p1 := &p1_r key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -81,7 +89,7 @@ func TestGQLDBLoad(t * testing.T) { info := NewParentThreadInfo(true, "start", "restore") err = UpdateStates(ctx, []Node{gql, t1, l1, u1, p1}, func(nodes NodeMap) error { - err := u1.AddPolicy(p1) + err := gql.AddPolicy(p1) if err != nil { return err } @@ -152,13 +160,20 @@ func TestGQLDBLoad(t * testing.T) { } func TestGQLAuth(t * testing.T) { - ctx := logTestContext(t, []string{"test", "gql"}) + ctx := logTestContext(t, []string{"test", "gql", "db"}) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) + p1_r := NewPerNodePolicy(RandID(), nil, NewNodeActions(nil, []string{"*"})) + p1 := &p1_r + gql_t_r := NewGQLThread(RandID(), "GQL Thread", "init", ":0", ecdh.P256(), key, nil, nil) gql_t := &gql_t_r + err = UpdateStates(ctx, []Node{gql_t, p1}, func(nodes NodeMap) error { + return gql_t.AddPolicy(p1) + }) + done := make(chan error, 1) var update_channel chan GraphSignal diff --git a/node.go b/node.go index cbdbb77..6fd863e 100644 --- a/node.go +++ b/node.go @@ -69,7 +69,7 @@ type Node interface { ID() NodeID Type() NodeType - Allowed(action string, resource string, principal NodeID) bool + Allowed(action string, resource string, principal NodeID) error AddPolicy(Policy) error RemovePolicy(Policy) error @@ -100,13 +100,13 @@ func (node * GraphNode) Serialize() ([]byte, error) { return json.MarshalIndent(&node_json, "", " ") } -func (node *GraphNode) Allowed(action string, resource string, principal NodeID) bool { +func (node *GraphNode) Allowed(action string, resource string, principal NodeID) error { for _, policy := range(node.policies) { if policy.Allows(action, resource, principal) == true { - return true + return nil } } - return false + return fmt.Errorf("%s is not allowed to perform %s.%s on %s", principal.String(), resource, action, node.ID().String()) } func (node *GraphNode) AddPolicy(policy Policy) error { diff --git a/policy.go b/policy.go index 15b972c..6ddbbbb 100644 --- a/policy.go +++ b/policy.go @@ -4,17 +4,17 @@ import ( "encoding/json" ) -// A policy represents a set of rules attached to a Node that allow it to preform actions +// A policy represents a set of rules attached to a Node that allow principals to perform actions on it type Policy interface { Node - // Returns true if the policy allows the action on the given principal + // Returns true if the principal is allowed to perform the action on the resource Allows(action string, resource string, principal NodeID) bool } type NodeActions map[string][]string func (actions NodeActions) Allows(action string, resource string) bool { for _, a := range(actions[""]) { - if a == action { + if a == action || a == "*" { return true } } @@ -22,7 +22,7 @@ func (actions NodeActions) Allows(action string, resource string) bool { resource_actions, exists := actions[resource] if exists == true { for _, a := range(resource_actions) { - if a == action { + if a == action || a == "*" { return true } } @@ -31,14 +31,16 @@ func (actions NodeActions) Allows(action string, resource string) bool { return false } -func NewNodeActions(wildcard_actions []string) NodeActions { - actions := NodeActions{} +func NewNodeActions(resource_actions NodeActions, wildcard_actions []string) NodeActions { + if resource_actions == nil { + resource_actions = NodeActions{} + } // Wildcard actions, all actions in "" will be allowed on all resources if wildcard_actions == nil { wildcard_actions = []string{} } - actions[""] = wildcard_actions - return actions + resource_actions[""] = wildcard_actions + return resource_actions } type PerNodePolicy struct { @@ -76,7 +78,7 @@ func NewPerNodePolicy(id NodeID, node_actions map[NodeID]NodeActions, wildcard_a } if wildcard_actions == nil { - wildcard_actions = NewNodeActions(nil) + wildcard_actions = NewNodeActions(nil, nil) } return PerNodePolicy{ @@ -115,7 +117,6 @@ func LoadPerNodePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Nod } func (policy *PerNodePolicy) Allows(action string, resource string, principal NodeID) bool { - // Check wildcard actions if policy.WildcardActions.Allows(action, resource) == true { return true }