Major cleanup

gql_cataclysm
noah metz 2024-03-04 17:30:42 -07:00
parent e5776e0a14
commit 6942dc02db
21 changed files with 330 additions and 4574 deletions

233
acl.go

@ -1,233 +0,0 @@
package graphvent
import (
"github.com/google/uuid"
"slices"
"time"
)
type ACLSignal struct {
SignalHeader
Principal NodeID `gv:"principal"`
Action Tree `gv:"tree"`
}
func NewACLSignal(principal NodeID, action Tree) *ACLSignal {
return &ACLSignal{
SignalHeader: NewSignalHeader(Direct),
Principal: principal,
Action: action,
}
}
var DefaultACLPolicy = NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[ACLSignal]()): nil,
})
func (signal ACLSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[ACLSignal]()): nil,
}
}
type ACLExt struct {
Policies []Policy `gv:"policies"`
PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"`
Pending map[uuid.UUID]PendingACLSignal `gv:"pending"`
}
func NewACLExt(policies []Policy) *ACLExt {
return &ACLExt{
Policies: policies,
PendingACLs: map[uuid.UUID]PendingACL{},
Pending: map[uuid.UUID]PendingACLSignal{},
}
}
func (ext *ACLExt) Load(ctx *Context, node *Node) error {
return nil
}
func (ext *ACLExt) Unload(ctx *Context, node *Node) {
}
func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
response, is_response := signal.(ResponseSignal)
if is_response == true {
var messages Messages = nil
var changes = Changes{}
info, waiting := ext.Pending[response.ResponseID()]
if waiting == true {
changes.Add("pending")
delete(ext.Pending, response.ResponseID())
if response.ID() != info.Timeout {
err := node.DequeueSignal(info.Timeout)
if err != nil {
ctx.Log.Logf("acl", "timeout dequeue error: %s", err)
}
}
acl_info, found := ext.PendingACLs[info.ID]
if found == true {
acl_info.Counter -= 1
acl_info.Responses = append(acl_info.Responses, response)
policy_index := slices.IndexFunc(ext.Policies, func(policy Policy) bool {
return policy.ID() == info.Policy
})
if policy_index == -1 {
ctx.Log.Logf("acl", "pending signal for nonexistent policy")
delete(ext.PendingACLs, info.ID)
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else {
if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow {
changes.Add("pending_acls")
delete(ext.PendingACLs, info.ID)
ctx.Log.Logf("acl", "Request delayed allow")
messages = messages.Add(ctx, acl_info.Source, node, nil, NewSuccessSignal(info.ID))
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else if acl_info.Counter == 0 {
changes.Add("pending_acls")
delete(ext.PendingACLs, info.ID)
ctx.Log.Logf("acl", "Request delayed deny")
messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(info.ID, "acl_denied"))
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else {
node.PendingACLs[info.ID] = acl_info
changes.Add("pending_acls")
}
}
}
}
return messages, changes
}
var messages Messages = nil
var changes = Changes{}
switch sig := signal.(type) {
case *ACLSignal:
var acl_messages map[uuid.UUID]Messages = nil
denied := true
for _, policy := range(ext.Policies) {
policy_messages, result := policy.Allows(ctx, sig.Principal, sig.Action, node)
if result == Allow {
denied = false
break
} else if result == Pending {
if len(policy_messages) == 0 {
ctx.Log.Logf("acl", "Pending result for %s with no messages returned", policy.ID())
continue
} else if acl_messages == nil {
acl_messages = map[uuid.UUID]Messages{}
denied = false
}
acl_messages[policy.ID()] = policy_messages
ctx.Log.Logf("acl", "Pending result for %s:%s - %+v", node.ID, policy.ID(), acl_messages)
}
}
if denied == true {
ctx.Log.Logf("acl", "Request denied")
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "acl_denied"))
} else if acl_messages != nil {
ctx.Log.Logf("acl", "Request pending")
changes.Add("pending")
total_messages := 0
// TODO: reasonable timeout/configurable
timeout_time := time.Now().Add(time.Second)
for policy_id, policy_messages := range(acl_messages) {
total_messages += len(policy_messages)
for _, message := range(policy_messages) {
timeout_signal := NewTimeoutSignal(message.Signal.ID())
ext.Pending[message.Signal.ID()] = PendingACLSignal{
Policy: policy_id,
Timeout: timeout_signal.Id,
ID: sig.Id,
}
node.QueueSignal(timeout_time, timeout_signal)
messages = append(messages, message)
}
}
acl_timeout := NewACLTimeoutSignal(sig.Id)
node.QueueSignal(timeout_time, acl_timeout)
ext.PendingACLs[sig.Id] = PendingACL{
Counter: total_messages,
Responses: []ResponseSignal{},
TimeoutID: acl_timeout.Id,
Action: sig.Action,
Principal: sig.Principal,
Source: source,
Signal: signal,
}
} else {
ctx.Log.Logf("acl", "Request allowed")
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
}
// Test an action against the policy list, sending any intermediate signals necessary and seeting Pending and PendingACLs accordingly. Add a TimeoutSignal for every message awaiting a response, and an ACLTimeoutSignal for the overall request
case *ACLTimeoutSignal:
acl_info, exists := ext.PendingACLs[sig.ReqID]
if exists == true {
delete(ext.PendingACLs, sig.ReqID)
changes.Add("pending_acls")
ctx.Log.Logf("acl", "Request timeout deny")
messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(sig.ReqID, "acl_timeout"))
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else {
ctx.Log.Logf("acl", "ACL_TIMEOUT_SIGNAL for passed acl")
}
// Delete from PendingACLs
}
return messages, changes
}
type ACLProxyPolicy struct {
PolicyHeader
Proxies []NodeID `gv:"proxies"`
}
func NewACLProxyPolicy(proxies []NodeID) ACLProxyPolicy {
return ACLProxyPolicy{
NewPolicyHeader(),
proxies,
}
}
func (policy ACLProxyPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
if len(policy.Proxies) == 0 {
return nil, Deny
}
messages := Messages{}
for _, proxy := range(policy.Proxies) {
messages = messages.Add(ctx, proxy, node, nil, NewACLSignal(principal_id, action))
}
return messages, Pending
}
func (policy ACLProxyPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
_, is_success := signal.(*SuccessSignal)
if is_success == true {
return Allow
}
return Deny
}

@ -1,141 +0,0 @@
package graphvent
import (
"testing"
"time"
"reflect"
"runtime/debug"
)
func checkSignal[S Signal](t *testing.T, signal Signal, check func(S)){
response_casted, cast_ok := signal.(S)
if cast_ok == false {
error_signal, is_error := signal.(*ErrorSignal)
if is_error {
t.Log(string(debug.Stack()))
t.Fatal(error_signal.Error)
}
t.Fatalf("Response of wrong type %s", reflect.TypeOf(signal))
}
check(response_casted)
}
func testSendACL[S Signal](t *testing.T, ctx *Context, listener *Node, action Tree, policies []Policy, check func(S)){
acl_node, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(policies))
fatalErr(t, err)
acl_signal := NewACLSignal(listener.ID, action)
response, _ := testSend(t, ctx, acl_signal, listener, acl_node)
checkSignal(t, response, check)
}
func testErrorSignal(t *testing.T, error_string string) func(*ErrorSignal){
return func(response *ErrorSignal) {
if response.Error != error_string {
t.Fatalf("Wrong error: %s", response.Error)
}
}
}
func testSuccess(*SuccessSignal){}
func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) (ResponseSignal, []Signal) {
source_listener, err := GetExt[ListenerExt](source)
fatalErr(t, err)
messages := Messages{}
messages = messages.Add(ctx, destination.ID, source, nil, signal)
fatalErr(t, ctx.Send(messages))
response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID())
fatalErr(t, err)
return response, signals
}
func TestACLBasic(t *testing.T) {
ctx := logTestContext(t, []string{"test", "acl", "group", "read_field"})
listener, err := NewNode(ctx, nil, "Base", 100, nil, NewListenerExt(100))
fatalErr(t, err)
ctx.Log.Logf("test", "testing fail")
testSendACL(t, ctx, listener, nil, nil, testErrorSignal(t, "acl_denied"))
ctx.Log.Logf("test", "testing allow all")
testSendACL(t, ctx, listener, nil, []Policy{NewAllNodesPolicy(nil)}, testSuccess)
group, err := NewNode(ctx, nil, "Base", 100, []Policy{
DefaultGroupPolicy,
NewPerNodePolicy(map[NodeID]Tree{
listener.ID: {
SerializedType(SignalTypeFor[AddSubGroupSignal]()): nil,
SerializedType(SignalTypeFor[AddMemberSignal]()): nil,
},
}),
}, NewGroupExt(nil))
fatalErr(t, err)
ctx.Log.Logf("test", "testing empty groups")
testSendACL(t, ctx, listener, nil, []Policy{
NewMemberOfPolicy(map[NodeID]map[string]Tree{
group.ID: {
"test_group": nil,
},
}),
}, testErrorSignal(t, "acl_denied"))
ctx.Log.Logf("test", "testing adding group")
add_subgroup_signal := NewAddSubGroupSignal("test_group")
add_subgroup_response, _ := testSend(t, ctx, add_subgroup_signal, listener, group)
checkSignal(t, add_subgroup_response, testSuccess)
ctx.Log.Logf("test", "testing adding member")
add_member_signal := NewAddMemberSignal("test_group", listener.ID)
add_member_response, _ := testSend(t, ctx, add_member_signal, listener, group)
checkSignal(t, add_member_response, testSuccess)
ctx.Log.Logf("test", "testing group membership")
testSendACL(t, ctx, listener, nil, []Policy{
NewMemberOfPolicy(map[NodeID]map[string]Tree{
group.ID: {
"test_group": nil,
},
}),
}, testSuccess)
testSendACL(t, ctx, listener, nil, []Policy{
NewACLProxyPolicy(nil),
}, testErrorSignal(t, "acl_denied"))
acl_proxy_1, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(nil))
fatalErr(t, err)
testSendACL(t, ctx, listener, nil, []Policy{
NewACLProxyPolicy([]NodeID{acl_proxy_1.ID}),
}, testErrorSignal(t, "acl_denied"))
acl_proxy_2, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt([]Policy{NewAllNodesPolicy(nil)}))
fatalErr(t, err)
testSendACL(t, ctx, listener, nil, []Policy{
NewACLProxyPolicy([]NodeID{acl_proxy_2.ID}),
}, testSuccess)
acl_proxy_3, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy},
NewACLExt([]Policy{
NewMemberOfPolicy(map[NodeID]map[string]Tree{
group.ID: {
"test_group": nil,
},
}),
}),
)
fatalErr(t, err)
testSendACL(t, ctx, listener, nil, []Policy{
NewACLProxyPolicy([]NodeID{acl_proxy_3.ID}),
}, testSuccess)
}

@ -1,17 +1,17 @@
package graphvent
import (
"crypto/ecdh"
"errors"
"fmt"
"reflect"
"runtime"
"sync"
"time"
"github.com/google/uuid"
"github.com/graphql-go/graphql"
badger "github.com/dgraph-io/badger/v3"
"crypto/ecdh"
"errors"
"fmt"
"reflect"
"runtime"
"sync"
"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast"
badger "github.com/dgraph-io/badger/v3"
)
var (
@ -19,161 +19,81 @@ var (
ECDH = ecdh.X25519()
)
type ExtensionInfo struct {
Reflect reflect.Type
Interface graphql.Interface
type TypeInfo struct {
Type graphql.Type
}
type ExtensionInfo struct {
Interface *graphql.Interface
Fields map[string][]int
Data interface{}
}
type SignalInfo struct {
Type graphql.Type
}
type FieldIndex struct {
Extension ExtType
Field string
}
type NodeInfo struct {
GQL *graphql.Object
Extensions []ExtType
Policies []Policy
Fields map[string]FieldIndex
}
type GQLValueConverter func(*Context, interface{})(reflect.Value, error)
type TypeInfo struct {
Reflect reflect.Type
GQL graphql.Type
Type SerializedType
TypeSerialize TypeSerializeFn
Serialize SerializeFn
TypeDeserialize TypeDeserializeFn
Deserialize DeserializeFn
GQLValue GQLValueConverter
}
type KindInfo struct {
Reflect reflect.Kind
Base reflect.Type
Type SerializedType
TypeSerialize TypeSerializeFn
Serialize SerializeFn
TypeDeserialize TypeDeserializeFn
Deserialize DeserializeFn
}
// A Context stores all the data to run a graphvent process
type Context struct {
// DB is the database connection used to load and write nodes
DB * badger.DB
// Logging interface
Log Logger
// Mapped types
TypeMap map[SerializedType]TypeInfo
TypeTypes map[reflect.Type]SerializedType
// Map between database extension hashes and the registered info
Extensions map[ExtType]ExtensionInfo
ExtensionTypes map[reflect.Type]ExtType
// Map between databse policy hashes and the registered info
Policies map[PolicyType]reflect.Type
PolicyTypes map[reflect.Type]PolicyType
// Map between serialized signal hashes and the registered info
Signals map[SignalType]reflect.Type
SignalTypes map[reflect.Type]SignalType
// Map between database type hashes and the registered info
Nodes map[NodeType]NodeInfo
NodeTypes map[string]NodeType
// Map between go types and registered info
Types map[SerializedType]*TypeInfo
TypeReflects map[reflect.Type]*TypeInfo
Kinds map[reflect.Kind]*KindInfo
KindTypes map[SerializedType]*KindInfo
// Routing map to all the nodes local to this context
nodeMapLock sync.RWMutex
nodeMap map[NodeID]*Node
}
// Register a NodeType to the context, with the list of extensions it requires
func RegisterNodeType(ctx *Context, name string, extensions []ExtType, mappings map[string]FieldIndex) error {
node_type := NodeTypeFor(name, extensions, mappings)
_, exists := ctx.Nodes[node_type]
if exists == true {
return fmt.Errorf("Cannot register node type %+v, type already exists in context", node_type)
}
ext_found := map[ExtType]bool{}
for _, extension := range(extensions) {
_, in_ctx := ctx.Extensions[extension]
if in_ctx == false {
return fmt.Errorf("Cannot register node type %+v, required extension %+v not in context", node_type, extension)
}
_, duplicate := ext_found[extension]
if duplicate == true {
return fmt.Errorf("Duplicate extension %+v found in extension list", extension)
}
ext_found[extension] = true
}
ctx.Nodes[node_type] = NodeInfo{
Extensions: extensions,
Fields: mappings,
}
ctx.NodeTypes[name] = node_type
return nil
}
func RegisterPolicy[P Policy](ctx *Context) error {
reflect_type := reflect.TypeFor[P]()
policy_type := PolicyTypeFor[P]()
func BuildSchema(ctx *Context, query, mutation *graphql.Object) (graphql.Schema, error) {
types := []graphql.Type{}
_, exists := ctx.Policies[policy_type]
if exists == true {
return fmt.Errorf("Cannot register policy of type %+v, type already exists in context", policy_type)
}
subscription := graphql.NewObject(graphql.ObjectConfig{
policy_info, err := GetStructInfo(ctx, reflect_type)
if err != nil {
return err
}
})
err = RegisterType[P](ctx, nil, SerializeStruct(policy_info), nil, DeserializeStruct(policy_info))
if err != nil {
return err
}
ctx.Log.Logf("serialize_types", "Registered PolicyType: %+v - %+v", reflect_type, policy_type)
ctx.Policies[policy_type] = reflect_type
ctx.PolicyTypes[reflect_type] = policy_type
return nil
return graphql.NewSchema(graphql.SchemaConfig{
Types: types,
Query: query,
Subscription: subscription,
Mutation: mutation,
})
}
func RegisterSignal[S Signal](ctx *Context) error {
reflect_type := reflect.TypeFor[S]()
signal_type := SignalTypeFor[S]()
_, exists := ctx.Signals[signal_type]
if exists == true {
return fmt.Errorf("Cannot register signal of type %+v, type already exists in context", signal_type)
}
signal_info, err := GetStructInfo(ctx, reflect_type)
if err != nil {
return err
}
err = RegisterType[S](ctx, nil, SerializeStruct(signal_info), nil, DeserializeStruct(signal_info))
err := RegisterObject[S](ctx)
if err != nil {
return err
}
ctx.Log.Logf("serialize_types", "Registered SignalType: %+v - %+v", reflect_type, signal_type)
ctx.Signals[signal_type] = reflect_type
ctx.SignalTypes[reflect_type] = signal_type
return nil
}
@ -185,13 +105,51 @@ func RegisterExtension[E any, T interface { *E; Extension}](ctx *Context, data i
return fmt.Errorf("Cannot register extension %+v of type %+v, type already exists in context", reflect_type, ext_type)
}
elem_type := reflect_type.Elem()
elem_info, err := GetStructInfo(ctx, elem_type)
if err != nil {
return err
gql_interface := graphql.NewInterface(graphql.InterfaceConfig{
Name: reflect_type.String(),
ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return nil
}
node, ok := p.Value.(NodeResult)
if ok == false {
return nil
}
type_info, type_exists := ctx.Context.Nodes[node.NodeType]
if type_exists == false {
return ctx.Context.Nodes[ctx.Context.NodeTypes["Base"]].GQL
}
return type_info.GQL
},
Fields: graphql.Fields{
"ID": &graphql.Field{
Type: graphql.String,
},
},
})
fields := map[string][]int{}
for _, field := range reflect.VisibleFields(reflect.TypeFor[E]()) {
gv_tag, tagged_gv := field.Tag.Lookup("gv")
if tagged_gv {
fields[gv_tag] = field.Index
type_ser, type_mapped := ctx.TypeTypes[field.Type]
if type_mapped == false {
return fmt.Errorf("Extension %s has field %s of unregistered type %s", reflect_type, gv_tag, field.Type)
}
gql_interface.AddFieldConfig(gv_tag, &graphql.Field{
Type: ctx.TypeMap[type_ser].Type,
})
}
}
err = RegisterType[E](ctx, nil, SerializeStruct(elem_info), nil, DeserializeStruct(elem_info))
err := RegisterObject[E](ctx)
if err != nil {
return err
}
@ -199,111 +157,121 @@ func RegisterExtension[E any, T interface { *E; Extension}](ctx *Context, data i
ctx.Log.Logf("serialize_types", "Registered ExtType: %+v - %+v", reflect_type, ext_type)
ctx.Extensions[ext_type] = ExtensionInfo{
Reflect: reflect_type,
Interface: gql_interface,
Data: data,
Fields: fields,
}
ctx.ExtensionTypes[reflect_type] = ext_type
return nil
}
func RegisterKind(ctx *Context, kind reflect.Kind, base reflect.Type, type_serialize TypeSerializeFn, serialize SerializeFn, type_deserialize TypeDeserializeFn, deserialize DeserializeFn) error {
ctx_type := SerializedKindFor(kind)
_, exists := ctx.Kinds[kind]
if exists == true {
return fmt.Errorf("Cannot register kind %+v, kind already exists in context", kind)
}
_, exists = ctx.KindTypes[ctx_type]
func RegisterNodeType(ctx *Context, name string, extensions []ExtType, mappings map[string]FieldIndex) error {
node_type := NodeTypeFor(name, extensions, mappings)
_, exists := ctx.Nodes[node_type]
if exists == true {
return fmt.Errorf("0x%x is already registered, cannot use for %+v", ctx_type, kind)
}
if deserialize == nil {
return fmt.Errorf("Cannot register field without deserialize function")
}
if serialize == nil {
return fmt.Errorf("Cannot register field without serialize function")
return fmt.Errorf("Cannot register node type %+v, type already exists in context", node_type)
}
info := KindInfo{
Reflect: kind,
Type: ctx_type,
Base: base,
TypeSerialize: type_serialize,
Serialize: serialize,
TypeDeserialize: type_deserialize,
Deserialize: deserialize,
ext_found := map[ExtType]bool{}
for _, extension := range(extensions) {
_, in_ctx := ctx.Extensions[extension]
if in_ctx == false {
return fmt.Errorf("Cannot register node type %+v, required extension %+v not in context", node_type, extension)
}
_, duplicate := ext_found[extension]
if duplicate == true {
return fmt.Errorf("Duplicate extension %+v found in extension list", extension)
}
ext_found[extension] = true
}
ctx.KindTypes[ctx_type] = &info
ctx.Kinds[kind] = &info
ctx.Log.Logf("serialize_types", "Registered kind %+v, %+v", kind, ctx_type)
ctx.Nodes[node_type] = NodeInfo{
Extensions: extensions,
Fields: mappings,
}
ctx.NodeTypes[name] = node_type
return nil
}
func RegisterType[T any](ctx *Context, type_serialize TypeSerializeFn, serialize SerializeFn, type_deserialize TypeDeserializeFn, deserialize DeserializeFn) error {
func RegisterObject[T any](ctx *Context) error {
reflect_type := reflect.TypeFor[T]()
ctx_type := SerializedTypeFor[T]()
_, exists := ctx.Types[ctx_type]
if exists == true {
return fmt.Errorf("Cannot register field of type %+v, type already exists in context", ctx_type)
}
_, exists = ctx.TypeReflects[reflect_type]
if exists == true {
return fmt.Errorf("Cannot register field with type %+v, type already registered in context", reflect_type)
}
if type_serialize == nil || type_deserialize == nil {
kind_info, kind_registered := ctx.Kinds[reflect_type.Kind()]
if kind_registered == true {
if type_serialize == nil {
type_serialize = kind_info.TypeSerialize
}
if type_deserialize == nil {
type_deserialize = kind_info.TypeDeserialize
serialized_type := SerializedTypeFor[T]()
_, exists := ctx.TypeTypes[reflect_type]
if exists {
return fmt.Errorf("%+v already registered in TypeMap", reflect_type)
}
gql := graphql.NewObject(graphql.ObjectConfig{
Name: reflect_type.String(),
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
return reflect_type == reflect.TypeOf(p.Value)
},
Fields: graphql.Fields{},
})
for _, field := range(reflect.VisibleFields(reflect_type)) {
gv_tag, tagged_gv := field.Tag.Lookup("gv")
if tagged_gv {
field_type, mapped := ctx.TypeTypes[field.Type]
if mapped == false {
return fmt.Errorf("Object %+v has field %s of unknown type %+v", reflect_type, gv_tag, field_type)
}
gql.AddFieldConfig(gv_tag, &graphql.Field{
Type: ctx.TypeMap[field_type].Type,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
val, ok := p.Source.(T)
if ok == false {
return nil, fmt.Errorf("%s is not %s", reflect.TypeOf(p.Source), reflect_type)
}
value, err := reflect.ValueOf(val).FieldByIndexErr(field.Index)
if err != nil {
return nil, err
}
return value.Interface(), nil
},
})
}
}
if serialize == nil || deserialize == nil {
kind_info, kind_registered := ctx.Kinds[reflect_type.Kind()]
if kind_registered == false {
return fmt.Errorf("No serialize/deserialize passed and none registered for kind %+v", reflect_type.Kind())
} else {
if serialize == nil {
serialize = kind_info.Serialize
}
if deserialize == nil {
deserialize = kind_info.Deserialize
}
}
ctx.TypeTypes[reflect_type] = serialized_type
ctx.TypeMap[serialized_type] = TypeInfo{
Type: gql,
}
type_info := TypeInfo{
Reflect: reflect_type,
Type: ctx_type,
TypeSerialize: type_serialize,
Serialize: serialize,
TypeDeserialize: type_deserialize,
Deserialize: deserialize,
return nil
}
func RegisterScalar[T any](ctx *Context, to_json func(interface{})interface{}, from_json func(interface{})interface{}, from_ast func(ast.Value)interface{}) error {
reflect_type := reflect.TypeFor[T]()
serialized_type := SerializedTypeFor[T]()
_, exists := ctx.TypeTypes[reflect_type]
if exists {
return fmt.Errorf("%+v already registered in TypeMap", reflect_type)
}
ctx.Types[ctx_type] = &type_info
ctx.TypeReflects[reflect_type] = &type_info
gql := graphql.NewScalar(graphql.ScalarConfig{
Name: reflect_type.String(),
Serialize: to_json,
ParseValue: from_json,
ParseLiteral: from_ast,
})
ctx.Log.Logf("serialize_types", "Registered Type: %+v - %+v", reflect_type, ctx_type)
ctx.TypeTypes[reflect_type] = serialized_type
ctx.TypeMap[serialized_type] = TypeInfo{
Type: gql,
}
return nil
}
func RegisterStruct[T any](ctx *Context) error {
struct_info, err := GetStructInfo(ctx, reflect.TypeFor[T]())
if err != nil {
return err
}
return RegisterType[T](ctx, nil, SerializeStruct(struct_info), nil, DeserializeStruct(struct_info))
}
func (ctx *Context) AddNode(id NodeID, node *Node) {
ctx.nodeMapLock.Lock()
@ -364,7 +332,7 @@ func (ctx *Context) getNode(id NodeID) (*Node, error) {
}
// Route Messages to dest. Currently only local context routing is supported
func (ctx *Context) Send(messages Messages) error {
func (ctx *Context) Send(node *Node, messages []SendMsg) error {
for _, msg := range(messages) {
ctx.Log.Logf("signal", "Sending %s -> %+v", msg.Dest, msg)
if msg.Dest == ZeroID {
@ -373,7 +341,7 @@ func (ctx *Context) Send(messages Messages) error {
target, err := ctx.getNode(msg.Dest)
if err == nil {
select {
case target.MsgChan <- msg:
case target.MsgChan <- RecvMsg{node.ID, msg.Signal}:
ctx.Log.Logf("signal", "Sent %s -> %+v", target.ID, msg)
default:
buf := make([]byte, 4096)
@ -396,255 +364,27 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
ctx := &Context{
DB: db,
Log: log,
Policies: map[PolicyType]reflect.Type{},
PolicyTypes: map[reflect.Type]PolicyType{},
TypeMap: map[SerializedType]TypeInfo{},
TypeTypes: map[reflect.Type]SerializedType{},
Extensions: map[ExtType]ExtensionInfo{},
ExtensionTypes: map[reflect.Type]ExtType{},
Signals: map[SignalType]reflect.Type{},
SignalTypes: map[reflect.Type]SignalType{},
Nodes: map[NodeType]NodeInfo{},
NodeTypes: map[string]NodeType{},
Types: map[SerializedType]*TypeInfo{},
TypeReflects: map[reflect.Type]*TypeInfo{},
Kinds: map[reflect.Kind]*KindInfo{},
KindTypes: map[SerializedType]*KindInfo{},
nodeMap: map[NodeID]*Node{},
}
var err error
err = RegisterKind(ctx, reflect.Pointer, nil, SerializeTypeElem, SerializePointer, DeserializeTypePointer, DeserializePointer)
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Bool, reflect.TypeFor[bool](), nil, SerializeBool, nil, DeserializeBool[bool])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.String, reflect.TypeFor[string](), nil, SerializeString, nil, DeserializeString[string])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Float32, reflect.TypeFor[float32](), nil, SerializeFloat32, nil, DeserializeFloat32[float32])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Float64, reflect.TypeFor[float64](), nil, SerializeFloat64, nil, DeserializeFloat64[float64])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Uint, reflect.TypeFor[uint](), nil, SerializeUint32, nil, DeserializeUint32[uint])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Uint8, reflect.TypeFor[uint8](), nil, SerializeUint8, nil, DeserializeUint8[uint8])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Uint16, reflect.TypeFor[uint16](), nil, SerializeUint16, nil, DeserializeUint16[uint16])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Uint32, reflect.TypeFor[uint32](), nil, SerializeUint32, nil, DeserializeUint32[uint32])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Uint64, reflect.TypeFor[uint64](), nil, SerializeUint64, nil, DeserializeUint64[uint64])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Int, reflect.TypeFor[int](), nil, SerializeInt32, nil, DeserializeUint32[int])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Int8, reflect.TypeFor[int8](), nil, SerializeInt8, nil, DeserializeUint8[int8])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Int16, reflect.TypeFor[int16](), nil, SerializeInt16, nil, DeserializeUint16[int16])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Int32, reflect.TypeFor[int32](), nil, SerializeInt32, nil, DeserializeUint32[int32])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Int64, reflect.TypeFor[int64](), nil, SerializeInt64, nil, DeserializeUint64[int64])
if err != nil {
return nil, err
}
err = RegisterType[WaitReason](ctx, nil, nil, nil, DeserializeString[WaitReason])
if err != nil {
return nil, err
}
err = RegisterType[EventCommand](ctx, nil, nil, nil, DeserializeString[EventCommand])
if err != nil {
return nil, err
}
err = RegisterType[EventState](ctx, nil, nil, nil, DeserializeString[EventState])
if err != nil {
return nil, err
}
err = RegisterStruct[WaitInfo](ctx)
if err != nil {
return nil, err
}
err = RegisterType[time.Duration](ctx, nil, nil, nil, DeserializeUint64[time.Duration])
if err != nil {
return nil, err
}
err = RegisterType[time.Time](ctx, nil, SerializeGob, nil, DeserializeGob[time.Time])
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Map, nil, SerializeTypeMap, SerializeMap, DeserializeTypeMap, DeserializeMap)
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Array, nil, SerializeTypeArray, SerializeArray, DeserializeTypeArray, DeserializeArray)
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Slice, nil, SerializeTypeElem, SerializeSlice, DeserializeTypeSlice, DeserializeSlice)
if err != nil {
return nil, err
}
err = RegisterKind(ctx, reflect.Interface, reflect.TypeFor[interface{}](), nil, SerializeInterface, nil, DeserializeInterface)
if err != nil {
return nil, err
}
err = RegisterType[SerializedType](ctx, nil, SerializeUint64, nil, DeserializeUint64[SerializedType])
if err != nil {
return nil, err
}
err = RegisterType[Changes](ctx, SerializeTypeStub, SerializeMap, DeserializeTypeStub[Changes], DeserializeMap)
if err != nil {
return nil, err
}
err = RegisterType[ExtType](ctx, nil, SerializeUint64, nil, DeserializeUint64[ExtType])
if err != nil {
return nil, err
}
err = RegisterType[NodeType](ctx, nil, SerializeUint64, nil, DeserializeUint64[NodeType])
if err != nil {
return nil, err
}
err = RegisterType[PolicyType](ctx, nil, SerializeUint64, nil, DeserializeUint64[PolicyType])
if err != nil {
return nil, err
}
err = RegisterType[NodeID](ctx, SerializeTypeStub, SerializeUUID, DeserializeTypeStub[NodeID], DeserializeUUID[NodeID])
if err != nil {
return nil, err
}
err = RegisterType[uuid.UUID](ctx, SerializeTypeStub, SerializeUUID, DeserializeTypeStub[uuid.UUID], DeserializeUUID[uuid.UUID])
if err != nil {
return nil, err
}
err = RegisterType[SignalDirection](ctx, nil, SerializeUint8, nil, DeserializeUint8[SignalDirection])
if err != nil {
return nil, err
}
err = RegisterType[ReqState](ctx, nil, SerializeUint8, nil, DeserializeUint8[ReqState])
if err != nil {
return nil, err
}
err = RegisterType[Tree](ctx, SerializeTypeStub, nil, DeserializeTypeStub[Tree], nil)
if err != nil {
return nil, err
}
err = RegisterType[Extension](ctx, nil, SerializeInterface, nil, DeserializeInterface)
if err != nil {
return nil, err
}
err = RegisterType[Policy](ctx, nil, SerializeInterface, nil, DeserializeInterface)
if err != nil {
return nil, err
}
err = RegisterType[Signal](ctx, nil, SerializeInterface, nil, DeserializeInterface)
if err != nil {
return nil, err
}
err = RegisterStruct[PendingACL](ctx)
if err != nil {
return nil, err
}
err = RegisterStruct[PendingACLSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterStruct[QueuedSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterStruct[Node](ctx)
if err != nil {
return nil, err
}
err = RegisterExtension[LockableExt](ctx, nil)
if err != nil {
return nil, err
}
err = RegisterExtension[ListenerExt](ctx, nil)
if err != nil {
return nil, err
}
err = RegisterExtension[GroupExt](ctx, nil)
if err != nil {
return nil, err
}
gql_ctx := NewGQLExtContext()
err = RegisterExtension[GQLExt](ctx, gql_ctx)
if err != nil {
return nil, err
}
err = RegisterExtension[ACLExt](ctx, nil)
err = RegisterExtension[LockableExt](ctx, nil)
if err != nil {
return nil, err
}
@ -654,127 +394,10 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err
}
err = RegisterPolicy[OwnerOfPolicy](ctx)
if err != nil {
return nil, err
}
err = RegisterPolicy[ParentOfPolicy](ctx)
if err != nil {
return nil, err
}
err = RegisterPolicy[MemberOfPolicy](ctx)
if err != nil {
return nil, err
}
err = RegisterPolicy[AllNodesPolicy](ctx)
err = RegisterExtension[GQLExt](ctx, nil)
if err != nil {
return nil, err
}
err = RegisterPolicy[PerNodePolicy](ctx)
if err != nil {
return nil, err
}
err = RegisterPolicy[ACLProxyPolicy](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[AddSubGroupSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[RemoveSubGroupSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[ACLTimeoutSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[ACLSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[RemoveMemberSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[AddMemberSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[StatusSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[ReadSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[LockSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[TimeoutSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[LinkSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[ErrorSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[SuccessSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[ReadResultSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[EventControlSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterSignal[EventStateSignal](ctx)
if err != nil {
return nil, err
}
err = RegisterNodeType(ctx, "Base", []ExtType{}, map[string]FieldIndex{})
if err != nil {
return nil, err
}
schema, err := BuildSchema(gql_ctx)
if err != nil {
return nil, err
}
gql_ctx.Schema = schema
return ctx, nil
}

@ -8,40 +8,6 @@ import (
type EventCommand string
type EventState string
type ParentOfPolicy struct {
PolicyHeader
Policy Tree
}
func NewParentOfPolicy(policy Tree) *ParentOfPolicy {
return &ParentOfPolicy{
PolicyHeader: NewPolicyHeader(),
Policy: policy,
}
}
func (policy ParentOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
event_ext, err := GetExt[EventExt](node)
if err != nil {
ctx.Log.Logf("event", "ParentOfPolicy, node not event %s", node.ID)
return nil, Deny
}
if event_ext.Parent == principal_id {
return nil, policy.Policy.Allows(action)
}
return nil, Deny
}
func (policy ParentOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny
}
var DefaultEventPolicy = NewParentOfPolicy(Tree{
SerializedType(SignalTypeFor[EventControlSignal]()): nil,
})
type EventExt struct {
Name string `gv:"name"`
State EventState `gv:"state"`
@ -71,19 +37,13 @@ type EventStateSignal struct {
Time time.Time `gv:"time"`
}
func (signal EventStateSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[StatusSignal]()): nil,
}
}
func (signal EventStateSignal) String() string {
return fmt.Sprintf("EventStateSignal(%s, %s, %s, %+v)", signal.SignalHeader, signal.Source, signal.State, signal.Time)
}
func NewEventStateSignal(source NodeID, state EventState, t time.Time) *EventStateSignal {
return &EventStateSignal{
SignalHeader: NewSignalHeader(Up),
SignalHeader: NewSignalHeader(),
Source: source,
State: state,
Time: t,
@ -101,19 +61,11 @@ func (signal EventControlSignal) String() string {
func NewEventControlSignal(command EventCommand) *EventControlSignal {
return &EventControlSignal{
NewSignalHeader(Direct),
NewSignalHeader(),
command,
}
}
func (signal EventControlSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[EventControlSignal]()): {
Hash("command", string(signal.Command)): nil,
},
}
}
func (ext *EventExt) UpdateState(node *Node, changes Changes, state EventState, state_start time.Time) {
if ext.State != state {
ext.StateStart = state_start
@ -123,14 +75,10 @@ func (ext *EventExt) UpdateState(node *Node, changes Changes, state EventState,
}
}
func (ext *EventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
var messages Messages = nil
func (ext *EventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
var messages []SendMsg = nil
var changes = Changes{}
if signal.Direction() == Up && ext.Parent != node.ID {
messages = messages.Add(ctx, ext.Parent, node, nil, signal)
}
return messages, changes
}
@ -165,27 +113,27 @@ var test_event_commands = EventCommandMap{
}
func (ext *TestEventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
var messages Messages = nil
func (ext *TestEventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
var messages []SendMsg = nil
var changes = Changes{}
switch sig := signal.(type) {
case *EventControlSignal:
event_ext, err := GetExt[EventExt](node)
if err != nil {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_event"))
messages = append(messages, SendMsg{source, NewErrorSignal(sig.Id, "not_event")})
} else {
ctx.Log.Logf("event", "%s got %s EventControlSignal while in %s", node.ID, sig.Command, event_ext.State)
new_state, error_signal := event_ext.ValidateEventCommand(sig, test_event_commands)
if error_signal != nil {
messages = messages.Add(ctx, source, node, nil, error_signal)
messages = append(messages, SendMsg{source, error_signal})
} else {
switch sig.Command {
case "start":
node.QueueSignal(time.Now().Add(ext.Length), NewEventControlSignal("finish"))
}
event_ext.UpdateState(node, changes, new_state, time.Now())
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
messages = append(messages, SendMsg{source, NewSuccessSignal(sig.Id)})
}
}
}

@ -7,7 +7,7 @@ import (
// Extensions are data attached to nodes that process signals
type Extension interface {
// Called to process incoming signals, returning changes and messages to send
Process(*Context, *Node, NodeID, Signal) (Messages, Changes)
Process(*Context, *Node, NodeID, Signal) ([]SendMsg, Changes)
// Called when the node is loaded into a context(creation or move), so extension data can be initialized
Load(*Context, *Node) error

548
gql.go

@ -73,45 +73,6 @@ func NodeInterfaceDefaultIsType(required_extensions []ExtType) func(graphql.IsTy
}
}
func NodeInterfaceResolveType(required_extensions []ExtType, default_type **graphql.Object)func(graphql.ResolveTypeParams) *graphql.Object {
return func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return nil
}
node, ok := p.Value.(NodeResult)
if ok == false {
return nil
}
gql_type, exists := ctx.GQLContext.NodeTypes[node.NodeType]
ctx.Context.Log.Logf("gql", "GQL_INTERFACE_RESOLVE_TYPE(%+v): %+v - %t - %+v - %+v", node, gql_type, exists, required_extensions, *default_type)
if exists == false {
node_type_def, exists := ctx.Context.Nodes[node.NodeType]
if exists == false {
return nil
} else {
for _, ext := range(required_extensions) {
found := false
for _, e := range(node_type_def.Extensions) {
if e == ext {
found = true
break
}
}
if found == false {
return nil
}
}
}
return *default_type
}
return gql_type
}
}
func PrepResolve(p graphql.ResolveParams) (*ResolveContext, error) {
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
@ -315,9 +276,6 @@ type ResolveContext struct {
// Graph Context this resolver is running under
Context *Context
// GQL Extension context this resolver is running under
GQLContext *GQLExtContext
// Pointer to the node that's currently processing this request
Server *Node
@ -326,9 +284,6 @@ type ResolveContext struct {
// Cache of resolved nodes
NodeCache map[NodeID]NodeResult
// Authorization from the user that started this request
Authorization *ClientAuthorization
}
func AuthB64(client_key ed25519.PrivateKey, server_pubkey ed25519.PublicKey) (string, error) {
@ -409,141 +364,14 @@ func AuthB64(client_key ed25519.PrivateKey, server_pubkey ed25519.PublicKey) (st
return base64.StdEncoding.EncodeToString([]byte(strings.Join([]string{id_b64, iv_b64, key_b64, encrypted_b64, start_b64, sig_b64}, ":"))), nil
}
func ParseAuthB64(auth_base64 string, server_id ed25519.PrivateKey) (*ClientAuthorization, error) {
joined_b64, err := base64.StdEncoding.DecodeString(auth_base64)
if err != nil {
return nil, err
}
auth_parts := strings.Split(string(joined_b64), ":")
if len(auth_parts) != 6 {
return nil, fmt.Errorf("Wrong number of delimited elements %d", len(auth_parts))
}
id_bytes, err := base64.StdEncoding.DecodeString(auth_parts[0])
if err != nil {
return nil, err
}
iv, err := base64.StdEncoding.DecodeString(auth_parts[1])
if err != nil {
return nil, err
}
public_key, err := base64.StdEncoding.DecodeString(auth_parts[2])
if err != nil {
return nil, err
}
key_encrypted, err := base64.StdEncoding.DecodeString(auth_parts[3])
if err != nil {
return nil, err
}
start_bytes, err := base64.StdEncoding.DecodeString(auth_parts[4])
if err != nil {
return nil, err
}
signature, err := base64.StdEncoding.DecodeString(auth_parts[5])
if err != nil {
return nil, err
}
var start time.Time
err = start.UnmarshalBinary(start_bytes)
if err != nil {
return nil, err
}
client_id := ed25519.PublicKey(id_bytes)
if err != nil {
return nil, err
}
client_point, err := (&edwards25519.Point{}).SetBytes(public_key)
if err != nil {
return nil, err
}
ecdh_client, err := ECDH.NewPublicKey(client_point.BytesMontgomery())
if err != nil {
return nil, err
}
h := sha512.Sum512(server_id.Seed())
ecdh_server, err := ECDH.NewPrivateKey(h[:32])
if err != nil {
return nil, err
}
secret, err := ecdh_server.ECDH(ecdh_client)
if err != nil {
return nil, err
} else if len(secret) != 32 {
return nil, fmt.Errorf("Secret wrong length: %d/32", len(secret))
}
block, err := aes.NewCipher(secret)
if err != nil {
return nil, err
}
encrypted_reader := bytes.NewReader(key_encrypted)
stream := cipher.NewOFB(block, iv)
reader := cipher.StreamReader{S: stream, R: encrypted_reader}
var decrypted_key bytes.Buffer
_, err = io.Copy(&decrypted_key, reader)
if err != nil {
return nil, err
}
session_key := ed25519.NewKeyFromSeed(decrypted_key.Bytes())
digest := append(session_key.Public().(ed25519.PublicKey), start_bytes...)
if ed25519.Verify(client_id, digest, signature) == false {
return nil, fmt.Errorf("Failed to verify digest/signature against client_id")
}
return &ClientAuthorization{
AuthInfo: AuthInfo{
Identity: client_id,
Start: start,
Signature: signature,
},
Key: session_key,
}, nil
}
func ValidateAuthorization(auth Authorization, valid time.Duration) error {
// Check that the time + valid < now
// Check that Signature is public_key + start signed with client_id
if auth.Start.Add(valid).Compare(time.Now()) != 1 {
return fmt.Errorf("authorization expired")
}
time_bytes, err := auth.Start.MarshalBinary()
if err != nil {
return err
}
digest := append(auth.Key, time_bytes...)
if ed25519.Verify(auth.Identity, digest, auth.Signature) != true {
return fmt.Errorf("verification failed")
}
return nil
}
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveContext, error) {
return &ResolveContext{
ID: uuid.New(),
Ext: gql_ext,
Chans: map[uuid.UUID]chan Signal{},
Context: ctx,
GQLContext: ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext),
NodeCache: map[NodeID]NodeResult{},
Server: server,
Authorization: nil,
}, nil
}
@ -557,13 +385,6 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
}
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
auth, err := ParseAuthB64(r.Header.Get("Authorization"), server.Key)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ID_PARSE_ERROR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized(""))
return
}
resolve_context, err := NewResolveContext(ctx, server, gql_ext)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
@ -571,8 +392,6 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
return
}
resolve_context.Authorization = auth
req_ctx := context.Background()
req_ctx = context.WithValue(req_ctx, "resolve", resolve_context)
@ -585,10 +404,10 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
query := GQLPayload{}
json.Unmarshal(str, &query)
gql_context := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext)
schema := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(graphql.Schema)
params := graphql.Params{
Schema: gql_context.Schema,
Schema: schema,
Context: req_ctx,
RequestString: query.Query,
}
@ -716,14 +535,6 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
break
}
authorization, err := ParseAuthB64(connection_params.Payload.Token, server.Key)
if err != nil {
ctx.Log.Logf("gqlws", "WS_AUTH_PARSE_ERR: %s", err)
break
}
resolve_context.Authorization = authorization
conn_state = "ready"
err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"connection_ack\"}"))
if err != nil {
@ -739,9 +550,9 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
}
} else if msg.Type == "subscribe" {
ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload)
gql_context := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext)
schema := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(graphql.Schema)
params := graphql.Params{
Schema: gql_context.Schema,
Schema: schema,
Context: req_ctx,
RequestString: msg.Payload.Query,
}
@ -829,165 +640,10 @@ type Field struct {
Field *graphql.Field
}
// GQL Specific Context information
type GQLExtContext struct {
// Generated GQL schema
Schema graphql.Schema
// Custom graphql types, mapped to NodeTypes
NodeTypes map[NodeType]*graphql.Object
Interfaces map[string]*Interface
Fields map[string]Field
// Schema parameters
Types []graphql.Type
Query *graphql.Object
Mutation *graphql.Object
}
func (ctx *GQLExtContext) GetACLFields(obj_name string, names []string) (map[ExtType][]string, error) {
ext_fields := map[ExtType][]string{}
for _, name := range(names) {
switch name {
case "ID":
case "TypeHash":
default:
field, exists := ctx.Fields[name]
if exists == false {
continue
}
ext, exists := ext_fields[field.Ext]
if exists == false {
ext = []string{}
}
ext = append(ext, field.Name)
ext_fields[field.Ext] = ext
}
}
return ext_fields, nil
}
func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) {
schemaConfig := graphql.SchemaConfig{
Types: ctx.Types,
Query: ctx.Query,
Mutation: ctx.Mutation,
}
return graphql.NewSchema(schemaConfig)
}
func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, ext_type ExtType, gv_tag string, resolve_fn func(graphql.ResolveParams, *ResolveContext, reflect.Value)(interface{}, error)) error {
if ctx == nil {
return fmt.Errorf("ctx is nil")
}
if resolve_fn == nil {
return fmt.Errorf("resolve_fn cannot be nil")
}
_, exists := ctx.Fields[gql_name]
if exists == true {
return fmt.Errorf("%s is already a field in the context, cannot add again", gql_name)
}
// Resolver has p.Source.(NodeResult) = read result of current node
resolver := func(p graphql.ResolveParams)(interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(NodeResult)
if ok == false {
return nil, fmt.Errorf("p.Value is not NodeResult")
}
ext, ext_exists := node.Data[ext_type]
if ext_exists == false {
return nil, fmt.Errorf("%+v is not in the extensions of the result: %+v", ext_type, node.Data)
}
val_ser, field_exists := ext[gv_tag]
if field_exists == false {
return nil, fmt.Errorf("%s is not in the fields of %+v in the result for %s - %+v", gv_tag, ext_type, gql_name, node)
}
if val_ser.TypeStack[0] == SerializedTypeFor[error]() {
return nil, fmt.Errorf(string(val_ser.Data))
}
field_type, _, err := DeserializeType(ctx.Context, val_ser.TypeStack)
if err != nil {
return nil, err
}
field_value, _, err := DeserializeValue(ctx.Context, field_type, val_ser.Data)
if err != nil {
return nil, err
}
ctx.Context.Log.Logf("gql", "Resolving %+v", field_value)
return resolve_fn(p, ctx, field_value)
}
ctx.Fields[gql_name] = Field{ext_type, gv_tag, &graphql.Field{
Type: gql_type,
Resolve: resolver,
}}
return nil
}
func GQLInterfaces(ctx *GQLExtContext, interface_names []string) ([]*graphql.Interface, error) {
ret := make([]*graphql.Interface, len(interface_names))
for i, in := range(interface_names) {
ctx_interface, exists := ctx.Interfaces[in]
if exists == false {
return nil, fmt.Errorf("%s is not in GQLExtContext.Interfaces", in)
}
ret[i] = ctx_interface.Interface
}
return ret, nil
}
func GQLFields(ctx *GQLExtContext, field_names []string) (graphql.Fields, []ExtType, error) {
fields := graphql.Fields{
"ID": &graphql.Field{
Type: graphql.String,
Resolve: ResolveNodeID,
},
"TypeHash": &graphql.Field{
Type: graphql.String,
Resolve: ResolveNodeTypeHash,
},
}
exts := map[ExtType]ExtType{}
ext_list := []ExtType{}
for _, name := range(field_names) {
field, exists := ctx.Fields[name]
if exists == false {
return nil, nil, fmt.Errorf("%s is not in GQLExtContext.Fields", name)
}
fields[name] = field.Field
_, exists = exts[field.Ext]
if exists == false {
ext_list = append(ext_list, field.Ext)
exts[field.Ext] = field.Ext
}
}
return fields, ext_list, nil
}
type NodeResult struct {
NodeID NodeID
NodeType NodeType
Data map[ExtType]map[string]SerializedValue
Data map[ExtType]map[string]interface{}
}
type ListField struct {
@ -1002,193 +658,6 @@ type SelfField struct {
ResolveFn func(graphql.ResolveParams, *ResolveContext, reflect.Value) (*NodeID, error)
}
func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, interfaces []string, fields []string, self_fields map[string]SelfField, list_fields map[string]ListField) error {
if interfaces == nil {
return fmt.Errorf("interfaces is nil")
}
if fields == nil {
return fmt.Errorf("fields is nil")
}
_, exists := ctx.Interfaces[name]
if exists == true {
return fmt.Errorf("%s is already an interface in ctx", name)
}
node_interfaces, err := GQLInterfaces(ctx, interfaces)
if err != nil {
return err
}
node_fields, node_exts, err := GQLFields(ctx, fields)
if err != nil {
return err
}
ctx_interface := Interface{}
ctx_interface.Interface = graphql.NewInterface(graphql.InterfaceConfig{
Name: name,
ResolveType: NodeInterfaceResolveType(node_exts, &ctx_interface.Default),
Fields: node_fields,
})
ctx_interface.List = graphql.NewList(ctx_interface.Interface)
for field_name, field := range(self_fields) {
self_field := field
err := ctx.RegisterField(ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName,
func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) {
id, err := self_field.ResolveFn(p, ctx, value)
if err != nil {
return nil, err
}
if id != nil {
nodes, err := ResolveNodes(ctx, p, []NodeID{*id})
if err != nil {
return nil, err
} else if len(nodes) != 1 {
return nil, fmt.Errorf("wrong length of nodes returned")
}
return nodes[0], nil
} else {
return nil, nil
}
})
if err != nil {
return err
}
ctx_interface.Interface.AddFieldConfig(field_name, ctx.Fields[field_name].Field)
node_fields[field_name] = ctx.Fields[field_name].Field
}
for field_name, field := range(list_fields) {
list_field := field
resolve_fn := func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) {
var zero NodeID
ids, err := list_field.ResolveFn(p, ctx, value)
if err != nil {
return zero, err
}
nodes, err := ResolveNodes(ctx, p, ids)
if err != nil {
return nil, err
} else if len(nodes) != len(ids) {
return nil, fmt.Errorf("wrong length of nodes returned")
}
return nodes, nil
}
err := ctx.RegisterField(ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn)
if err != nil {
return err
}
ctx_interface.Interface.AddFieldConfig(field_name, ctx.Fields[field_name].Field)
node_fields[field_name] = ctx.Fields[field_name].Field
}
ctx_interface.Default = graphql.NewObject(graphql.ObjectConfig{
Name: default_name,
Interfaces: append(node_interfaces, ctx_interface.Interface),
IsTypeOf: NodeInterfaceDefaultIsType(node_exts),
Fields: node_fields,
})
ctx.Interfaces[name] = &ctx_interface
ctx.Types = append(ctx.Types, ctx_interface.Default)
return nil
}
func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, name string, interface_names []string, field_names []string) error {
if field_names == nil {
return fmt.Errorf("fields is nil")
}
_, exists := ctx.NodeTypes[node_type]
if exists == true {
return fmt.Errorf("%+v already in GQLExtContext.NodeTypes", node_type)
}
node_interfaces, err := GQLInterfaces(ctx, interface_names)
if err != nil {
return err
}
gql_fields, _, err := GQLFields(ctx, field_names)
if err != nil {
return err
}
gql_type := graphql.NewObject(graphql.ObjectConfig{
Name: name,
Interfaces: node_interfaces,
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
node, ok := p.Value.(NodeResult)
if ok == false {
return false
}
return node.NodeType == node_type
},
Fields: gql_fields,
})
ctx.NodeTypes[node_type] = gql_type
ctx.Types = append(ctx.Types, gql_type)
return nil
}
func NewGQLExtContext() *GQLExtContext {
query := graphql.NewObject(graphql.ObjectConfig{
Name: "Query",
Fields: graphql.Fields{
"Test": &graphql.Field{
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return "Test Data", nil
},
},
},
})
mutation := graphql.NewObject(graphql.ObjectConfig{
Name: "Mutation",
Fields: graphql.Fields{
"Test": &graphql.Field{
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return "Test Mutation Data", nil
},
},
},
})
context := GQLExtContext{
Schema: graphql.Schema{},
Types: []graphql.Type{},
Query: query,
Mutation: mutation,
NodeTypes: map[NodeType]*graphql.Object{},
Interfaces: map[string]*Interface{},
Fields: map[string]Field{},
}
schema, err := BuildSchema(&context)
if err != nil {
panic(err)
}
context.Schema = schema
return &context
}
type SubscriptionInfo struct {
ID uuid.UUID
Channel chan interface{}
@ -1295,9 +764,10 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal {
return response_chan
}
func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
// Process ReadResultSignalType by forwarding it to the waiting resolver
var changes = Changes{}
var changes Changes = nil
var messages []SendMsg = nil
switch sig := signal.(type) {
case *SuccessSignal:
@ -1355,7 +825,7 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
ext.subscriptions_lock.RUnlock()
}
return nil, changes
return messages, changes
}
var ecdsa_curves = map[uint8]elliptic.Curve{

@ -1,11 +1,9 @@
package graphvent
import (
"time"
"reflect"
"fmt"
"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast"
"github.com/google/uuid"
)
func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) {
@ -54,136 +52,3 @@ func GetResolveFields(ctx *Context, p graphql.ResolveParams) []string {
return names
}
func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([]NodeResult, error) {
fields := GetResolveFields(ctx.Context, p)
ctx.Context.Log.Logf("gql_resolve_node", "RESOLVE_NODES(%+v): %+v", ids, fields)
resp_channels := map[uuid.UUID]chan Signal{}
indices := map[uuid.UUID]int{}
// Get a list of fields that will be written
ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields)
if err != nil {
return nil, err
}
ctx.Context.Log.Logf("gql_resolve_node", "ACL Fields from request: %+v", ext_fields)
responses := make([]NodeResult, len(ids))
for i, id := range(ids) {
var read_signal *ReadSignal = nil
node, cached := ctx.NodeCache[id]
if cached == true {
resolve := false
missing_exts := map[ExtType][]string{}
for ext_type, fields := range(ext_fields) {
cached_ext, exists := node.Data[ext_type]
if exists == true {
missing_fields := []string{}
for _, field_name := range(fields) {
_, found := cached_ext[field_name]
if found == false {
missing_fields = append(missing_fields, field_name)
}
}
if len(missing_fields) > 0 {
missing_exts[ext_type] = missing_fields
resolve = true
}
} else {
missing_exts[ext_type] = fields
resolve = true
}
}
if resolve == true {
read_signal = NewReadSignal(missing_exts)
ctx.Context.Log.Logf("gql_resolve_node", "sending read for %+v because of missing fields %+v", id, missing_exts)
} else {
ctx.Context.Log.Logf("gql_resolve_node", "Using cached response for %+v(%d)", id, i)
responses[i] = node
continue
}
} else {
ctx.Context.Log.Logf("gql_resolve_node", "sending read for %+v", id)
read_signal = NewReadSignal(ext_fields)
}
// Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error
msgs := Messages{}
msgs = msgs.Add(ctx.Context, id, ctx.Server, ctx.Authorization, read_signal)
response_chan := ctx.Ext.GetResponseChannel(read_signal.ID())
resp_channels[read_signal.ID()] = response_chan
indices[read_signal.ID()] = i
// TODO: Send all at once instead of creating Messages for each
err = ctx.Context.Send(msgs)
if err != nil {
ctx.Ext.FreeResponseChannel(read_signal.ID())
return nil, err
}
}
errors := ""
for sig_id, response_chan := range(resp_channels) {
// Wait for the response, returning an error on timeout
response, other, err := WaitForResponse(response_chan, time.Millisecond*100, sig_id)
if err != nil {
return nil, err
}
ctx.Context.Log.Logf("gql_resolve_node", "GQL node response: %+v", response)
ctx.Context.Log.Logf("gql_resolve_node", "GQL node other messages: %+v", other)
// for now, just put signals we didn't want back into the 'queue'
for _, other_signal := range(other) {
response_chan <- other_signal
}
error_signal, is_error := response.(*ErrorSignal)
if is_error {
errors = fmt.Sprintf("%s, %s", errors, error_signal.Error)
continue
}
read_response, is_read_response := response.(*ReadResultSignal)
if is_read_response == false {
errors = fmt.Sprintf("%s, wrong response type %+v", errors, reflect.TypeOf(response))
continue
}
idx := indices[sig_id]
responses[idx] = NodeResult{
read_response.NodeID,
read_response.NodeType,
read_response.Extensions,
}
cache, exists := ctx.NodeCache[read_response.NodeID]
if exists == true {
ctx.Context.Log.Logf("gql_resolve_node", "Merging new response with cached: %s, %+v - %+v", read_response.NodeID, cache, read_response.Extensions)
for ext_type, fields := range(read_response.Extensions) {
cached_fields, exists := cache.Data[ext_type]
if exists == false {
cached_fields = map[string]SerializedValue{}
cache.Data[ext_type] = cached_fields
}
for field_name, field_value := range(fields) {
cached_fields[field_name] = field_value
}
}
responses[idx] = cache
} else {
ctx.Context.Log.Logf("gql_resolve_node", "Adding new response to node cache: %s, %+v", read_response.NodeID, read_response.Extensions)
ctx.NodeCache[read_response.NodeID] = responses[idx]
}
}
if errors != "" {
return nil, fmt.Errorf(errors)
}
ctx.Context.Log.Logf("gql_resolve_node", "RESOLVED_NODES %+v - %+v", ids, responses)
return responses, nil
}

@ -1,145 +0,0 @@
package graphvent
import (
graphql "github.com/graphql-go/graphql"
"github.com/google/uuid"
"reflect"
"fmt"
"time"
)
type StructFieldInfo struct {
Name string
Type *TypeInfo
Index []int
}
func ArgumentInfo(ctx *Context, field reflect.StructField, gv_tag string) (StructFieldInfo, error) {
type_info, mapped := ctx.TypeReflects[field.Type]
if mapped == false {
return StructFieldInfo{}, fmt.Errorf("field %+v is of unregistered type %+v ", field.Name, field.Type)
}
return StructFieldInfo{
Name: gv_tag,
Type: type_info,
Index: field.Index,
}, nil
}
func SignalFromArgs(ctx *Context, signal_type reflect.Type, fields []StructFieldInfo, args map[string]interface{}, id_index, direction_index []int) (Signal, error) {
fmt.Printf("FIELD: %+v\n", fields)
signal_value := reflect.New(signal_type)
id_field := signal_value.Elem().FieldByIndex(id_index)
id_field.Set(reflect.ValueOf(uuid.New()))
direction_field := signal_value.Elem().FieldByIndex(direction_index)
direction_field.Set(reflect.ValueOf(Direct))
for _, field := range(fields) {
arg, arg_exists := args[field.Name]
if arg_exists == false {
return nil, fmt.Errorf("No arg provided named %s", field.Name)
}
field_value := signal_value.Elem().FieldByIndex(field.Index)
if field_value.CanConvert(field.Type.Reflect) == false {
return nil, fmt.Errorf("Arg %s wrong type %s/%s", field.Name, field_value.Type(), field.Type.Reflect)
}
value, err := field.Type.GQLValue(ctx, arg)
if err != nil {
return nil, err
}
fmt.Printf("Setting %s to %+v of type %+v\n", field.Name, value, value.Type())
field_value.Set(value)
}
return signal_value.Interface().(Signal), nil
}
func NewSignalMutation(ctx *Context, name string, send_id_key string, signal_type reflect.Type) (*graphql.Field, error) {
args := graphql.FieldConfigArgument{}
arg_info := []StructFieldInfo{}
var id_index []int = nil
var direction_index []int = nil
for _, field := range(reflect.VisibleFields(signal_type)) {
gv_tag, tagged_gv := field.Tag.Lookup("gv")
if tagged_gv {
if gv_tag == "id" {
id_index = field.Index
} else if gv_tag == "direction" {
direction_index = field.Index
} else {
_, exists := args[gv_tag]
if exists == true {
return nil, fmt.Errorf("Signal has repeated tag %s", gv_tag)
} else {
info, err := ArgumentInfo(ctx, field, gv_tag)
if err != nil {
return nil, err
}
args[gv_tag] = &graphql.ArgumentConfig{
}
arg_info = append(arg_info, info)
}
}
}
}
_, send_exists := args[send_id_key]
if send_exists == false {
args[send_id_key] = &graphql.ArgumentConfig{
Type: graphql.String,
}
}
resolve_signal := func(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
send_id, err := ExtractID(p, send_id_key)
if err != nil {
return nil, err
}
signal, err := SignalFromArgs(ctx.Context, signal_type, arg_info, p.Args, id_index, direction_index)
if err != nil {
return nil, err
}
msgs := Messages{}
msgs = msgs.Add(ctx.Context, send_id, ctx.Server, ctx.Authorization, signal)
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
err = ctx.Context.Send(msgs)
if err != nil {
ctx.Ext.FreeResponseChannel(signal.ID())
return nil, err
}
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
if err != nil {
ctx.Ext.FreeResponseChannel(signal.ID())
return nil, err
}
_, is_success := response.(*SuccessSignal)
if is_success == true {
return "success", nil
}
error_response, is_error := response.(*ErrorSignal)
if is_error == true {
return "error", fmt.Errorf(error_response.Error)
}
return nil, fmt.Errorf("response of unhandled type %s", reflect.TypeOf(response))
}
return &graphql.Field{
Type: graphql.String,
Args: args,
Resolve: resolve_signal,
}, nil
}

@ -3,6 +3,7 @@ package graphvent
import (
"testing"
"runtime/debug"
"time"
badger "github.com/dgraph-io/badger/v3"
)
@ -44,3 +45,16 @@ func fatalErr(t * testing.T, err error) {
t.Fatal(err)
}
}
func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) (ResponseSignal, []Signal) {
source_listener, err := GetExt[ListenerExt](source)
fatalErr(t, err)
messages := []SendMsg{{destination.ID, signal}}
fatalErr(t, ctx.Send(source, messages))
response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID())
fatalErr(t, err)
return response, signals
}

@ -1,296 +0,0 @@
package graphvent
import (
"slices"
)
type AddSubGroupSignal struct {
SignalHeader
Name string `gv:"name"`
}
func NewAddSubGroupSignal(name string) *AddSubGroupSignal {
return &AddSubGroupSignal{
NewSignalHeader(Direct),
name,
}
}
func (signal AddSubGroupSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[AddSubGroupSignal]()): {
Hash("name", signal.Name): nil,
},
}
}
type RemoveSubGroupSignal struct {
SignalHeader
Name string `gv:"name"`
}
func NewRemoveSubGroupSignal(name string) *RemoveSubGroupSignal {
return &RemoveSubGroupSignal{
NewSignalHeader(Direct),
name,
}
}
func (signal RemoveSubGroupSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[RemoveSubGroupSignal]()): {
Hash("command", signal.Name): nil,
},
}
}
type AddMemberSignal struct {
SignalHeader
SubGroup string `gv:"sub_group"`
MemberID NodeID `gv:"member_id"`
}
type SubGroupGQL struct {
Name string
Members []NodeID
}
func (signal AddMemberSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[AddMemberSignal]()): {
Hash("sub_group", signal.SubGroup): nil,
},
}
}
func NewAddMemberSignal(sub_group string, member_id NodeID) *AddMemberSignal {
return &AddMemberSignal{
NewSignalHeader(Direct),
sub_group,
member_id,
}
}
type RemoveMemberSignal struct {
SignalHeader
SubGroup string `gv:"sub_group"`
MemberID NodeID `gv:"member_id"`
}
func (signal RemoveMemberSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[RemoveMemberSignal]()): {
Hash("sub_group", signal.SubGroup): nil,
},
}
}
func NewRemoveMemberSignal(sub_group string, member_id NodeID) *RemoveMemberSignal {
return &RemoveMemberSignal{
NewSignalHeader(Direct),
sub_group,
member_id,
}
}
var DefaultGroupPolicy = NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[ReadSignal]()): {
SerializedType(ExtTypeFor[GroupExt]()): {
SerializedType(GetFieldTag("sub_groups")): nil,
},
},
})
type SubGroup struct {
Members []NodeID
Permissions Tree
}
type MemberOfPolicy struct {
PolicyHeader
Groups map[NodeID]map[string]Tree
}
func NewMemberOfPolicy(groups map[NodeID]map[string]Tree) MemberOfPolicy {
return MemberOfPolicy{
PolicyHeader: NewPolicyHeader(),
Groups: groups,
}
}
func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
sig, ok := signal.(*ReadResultSignal)
if ok == false {
return Deny
}
ctx.Log.Logf("group", "member_of_read_result: %+v", sig.Extensions)
group_ext_data, ok := sig.Extensions[ExtTypeFor[GroupExt]()]
if ok == false {
return Deny
}
sub_groups_ser, ok := group_ext_data["sub_groups"]
if ok == false {
return Deny
}
sub_groups_type, _, err := DeserializeType(ctx, sub_groups_ser.TypeStack)
if err != nil {
ctx.Log.Logf("group", "Type deserialize error: %s", err)
return Deny
}
sub_groups_if, _, err := DeserializeValue(ctx, sub_groups_type, sub_groups_ser.Data)
if err != nil {
ctx.Log.Logf("group", "Value deserialize error: %s", err)
return Deny
}
ext_sub_groups, ok := sub_groups_if.Interface().(map[string][]NodeID)
if ok == false {
return Deny
}
group, exists := policy.Groups[sig.NodeID]
if exists == false {
return Deny
}
for sub_group_name, permissions := range(group) {
ext_sub_group, exists := ext_sub_groups[sub_group_name]
if exists == true {
for _, member_id := range(ext_sub_group) {
if member_id == current.Principal {
if permissions.Allows(current.Action) == Allow {
return Allow
}
}
}
}
}
return Deny
}
// Send a read signal to Group to check if principal_id is a member of it
func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
var messages Messages = nil
for group_id, sub_groups := range(policy.Groups) {
if group_id == node.ID {
ext, err := GetExt[GroupExt](node)
if err != nil {
ctx.Log.Logf("group", "MemberOfPolicy with self ID error: %s", err)
} else {
for sub_group_name, permission := range(sub_groups) {
ext_sub_group, exists := ext.SubGroups[sub_group_name]
if exists == true {
for _, member := range(ext_sub_group) {
if member == principal_id {
if permission.Allows(action) == Allow {
return nil, Allow
}
break
}
}
}
}
}
} else {
// Send the read request to the group so that ContinueAllows can parse the response to check membership
messages = messages.Add(ctx, group_id, node, nil, NewReadSignal(map[ExtType][]string{
ExtTypeFor[GroupExt](): {"sub_groups"},
}))
}
}
if len(messages) > 0 {
return messages, Pending
} else {
return nil, Deny
}
}
type GroupExt struct {
SubGroups map[string][]NodeID `gv:"sub_groups"`
}
func NewGroupExt(sub_groups map[string][]NodeID) *GroupExt {
if sub_groups == nil {
sub_groups = map[string][]NodeID{}
}
return &GroupExt{
SubGroups: sub_groups,
}
}
func (ext *GroupExt) Load(ctx *Context, node *Node) error {
return nil
}
func (ext *GroupExt) Unload(ctx *Context, node *Node) {
}
func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
var messages Messages = nil
var changes = Changes{}
switch sig := signal.(type) {
case *AddMemberSignal:
sub_group, exists := ext.SubGroups[sig.SubGroup]
if exists == false {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
} else {
if slices.Contains(sub_group, sig.MemberID) == true {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_member"))
} else {
sub_group = append(sub_group, sig.MemberID)
ext.SubGroups[sig.SubGroup] = sub_group
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
changes.Add("sub_groups")
}
}
case *RemoveMemberSignal:
sub_group, exists := ext.SubGroups[sig.SubGroup]
if exists == false {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
} else {
idx := slices.Index(sub_group, sig.MemberID)
if idx == -1 {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_member"))
} else {
sub_group = slices.Delete(sub_group, idx, idx+1)
ext.SubGroups[sig.SubGroup] = sub_group
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
changes.Add("sub_groups")
}
}
case *AddSubGroupSignal:
_, exists := ext.SubGroups[sig.Name]
if exists == true {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_subgroup"))
} else {
ext.SubGroups[sig.Name] = []NodeID{}
changes.Add("sub_groups")
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
}
case *RemoveSubGroupSignal:
_, exists := ext.SubGroups[sig.Name]
if exists == false {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
} else {
delete(ext.SubGroups, sig.Name)
changes.Add("sub_groups")
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
}
}
return messages, changes
}

@ -1,94 +0,0 @@
package graphvent
import (
"testing"
"time"
)
func TestGroupAdd(t *testing.T) {
ctx := logTestContext(t, []string{"listener", "test"})
group_listener := NewListenerExt(10)
group, err := NewNode(ctx, nil, "Base", 10, nil, group_listener, NewGroupExt(nil))
fatalErr(t, err)
add_subgroup_signal := NewAddSubGroupSignal("test_group")
messages := Messages{}
messages = messages.Add(ctx, group.ID, group, nil, add_subgroup_signal)
fatalErr(t, ctx.Send(messages))
resp_1, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_subgroup_signal.Id)
fatalErr(t, err)
error_1, is_error := resp_1.(*ErrorSignal)
if is_error {
t.Fatalf("Error returned: %s", error_1.Error)
}
user_id := RandID()
add_member_signal := NewAddMemberSignal("test_group", user_id)
messages = Messages{}
messages = messages.Add(ctx, group.ID, group, nil, add_member_signal)
fatalErr(t, ctx.Send(messages))
resp_2, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_member_signal.Id)
fatalErr(t, err)
error_2, is_error := resp_2.(*ErrorSignal)
if is_error {
t.Fatalf("Error returned: %s", error_2.Error)
}
read_signal := NewReadSignal(map[ExtType][]string{
ExtTypeFor[GroupExt](): {"sub_groups"},
})
messages = Messages{}
messages = messages.Add(ctx, group.ID, group, nil, read_signal)
fatalErr(t, ctx.Send(messages))
response, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, read_signal.Id)
fatalErr(t, err)
read_response := response.(*ReadResultSignal)
sub_groups_serialized := read_response.Extensions[ExtTypeFor[GroupExt]()]["sub_groups"]
sub_groups_type, remaining_types, err := DeserializeType(ctx, sub_groups_serialized.TypeStack)
fatalErr(t, err)
if len(remaining_types) > 0 {
t.Fatalf("Types remaining after deserializing subgroups: %d", len(remaining_types))
}
sub_groups_value, remaining, err := DeserializeValue(ctx, sub_groups_type, sub_groups_serialized.Data)
fatalErr(t, err)
if len(remaining) > 0 {
t.Fatalf("Data remaining after deserializing subgroups: %d", len(remaining_types))
}
sub_groups, ok := sub_groups_value.Interface().(map[string][]NodeID)
if ok != true {
t.Fatalf("sub_groups wrong type %s", sub_groups_value.Type())
}
if len(sub_groups) != 1 {
t.Fatalf("sub_groups wrong length %d", len(sub_groups))
}
test_subgroup, exists := sub_groups["test_group"]
if exists == false {
t.Fatal("test_group not in subgroups")
}
if len(test_subgroup) != 1 {
t.Fatalf("test_group wrong size %d/1", len(test_subgroup))
}
if test_subgroup[0] != user_id {
t.Fatalf("sub_groups wrong value %s", test_subgroup[0])
}
ctx.Log.Logf("test", "Read Response: %+v", read_response)
}

@ -11,17 +11,13 @@ type ListenerExt struct {
}
func (ext *ListenerExt) Load(ctx *Context, node *Node) error {
ext.Chan = make(chan Signal, ext.Buffer)
return nil
}
func (ext *ListenerExt) Unload(ctx *Context, node *Node) {
}
func (ext *ListenerExt) PostDeserialize(ctx *Context) error {
ext.Chan = make(chan Signal, ext.Buffer)
return nil
}
// Create a new listener extension with a given buffer size
func NewListenerExt(buffer int) *ListenerExt {
return &ListenerExt{
@ -31,7 +27,7 @@ func NewListenerExt(buffer int) *ListenerExt {
}
// Send the signal to the channel, logging an overflow if it occurs
func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
ctx.Log.Logf("listener", "%s - %+v", node.ID, reflect.TypeOf(signal))
ctx.Log.Logf("listener_debug", "%s->%s - %+v", source, node.ID, signal)
select {

@ -5,18 +5,6 @@ import (
"time"
)
var AllowParentUnlockPolicy = NewOwnerOfPolicy(Tree{
SerializedType(SignalTypeFor[LockSignal]()): {
Hash(LockStateBase, "unlock"): nil,
},
})
var AllowAnyLockPolicy = NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[LockSignal]()): {
Hash(LockStateBase, "lock"): nil,
},
})
type ReqState byte
const (
Unlocked = ReqState(0)
@ -62,17 +50,15 @@ func NewLockableExt(requirements []NodeID) *LockableExt {
}
func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
messages := Messages{}
signal := NewLockSignal("unlock")
messages = messages.Add(ctx, node.ID, node, nil, signal)
return signal.ID(), ctx.Send(messages)
messages := []SendMsg{{node.ID, signal}}
return signal.ID(), ctx.Send(node, messages)
}
func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
messages := Messages{}
signal := NewLockSignal("lock")
messages = messages.Add(ctx, node.ID, node, nil, signal)
return signal.ID(), ctx.Send(messages)
messages := []SendMsg{{node.ID, signal}}
return signal.ID(), ctx.Send(node, messages)
}
func (ext *LockableExt) Load(ctx *Context, node *Node) error {
@ -82,8 +68,8 @@ func (ext *LockableExt) Load(ctx *Context, node *Node) error {
func (ext *LockableExt) Unload(ctx *Context, node *Node) {
}
func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) (Messages, Changes) {
var messages Messages = nil
func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) ([]SendMsg, Changes) {
var messages []SendMsg = nil
var changes Changes = nil
info, info_found := node.ProcessResponse(ext.WaitInfos, signal)
@ -126,7 +112,7 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI
ext.Requirements[id] = Unlocking
lock_signal := NewLockSignal("unlock")
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond)
messages = messages.Add(ctx, id, node, nil, lock_signal)
messages = append(messages, SendMsg{id, lock_signal})
ctx.Log.Logf("lockable", "sent abort unlock to %s from %s", id, node.ID)
}
}
@ -153,43 +139,43 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI
return messages, changes
}
func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) (Messages, Changes) {
var messages Messages = nil
func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) ([]SendMsg, Changes) {
var messages []SendMsg = nil
var changes = Changes{}
if ext.State == Unlocked {
switch signal.Action {
case "add":
_, exists := ext.Requirements[signal.NodeID]
if exists == true {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "already_requirement"))
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "already_requirement")})
} else {
if ext.Requirements == nil {
ext.Requirements = map[NodeID]ReqState{}
}
ext.Requirements[signal.NodeID] = Unlocked
changes.Add("requirements")
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID()))
messages = append(messages, SendMsg{source, NewSuccessSignal(signal.ID())})
}
case "remove":
_, exists := ext.Requirements[signal.NodeID]
if exists == false {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "can't link: not_requirement"))
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "can't link: not_requirement")})
} else {
delete(ext.Requirements, signal.NodeID)
changes.Add("requirements")
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID()))
messages = append(messages, SendMsg{source, NewSuccessSignal(signal.ID())})
}
default:
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "unknown_action"))
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "unknown_action")})
}
} else {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_unlocked"))
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_unlocked")})
}
return messages, changes
}
func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) (Messages, Changes) {
var messages Messages = nil
func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) ([]SendMsg, Changes) {
var messages []SendMsg = nil
var changes = Changes{}
if source == node.ID {
return messages, changes
@ -218,7 +204,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
ext.State = Locked
ext.Owner = ext.PendingOwner
changes.Add("state", "owner", "requirements")
messages = messages.Add(ctx, *ext.Owner, node, nil, NewSuccessSignal(ext.PendingID))
messages = append(messages, SendMsg{*ext.Owner, NewSuccessSignal(ext.PendingID)})
} else {
changes.Add("requirements")
ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, len(ext.Requirements))
@ -228,7 +214,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
lock_signal := NewLockSignal("unlock")
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", info.Destination, lock_signal, 100*time.Millisecond)
messages = messages.Add(ctx, info.Destination, node, nil, lock_signal)
messages = append(messages, SendMsg{info.Destination, lock_signal})
ctx.Log.Logf("lockable", "sending abort_lock to %s for %s", info.Destination, node.ID)
}
@ -254,10 +240,10 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
ext.Owner = ext.PendingOwner
ext.ReqID = nil
changes.Add("state", "owner", "req_id")
messages = messages.Add(ctx, previous_owner, node, nil, NewSuccessSignal(ext.PendingID))
messages = append(messages, SendMsg{previous_owner, NewSuccessSignal(ext.PendingID)})
} else if old_state == AbortingLock {
changes.Add("state", "pending_owner")
messages = messages.Add(ctx, *ext.PendingOwner, node, nil, NewErrorSignal(*ext.ReqID, "not_unlocked"))
messages = append(messages, SendMsg{*ext.PendingOwner, NewErrorSignal(*ext.ReqID, "not_unlocked")})
ext.PendingOwner = ext.Owner
}
} else {
@ -272,8 +258,8 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
}
// Handle a LockSignal and update the extensions owner/requirement states
func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) (Messages, Changes) {
var messages Messages = nil
func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) ([]SendMsg, Changes) {
var messages []SendMsg = nil
var changes = Changes{}
switch signal.State {
@ -286,7 +272,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
ext.PendingOwner = &new_owner
ext.Owner = &new_owner
changes.Add("state", "pending_owner", "owner")
messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID()))
messages = append(messages, SendMsg{new_owner, NewSuccessSignal(signal.ID())})
} else {
ext.State = Locking
id := signal.ID()
@ -304,11 +290,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("lock", id, lock_signal, 500*time.Millisecond)
ext.Requirements[id] = Locking
messages = messages.Add(ctx, id, node, nil, lock_signal)
messages = append(messages, SendMsg{id, lock_signal})
}
}
default:
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_unlocked"))
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_unlocked")})
ctx.Log.Logf("lockable", "Tried to lock %s while %s", node.ID, ext.State)
}
case "unlock":
@ -319,7 +305,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
ext.PendingOwner = nil
ext.Owner = nil
changes.Add("state", "pending_owner", "owner")
messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID()))
messages = append(messages, SendMsg{new_owner, NewSuccessSignal(signal.ID())})
} else if source == *ext.Owner {
ext.State = Unlocking
id := signal.ID()
@ -336,11 +322,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond)
ext.Requirements[id] = Unlocking
messages = messages.Add(ctx, id, node, nil, lock_signal)
messages = append(messages, SendMsg{id, lock_signal})
}
}
} else {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_locked"))
messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_locked")})
}
default:
ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State)
@ -348,8 +334,8 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
return messages, changes
}
func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source NodeID, signal *TimeoutSignal) (Messages, Changes) {
var messages Messages = nil
func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source NodeID, signal *TimeoutSignal) ([]SendMsg, Changes) {
var messages []SendMsg = nil
var changes = Changes{}
wait_info, found := node.ProcessResponse(ext.WaitInfos, signal)
@ -380,7 +366,7 @@ func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source Nod
ext.Requirements[id] = Unlocking
lock_signal := NewLockSignal("unlock")
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond)
messages = messages.Add(ctx, id, node, nil, lock_signal)
messages = append(messages, SendMsg{id, lock_signal})
ctx.Log.Logf("lockable", "sent abort unlock to %s from %s", id, node.ID)
}
}
@ -405,124 +391,32 @@ func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source Nod
return messages, changes
}
// LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes
// LockableExts process status signals by forwarding them to it's owner
// LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state
func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
var messages Messages = nil
func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
var messages []SendMsg = nil
var changes = Changes{}
switch signal.Direction() {
case Up:
switch sig := signal.(type) {
case *StatusSignal:
if ext.Owner != nil {
if *ext.Owner != node.ID {
messages = messages.Add(ctx, *ext.Owner, node, nil, signal)
messages = append(messages, SendMsg{*ext.Owner, signal})
}
}
case Down:
for requirement := range(ext.Requirements) {
messages = messages.Add(ctx, requirement, node, nil, signal)
}
case Direct:
switch sig := signal.(type) {
case *LinkSignal:
messages, changes = ext.HandleLinkSignal(ctx, node, source, sig)
case *LockSignal:
messages, changes = ext.HandleLockSignal(ctx, node, source, sig)
case *ErrorSignal:
messages, changes = ext.HandleErrorSignal(ctx, node, source, sig)
case *SuccessSignal:
messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig)
case *TimeoutSignal:
messages, changes = ext.HandleTimeoutSignal(ctx, node, source, sig)
default:
}
case *LinkSignal:
messages, changes = ext.HandleLinkSignal(ctx, node, source, sig)
case *LockSignal:
messages, changes = ext.HandleLockSignal(ctx, node, source, sig)
case *ErrorSignal:
messages, changes = ext.HandleErrorSignal(ctx, node, source, sig)
case *SuccessSignal:
messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig)
case *TimeoutSignal:
messages, changes = ext.HandleTimeoutSignal(ctx, node, source, sig)
default:
}
return messages, changes
}
type OwnerOfPolicy struct {
PolicyHeader
Rules Tree `gv:"rules"`
}
func NewOwnerOfPolicy(rules Tree) OwnerOfPolicy {
return OwnerOfPolicy{
PolicyHeader: NewPolicyHeader(),
Rules: rules,
}
}
func (policy OwnerOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny
}
func (policy OwnerOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
l_ext, err := GetExt[LockableExt](node)
if err != nil {
ctx.Log.Logf("lockable", "OwnerOfPolicy.Allows called on node without LockableExt")
return nil, Deny
}
if l_ext.Owner == nil {
return nil, Deny
}
if principal_id == *l_ext.Owner {
return nil, Allow
}
return nil, Deny
}
type RequirementOfPolicy struct {
PerNodePolicy
}
func NewRequirementOfPolicy(dep_rules map[NodeID]Tree) RequirementOfPolicy {
return RequirementOfPolicy {
PerNodePolicy: NewPerNodePolicy(dep_rules),
}
return messages, changes
}
func (policy RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
sig, ok := signal.(*ReadResultSignal)
if ok == false {
return Deny
}
ext, ok := sig.Extensions[ExtTypeFor[LockableExt]()]
if ok == false {
return Deny
}
reqs_ser, ok := ext["requirements"]
if ok == false {
return Deny
}
reqs_type, _, err := DeserializeType(ctx, reqs_ser.TypeStack)
if err != nil {
return Deny
}
reqs_if, _, err := DeserializeValue(ctx, reqs_type, reqs_ser.Data)
if err != nil {
return Deny
}
requirements, ok := reqs_if.Interface().(map[NodeID]ReqState)
if ok == false {
return Deny
}
for req := range(requirements) {
if req == current.Principal {
return policy.NodeRules[sig.NodeID].Allows(current.Action)
}
}
return Deny
}

@ -3,8 +3,6 @@ package graphvent
import (
"testing"
"time"
"crypto/ed25519"
"crypto/rand"
)
func lockableTestContext(t *testing.T, logs []string) *Context {
@ -19,32 +17,19 @@ func lockableTestContext(t *testing.T, logs []string) *Context {
func TestLink(t *testing.T) {
ctx := lockableTestContext(t, []string{"lockable", "listener"})
l1_pub, l1_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
l1_id := KeyID(l1_pub)
policy := NewPerNodePolicy(map[NodeID]Tree{
l1_id: nil,
})
l2_listener := NewListenerExt(10)
l2, err := NewNode(ctx, nil, "Lockable", 10, []Policy{policy},
l2_listener,
NewLockableExt(nil),
)
l2, err := NewNode(ctx, nil, "Lockable", 10, l2_listener, NewLockableExt(nil))
fatalErr(t, err)
l1_lockable := NewLockableExt(nil)
l1_listener := NewListenerExt(10)
l1, err := NewNode(ctx, l1_key, "Lockable", 10, nil,
l1_listener,
l1_lockable,
)
l1, err := NewNode(ctx, nil, "Lockable", 10, l1_listener, l1_lockable)
fatalErr(t, err)
msgs := Messages{}
link_signal := NewLinkSignal("add", l2.ID)
msgs = msgs.Add(ctx, l1.ID, l1, nil, link_signal)
err = ctx.Send(msgs)
msgs := []SendMsg{{l1.ID, link_signal}}
err = ctx.Send(l1, msgs)
fatalErr(t, err)
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, link_signal.ID())
@ -57,10 +42,9 @@ func TestLink(t *testing.T) {
t.Fatalf("l2 in bad requirement state in l1: %+v", state)
}
msgs = Messages{}
unlink_signal := NewLinkSignal("remove", l2.ID)
msgs = msgs.Add(ctx, l1.ID, l1, nil, unlink_signal)
err = ctx.Send(msgs)
msgs = []SendMsg{{l1.ID, unlink_signal}}
err = ctx.Send(l1, msgs)
fatalErr(t, err)
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, unlink_signal.ID())
@ -70,18 +54,8 @@ func TestLink(t *testing.T) {
func Test1000Lock(t *testing.T) {
ctx := lockableTestContext(t, []string{"test", "lockable"})
l_pub, listener_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
listener_id := KeyID(l_pub)
child_policy := NewPerNodePolicy(map[NodeID]Tree{
listener_id: {
SerializedType(SignalTypeFor[LockSignal]()): nil,
},
})
NewLockable := func()(*Node) {
l, err := NewNode(ctx, nil, "Lockable", 10, []Policy{child_policy},
NewLockableExt(nil),
)
l, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil))
fatalErr(t, err)
return l
}
@ -93,15 +67,8 @@ func Test1000Lock(t *testing.T) {
}
ctx.Log.Logf("test", "CREATED_1000")
l_policy := NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[LockSignal]()): nil,
})
listener := NewListenerExt(5000)
node, err := NewNode(ctx, listener_key, "Lockable", 5000, []Policy{l_policy},
listener,
NewLockableExt(reqs),
)
node, err := NewNode(ctx, nil, "Lockable", 5000, listener, NewLockableExt(reqs))
fatalErr(t, err)
ctx.Log.Logf("test", "CREATED_LISTENER")
@ -123,14 +90,9 @@ func Test1000Lock(t *testing.T) {
func TestLock(t *testing.T) {
ctx := lockableTestContext(t, []string{"test", "lockable"})
policy := NewAllNodesPolicy(nil)
NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) {
listener := NewListenerExt(1000)
l, err := NewNode(ctx, nil, "Lockable", 10, []Policy{policy},
listener,
NewLockableExt(reqs),
)
l, err := NewNode(ctx, nil, "Lockable", 10, listener, NewLockableExt(reqs))
fatalErr(t, err)
return l, listener
}

@ -1,114 +1,11 @@
package graphvent
import (
"time"
"crypto/ed25519"
"crypto/rand"
"crypto"
)
type AuthInfo struct {
// The Node that issued the authorization
Identity ed25519.PublicKey
// Time the authorization was generated
Start time.Time
// Signature of Start + Principal with Identity private key
Signature []byte
}
type AuthorizationToken struct {
AuthInfo
// The private key generated by the client, encrypted with the servers public key
KeyEncrypted []byte
}
type ClientAuthorization struct {
AuthInfo
// The private key generated by the client
Key ed25519.PrivateKey
}
// Authorization structs can be passed in a message that originated from a different node than the sender
type Authorization struct {
AuthInfo
// The public key generated for this authorization
Key ed25519.PublicKey
}
type Message struct {
Dest NodeID
Source ed25519.PublicKey
Authorization *Authorization
type SendMsg struct {
Dest NodeID
Signal Signal
Signature []byte
}
type Messages []*Message
func (msgs Messages) Add(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) Messages {
msg, err := NewMessage(ctx, dest, source, authorization, signal)
if err != nil {
panic(err)
} else {
msgs = append(msgs, msg)
}
return msgs
}
func NewMessages(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signals... Signal) Messages {
messages := Messages{}
for _, signal := range(signals) {
messages = messages.Add(ctx, dest, source, authorization, signal)
}
return messages
}
func NewMessage(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) (*Message, error) {
signal_ser, err := SerializeAny(ctx, signal)
if err != nil {
return nil, err
}
signal_chunks, err := signal_ser.Chunks()
if err != nil {
return nil, err
}
dest_ser, err := dest.MarshalBinary()
if err != nil {
return nil, err
}
source_ser, err := source.ID.MarshalBinary()
if err != nil {
return nil, err
}
sig_data := append(dest_ser, source_ser...)
sig_data = append(sig_data, signal_chunks.Slice()...)
var message_auth *Authorization = nil
if authorization != nil {
sig_data = append(sig_data, authorization.Signature...)
message_auth = &Authorization{
authorization.AuthInfo,
authorization.Key.Public().(ed25519.PublicKey),
}
}
sig, err := source.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil {
return nil, err
}
return &Message{
Dest: dest,
Source: source.Key.Public().(ed25519.PublicKey),
Authorization: message_auth,
Signal: signal,
Signature: sig,
}, nil
type RecvMsg struct {
Source NodeID
Signal Signal
}

@ -10,7 +10,7 @@ import (
"sync/atomic"
"time"
badger "github.com/dgraph-io/badger/v3"
_ "github.com/dgraph-io/badger/v3"
"github.com/google/uuid"
)
@ -57,24 +57,6 @@ func (q QueuedSignal) String() string {
return fmt.Sprintf("%+v@%s", reflect.TypeOf(q.Signal), q.Time)
}
type PendingACL struct {
Counter int
Responses []ResponseSignal
TimeoutID uuid.UUID
Action Tree
Principal NodeID
Signal Signal
Source NodeID
}
type PendingACLSignal struct {
Policy uuid.UUID
Timeout uuid.UUID
ID uuid.UUID
}
// Default message channel size for nodes
// Nodes represent a group of extensions that can be collectively addressed
type Node struct {
@ -84,13 +66,8 @@ type Node struct {
// TODO: move each extension to it's own db key, and extend changes to notify which extension was changed
Extensions map[ExtType]Extension
Policies []Policy `gv:"policies"`
PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"`
PendingACLSignals map[uuid.UUID]PendingACLSignal `gv:"pending_signal"`
// Channel for this node to receive messages from the Context
MsgChan chan *Message
MsgChan chan RecvMsg
// Size of MsgChan
BufferSize uint32 `gv:"buffer_size"`
// Channel for this node to process delayed signals
@ -110,34 +87,11 @@ func (node *Node) PostDeserialize(ctx *Context) error {
public := node.Key.Public().(ed25519.PublicKey)
node.ID = KeyID(public)
node.MsgChan = make(chan *Message, node.BufferSize)
node.MsgChan = make(chan RecvMsg, node.BufferSize)
return nil
}
type RuleResult int
const (
Allow RuleResult = iota
Deny
Pending
)
func (node *Node) Allows(ctx *Context, principal_id NodeID, action Tree)(map[uuid.UUID]Messages, RuleResult) {
pends := map[uuid.UUID]Messages{}
for _, policy := range(node.Policies) {
msgs, resp := policy.Allows(ctx, principal_id, action, node)
if resp == Allow {
return nil, Allow
} else if resp == Pending {
pends[policy.ID()] = msgs
}
}
if len(pends) != 0 {
return pends, Pending
}
return nil, Deny
}
type WaitReason string
type WaitInfo struct {
Destination NodeID `gv:"destination"`
@ -250,37 +204,23 @@ func (err StringError) MarshalBinary() ([]byte, error) {
return []byte(string(err)), nil
}
func NewErrorField(fstring string, args ...interface{}) SerializedValue {
str := StringError(fmt.Sprintf(fstring, args...))
str_ser, err := str.MarshalBinary()
if err != nil {
panic(err)
}
return SerializedValue{
TypeStack: []SerializedType{SerializedTypeFor[error]()},
Data: str_ser,
}
}
func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]SerializedValue {
func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]any {
ctx.Log.Logf("read_field", "Reading %+v on %+v", reqs, node.ID)
exts := map[ExtType]map[string]SerializedValue{}
exts := map[ExtType]map[string]any{}
for ext_type, field_reqs := range(reqs) {
fields := map[string]SerializedValue{}
for _, req := range(field_reqs) {
ext, exists := node.Extensions[ext_type]
if exists == false {
fields[req] = NewErrorField("%+v does not have %+v extension", node.ID, ext_type)
} else {
f, err := SerializeField(ctx, ext, req)
if err != nil {
fields[req] = NewErrorField(err.Error())
ext_info, ext_known := ctx.Extensions[ext_type]
if ext_known {
fields := map[string]any{}
for _, req := range(field_reqs) {
ext, exists := node.Extensions[ext_type]
if exists == false {
fields[req] = fmt.Errorf("%+v does not have %+v extension", node.ID, ext_type)
} else {
fields[req] = f
fields[req] = reflect.ValueOf(ext).FieldByIndex(ext_info.Fields[req]).Interface()
}
}
exts[ext_type] = fields
}
exts[ext_type] = fields
}
return exts
}
@ -306,92 +246,8 @@ func nodeLoop(ctx *Context, node *Node) error {
var source NodeID
select {
case msg := <- node.MsgChan:
ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal)
signal_ser, err := SerializeAny(ctx, msg.Signal)
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, msg.Signal)
}
chunks, err := signal_ser.Chunks()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, signal_ser)
continue
}
dst_id_ser, err := msg.Dest.MarshalBinary()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_DEST_ID_SER_ERR: %e", err)
continue
}
src_id_ser, err := KeyID(msg.Source).MarshalBinary()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SRC_ID_SER_ERR: %e", err)
continue
}
sig_data := append(dst_id_ser, src_id_ser...)
sig_data = append(sig_data, chunks.Slice()...)
if msg.Authorization != nil {
sig_data = append(sig_data, msg.Authorization.Signature...)
}
validated := ed25519.Verify(msg.Source, sig_data, msg.Signature)
if validated == false {
ctx.Log.Logf("signal_verify", "SIGNAL_VERIFY_ERR: %s - %s", node.ID, reflect.TypeOf(msg.Signal))
continue
}
var princ_id NodeID
if msg.Authorization == nil {
princ_id = KeyID(msg.Source)
} else {
err := ValidateAuthorization(*msg.Authorization, time.Hour)
if err != nil {
ctx.Log.Logf("node", "Authorization validation failed: %s", err)
continue
}
princ_id = KeyID(msg.Authorization.Identity)
}
if princ_id != node.ID {
pends, resp := node.Allows(ctx, princ_id, msg.Signal.Permission())
if resp == Deny {
ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %+v(%+s)", princ_id, node.ID, reflect.TypeOf(msg.Signal), msg.Signal)
ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source)
msgs := Messages{}
msgs = msgs.Add(ctx, KeyID(msg.Source), node, nil, NewErrorSignal(msg.Signal.ID(), "acl denied"))
ctx.Send(msgs)
continue
} else if resp == Pending {
ctx.Log.Logf("policy", "SIGNAL_POLICY_PENDING: %s->%s - %s - %+v", princ_id, node.ID, msg.Signal.Permission(), pends)
timeout_signal := NewACLTimeoutSignal(msg.Signal.ID())
node.QueueSignal(time.Now().Add(100*time.Millisecond), timeout_signal)
msgs := Messages{}
for policy_type, sigs := range(pends) {
for _, m := range(sigs) {
msgs = append(msgs, m)
timeout_signal := NewTimeoutSignal(m.Signal.ID())
node.QueueSignal(time.Now().Add(time.Second), timeout_signal)
node.PendingACLSignals[m.Signal.ID()] = PendingACLSignal{policy_type, timeout_signal.Id, msg.Signal.ID()}
}
}
node.PendingACLs[msg.Signal.ID()] = PendingACL{
Counter: len(msgs),
TimeoutID: timeout_signal.ID(),
Action: msg.Signal.Permission(),
Principal: princ_id,
Responses: []ResponseSignal{},
Signal: msg.Signal,
Source: KeyID(msg.Source),
}
ctx.Log.Logf("policy", "Sending signals for pending ACL: %+v", msgs)
ctx.Send(msgs)
continue
} else if resp == Allow {
ctx.Log.Logf("policy", "SIGNAL_POLICY_ALLOW: %s->%s - %s", princ_id, node.ID, reflect.TypeOf(msg.Signal))
}
} else {
ctx.Log.Logf("policy", "SIGNAL_POLICY_SELF: %s - %s", node.ID, reflect.TypeOf(msg.Signal))
}
signal = msg.Signal
source = KeyID(msg.Source)
source = msg.Source
case <-node.TimeoutChan:
signal = node.NextSignal.Signal
@ -425,68 +281,12 @@ func nodeLoop(ctx *Context, node *Node) error {
ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue)
response, ok := signal.(ResponseSignal)
if ok == true {
info, waiting := node.PendingACLSignals[response.ResponseID()]
if waiting == true {
delete(node.PendingACLSignals, response.ResponseID())
ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal)
req_info, exists := node.PendingACLs[info.ID]
if exists == true {
req_info.Counter -= 1
req_info.Responses = append(req_info.Responses, response)
idx := -1
for i, p := range(node.Policies) {
if p.ID() == info.Policy {
idx = i
break
}
}
if idx == -1 {
ctx.Log.Logf("policy", "PENDING_FOR_NONEXISTENT_POLICY: %s - %s", node.ID, info.Policy)
delete(node.PendingACLs, info.ID)
} else {
allowed := node.Policies[idx].ContinueAllows(ctx, req_info, signal)
if allowed == Allow {
ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal)
signal = req_info.Signal
source = req_info.Source
err := node.DequeueSignal(req_info.TimeoutID)
if err != nil {
ctx.Log.Logf("node", "dequeue error: %s", err)
}
delete(node.PendingACLs, info.ID)
} else if req_info.Counter == 0 {
ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal)
// Send the denied response
msgs := Messages{}
msgs = msgs.Add(ctx, req_info.Source, node, nil, NewErrorSignal(req_info.Signal.ID(), "acl_denied"))
err := ctx.Send(msgs)
if err != nil {
ctx.Log.Logf("signal", "SEND_ERR: %s", err)
}
err = node.DequeueSignal(req_info.TimeoutID)
if err != nil {
ctx.Log.Logf("node", "ACL_DEQUEUE_ERROR: timeout signal not in queue when trying to clear after counter hit 0 %s, %s - %s", err, signal.ID(), req_info.TimeoutID)
}
delete(node.PendingACLs, info.ID)
} else {
node.PendingACLs[info.ID] = req_info
continue
}
}
}
}
}
switch sig := signal.(type) {
case *ReadSignal:
result := node.ReadFields(ctx, sig.Extensions)
msgs := Messages{}
msgs = msgs.Add(ctx, source, node, nil, NewReadResultSignal(sig.ID(), node.ID, node.Type, result))
ctx.Send(msgs)
msgs := []SendMsg{}
msgs = append(msgs, SendMsg{source, NewReadResultSignal(sig.ID(), node.ID, node.Type, result)})
ctx.Send(node, msgs)
default:
err := node.Process(ctx, source, signal)
@ -522,7 +322,7 @@ func (node *Node) QueueChanges(ctx *Context, changes map[ExtType]Changes) error
func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal)
messages := Messages{}
messages := []SendMsg{}
changes := map[ExtType]Changes{}
for ext_type, ext := range(node.Extensions) {
ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type)
@ -537,7 +337,7 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
ctx.Log.Logf("changes", "Changes for %s after %+v - %+v", node.ID, reflect.TypeOf(signal), changes)
if len(messages) != 0 {
send_err := ctx.Send(messages)
send_err := ctx.Send(node, messages)
if send_err != nil {
return send_err
}
@ -596,7 +396,7 @@ func KeyID(pub ed25519.PublicKey) NodeID {
}
// Create a new node in memory and start it's event loop
func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size uint32, policies []Policy, extensions ...Extension) (*Node, error) {
func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size uint32, extensions ...Extension) (*Node, error) {
node_type, known_type := ctx.NodeTypes[type_name]
if known_type == false {
return nil, fmt.Errorf("%s is not a known node type", type_name)
@ -643,17 +443,12 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size
}
}
policies = append(policies, DefaultPolicy)
node := &Node{
Key: key,
ID: id,
Type: node_type,
Extensions: ext_map,
Policies: policies,
PendingACLs: map[uuid.UUID]PendingACL{},
PendingACLSignals: map[uuid.UUID]PendingACLSignal{},
MsgChan: make(chan *Message, buffer_size),
MsgChan: make(chan RecvMsg, buffer_size),
BufferSize: buffer_size,
SignalQueue: []QueuedSignal{},
}
@ -685,256 +480,17 @@ func ExtTypeSuffix(ext_type ExtType) []byte {
}
func WriteNodeExtList(ctx *Context, node *Node) error {
ext_list := make([]ExtType, len(node.Extensions))
i := 0
for ext_type := range(node.Extensions) {
ext_list[i] = ext_type
i += 1
}
ctx.Log.Logf("db", "Writing ext_list for %s - %+v", node.ID, ext_list)
id_bytes, err := node.ID.MarshalBinary()
if err != nil {
return err
}
ext_list_serialized, err := SerializeAny(ctx, ext_list)
if err != nil {
return err
}
return ctx.DB.Update(func(txn *badger.Txn) error {
return txn.Set(append(id_bytes, extension_suffix...), ext_list_serialized.Data)
})
return fmt.Errorf("TODO: write node list")
}
func WriteNodeInit(ctx *Context, node *Node) error {
ctx.Log.Logf("db", "Writing initial entry for %s - %+v", node.ID, node)
ext_serialized := map[ExtType]SerializedValue{}
for ext_type, ext := range(node.Extensions) {
serialized_ext, err := SerializeAny(ctx, ext)
if err != nil {
return err
}
ext_serialized[ext_type] = serialized_ext
}
sq_serialized, err := SerializeAny(ctx, node.SignalQueue)
if err != nil {
return err
}
node_serialized, err := SerializeAny(ctx, node)
if err != nil {
return err
}
id_bytes, err := node.ID.MarshalBinary()
return ctx.DB.Update(func(txn *badger.Txn) error {
err := txn.Set(id_bytes, node_serialized.Data)
if err != nil {
return nil
}
err = txn.Set(append(id_bytes, signal_queue_suffix...), sq_serialized.Data)
if err != nil {
return err
}
for ext_type, data := range(ext_serialized) {
err := txn.Set(append(id_bytes, ExtTypeSuffix(ext_type)...), data.Data)
if err != nil {
return err
}
}
return nil
})
return fmt.Errorf("TODO: write initial node entry")
}
func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
ctx.Log.Logf("db", "Writing changes for %s - %+v", node.ID, changes)
ext_serialized := map[ExtType]SerializedValue{}
for ext_type := range(changes) {
ext, ext_exists := node.Extensions[ext_type]
if ext_exists == false {
ctx.Log.Logf("db", "extension 0x%x does not exist for %s", ext_type, node.ID)
} else {
serialized_ext, err := SerializeAny(ctx, ext)
if err != nil {
return err
}
ext_serialized[ext_type] = serialized_ext
}
}
var sq_serialized *SerializedValue = nil
if node.writeSignalQueue == true {
node.writeSignalQueue = false
ser, err := SerializeAny(ctx, node.SignalQueue)
if err != nil {
return err
}
sq_serialized = &ser
}
node_serialized, err := SerializeAny(ctx, node)
if err != nil {
return err
}
id_bytes, err := node.ID.MarshalBinary()
return ctx.DB.Update(func(txn *badger.Txn) error {
err := txn.Set(id_bytes, node_serialized.Data)
if err != nil {
return err
}
if sq_serialized != nil {
err := txn.Set(append(id_bytes, signal_queue_suffix...), sq_serialized.Data)
if err != nil {
return err
}
}
for ext_type, data := range(ext_serialized) {
err := txn.Set(append(id_bytes, ExtTypeSuffix(ext_type)...), data.Data)
if err != nil {
return err
}
}
return nil
})
return fmt.Errorf("TODO: write changes to node(and any signal queue changes)")
}
func LoadNode(ctx *Context, id NodeID) (*Node, error) {
ctx.Log.Logf("db", "LOADING_NODE: %s", id)
var node_bytes []byte = nil
var sq_bytes []byte = nil
var ext_bytes = map[ExtType][]byte{}
err := ctx.DB.View(func(txn *badger.Txn) error {
id_bytes, err := id.MarshalBinary()
if err != nil {
return err
}
node_item, err := txn.Get(id_bytes)
if err != nil {
ctx.Log.Logf("db", "node key not found")
return err
}
node_bytes, err = node_item.ValueCopy(nil)
if err != nil {
return err
}
sq_item, err := txn.Get(append(id_bytes, signal_queue_suffix...))
if err != nil {
ctx.Log.Logf("db", "sq key not found")
return err
}
sq_bytes, err = sq_item.ValueCopy(nil)
if err != nil {
return err
}
ext_list_item, err := txn.Get(append(id_bytes, extension_suffix...))
if err != nil {
ctx.Log.Logf("db", "ext_list key not found")
return err
}
ext_list_bytes, err := ext_list_item.ValueCopy(nil)
if err != nil {
return err
}
ext_list_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf([]ExtType{}), ext_list_bytes)
if err != nil {
return err
} else if len(remaining) > 0 {
return fmt.Errorf("Data remaining after ext_list deserialize %d", len(remaining))
}
ext_list, ok := ext_list_value.Interface().([]ExtType)
if ok == false {
return fmt.Errorf("deserialize returned wrong type %s", ext_list_value.Type())
}
for _, ext_type := range(ext_list) {
ext_item, err := txn.Get(append(id_bytes, ExtTypeSuffix(ext_type)...))
if err != nil {
ctx.Log.Logf("db", "ext %s key not found", ext_type)
return err
}
ext_bytes[ext_type], err = ext_item.ValueCopy(nil)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
node_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf((*Node)(nil)), node_bytes)
if err != nil {
return nil, err
} else if len(remaining) != 0 {
return nil, fmt.Errorf("data left after deserializing node %d", len(remaining))
}
node, node_ok := node_value.Interface().(*Node)
if node_ok == false {
return nil, fmt.Errorf("node wrong type %s", node_value.Type())
}
ctx.Log.Logf("db", "Deserialized node bytes %+v", node)
signal_queue_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf([]QueuedSignal{}), sq_bytes)
if err != nil {
return nil, err
} else if len(remaining) != 0 {
return nil, fmt.Errorf("data left after deserializing signal_queue %d", len(remaining))
}
signal_queue, sq_ok := signal_queue_value.Interface().([]QueuedSignal)
if sq_ok == false {
return nil, fmt.Errorf("signal queue wrong type %s", signal_queue_value.Type())
}
for ext_type, data := range(ext_bytes) {
ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return nil, fmt.Errorf("0x%0x is not a known extension type", ext_type)
}
ext_value, remaining, err := DeserializeValue(ctx, ext_info.Reflect, data)
if err != nil {
return nil, err
} else if len(remaining) > 0 {
return nil, fmt.Errorf("data left after deserializing ext(0x%x) %d", ext_type, len(remaining))
}
ext, ext_ok := ext_value.Interface().(Extension)
if ext_ok == false {
return nil, fmt.Errorf("extension wrong type %s", ext_value.Type())
}
node.Extensions[ext_type] = ext
}
node.SignalQueue = signal_queue
node.NextSignal, node.TimeoutChan = SoonestSignal(signal_queue)
ctx.AddNode(id, node)
ctx.Log.Logf("db", "loaded %+v", node)
go runNode(ctx, node)
return node, nil
return nil, fmt.Errorf("TODO: load node + extensions from DB")
}

@ -12,7 +12,7 @@ func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{"node", "db"})
node_listener := NewListenerExt(10)
node, err := NewNode(ctx, nil, "Base", 10, nil, NewGroupExt(nil), NewLockableExt(nil), node_listener)
node, err := NewNode(ctx, nil, "Base", 10, nil, NewLockableExt(nil), node_listener)
fatalErr(t, err)
_, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool {
@ -45,25 +45,18 @@ func TestNodeRead(t *testing.T) {
ctx.Log.Logf("test", "N1: %s", n1_id)
ctx.Log.Logf("test", "N2: %s", n2_id)
n1_policy := NewPerNodePolicy(map[NodeID]Tree{
n2_id: {
SerializedType(SignalTypeFor[ReadSignal]()): nil,
},
})
n2_listener := NewListenerExt(10)
n2, err := NewNode(ctx, n2_key, "Base", 10, nil, NewGroupExt(nil), n2_listener)
n2, err := NewNode(ctx, n2_key, "Base", 10, n2_listener)
fatalErr(t, err)
n1, err := NewNode(ctx, n1_key, "Base", 10, []Policy{n1_policy}, NewGroupExt(nil))
n1, err := NewNode(ctx, n1_key, "Base", 10, NewListenerExt(10))
fatalErr(t, err)
read_sig := NewReadSignal(map[ExtType][]string{
ExtTypeFor[GroupExt](): {"members"},
ExtTypeFor[ListenerExt](): {"buffer"},
})
msgs := Messages{}
msgs = msgs.Add(ctx, n1.ID, n2, nil, read_sig)
err = ctx.Send(msgs)
msgs := []SendMsg{{n1.ID, read_sig}}
err = ctx.Send(n2, msgs)
fatalErr(t, err)
res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, func(sig *ReadResultSignal) bool {

@ -1,139 +0,0 @@
package graphvent
import (
"github.com/google/uuid"
)
type Policy interface {
Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult)
ContinueAllows(ctx *Context, current PendingACL, signal Signal)RuleResult
ID() uuid.UUID
}
type PolicyHeader struct {
UUID uuid.UUID `gv:"uuid"`
}
func (header PolicyHeader) ID() uuid.UUID {
return header.UUID
}
func (policy AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
return nil, policy.Rules.Allows(action)
}
func (policy AllNodesPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny
}
func (policy PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
for id, actions := range(policy.NodeRules) {
if id != principal_id {
continue
}
return nil, actions.Allows(action)
}
return nil, Deny
}
func (policy PerNodePolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny
}
func CopyTree(tree Tree) Tree {
if tree == nil {
return nil
}
ret := Tree{}
for name, sub := range(tree) {
ret[name] = CopyTree(sub)
}
return ret
}
func MergeTrees(first Tree, second Tree) Tree {
if first == nil || second == nil {
return nil
}
ret := CopyTree(first)
for name, sub := range(second) {
current, exists := ret[name]
if exists == true {
ret[name] = MergeTrees(current, sub)
} else {
ret[name] = CopyTree(sub)
}
}
return ret
}
type Tree map[SerializedType]Tree
func (rule Tree) Allows(action Tree) RuleResult {
// If the current rule is nil, it's a wildcard and any action being processed is allowed
if rule == nil {
return Allow
// If the rule isn't "allow all" but the action is "request all", deny
} else if action == nil {
return Deny
// If the current action has no children, it's allowed
} else if len(action) == 0 {
return Allow
// If the current rule has no children but the action goes further, it's not allowed
} else if len(rule) == 0 {
return Deny
// If the current rule and action have children, all the children of action must be allowed by rule
} else {
for sub, subtree := range(action) {
subrule, exists := rule[sub]
if exists == false {
return Deny
} else if subrule.Allows(subtree) == Deny {
return Deny
}
}
return Allow
}
}
func NewPolicyHeader() PolicyHeader {
return PolicyHeader{
UUID: uuid.New(),
}
}
func NewPerNodePolicy(node_actions map[NodeID]Tree) PerNodePolicy {
if node_actions == nil {
node_actions = map[NodeID]Tree{}
}
return PerNodePolicy{
PolicyHeader: NewPolicyHeader(),
NodeRules: node_actions,
}
}
type PerNodePolicy struct {
PolicyHeader
NodeRules map[NodeID]Tree `gv:"node_rules"`
}
func NewAllNodesPolicy(rules Tree) AllNodesPolicy {
return AllNodesPolicy{
PolicyHeader: NewPolicyHeader(),
Rules: rules,
}
}
type AllNodesPolicy struct {
PolicyHeader
Rules Tree `gv:"rules"`
}
var DefaultPolicy = NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[ResponseSignal]()): nil,
SerializedType(SignalTypeFor[StatusSignal]()): nil,
})

File diff suppressed because it is too large Load Diff

@ -1,247 +0,0 @@
package graphvent
import (
"fmt"
"reflect"
"testing"
"time"
)
func TestSerializeTest(t *testing.T) {
ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"})
testSerialize(t, ctx, map[string][]NodeID{"test_group": {RandID(), RandID(), RandID()}})
testSerialize(t, ctx, map[NodeID]ReqState{
RandID(): Locked,
RandID(): Unlocked,
})
}
func TestSerializeBasic(t *testing.T) {
ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"})
testSerializeComparable[bool](t, ctx, true)
type bool_wrapped bool
err := RegisterType[bool_wrapped](ctx, nil, nil, nil, DeserializeBool[bool_wrapped])
fatalErr(t, err)
testSerializeComparable[bool_wrapped](t, ctx, true)
testSerializeSlice[[]bool](t, ctx, []bool{false, false, true, false})
testSerializeComparable[string](t, ctx, "test")
testSerializeComparable[float32](t, ctx, 0.05)
testSerializeComparable[float64](t, ctx, 0.05)
testSerializeComparable[uint](t, ctx, uint(1234))
testSerializeComparable[uint8] (t, ctx, uint8(123))
testSerializeComparable[uint16](t, ctx, uint16(1234))
testSerializeComparable[uint32](t, ctx, uint32(12345))
testSerializeComparable[uint64](t, ctx, uint64(123456))
testSerializeComparable[int](t, ctx, 1234)
testSerializeComparable[int8] (t, ctx, int8(-123))
testSerializeComparable[int16](t, ctx, int16(-1234))
testSerializeComparable[int32](t, ctx, int32(-12345))
testSerializeComparable[int64](t, ctx, int64(-123456))
testSerializeComparable[time.Duration](t, ctx, time.Duration(100))
testSerializeComparable[time.Time](t, ctx, time.Now().Truncate(0))
testSerializeSlice[[]int](t, ctx, []int{123, 456, 789, 101112})
testSerializeSlice[[]int](t, ctx, ([]int)(nil))
testSerializeSliceSlice[[][]int](t, ctx, [][]int{{123, 456, 789, 101112}, {3253, 2341, 735, 212}, {123, 51}, nil})
testSerializeSliceSlice[[][]string](t, ctx, [][]string{{"123", "456", "789", "101112"}, {"3253", "2341", "735", "212"}, {"123", "51"}, nil})
testSerialize(t, ctx, map[int8]map[*int8]string{})
testSerialize(t, ctx, map[int8]time.Time{
1: time.Now(),
3: time.Now().Add(time.Second),
0: time.Now().Add(time.Second*2),
4: time.Now().Add(time.Second*3),
})
testSerialize(t, ctx, Tree{
SerializedTypeFor[NodeType](): nil,
SerializedTypeFor[SerializedType](): {
SerializedTypeFor[NodeType](): Tree{},
},
})
var i interface{} = nil
testSerialize(t, ctx, i)
testSerializeMap(t, ctx, map[int8]interface{}{
0: "abcd",
1: uint32(12345678),
2: i,
3: 123,
})
testSerializeMap(t, ctx, map[int8]int32{
0: 1234,
2: 5678,
4: 9101,
6: 1121,
})
type test_struct struct {
Int int `gv:"int"`
String string `gv:"string"`
}
err = RegisterStruct[test_struct](ctx)
fatalErr(t, err)
testSerialize(t, ctx, test_struct{
12345,
"test_string",
})
testSerialize(t, ctx, Tree{
SerializedKindFor(reflect.Map): nil,
SerializedKindFor(reflect.String): nil,
})
testSerialize(t, ctx, Tree{
SerializedTypeFor[Tree](): nil,
})
testSerialize(t, ctx, Tree{
SerializedTypeFor[Tree](): {
SerializedTypeFor[error](): Tree{},
SerializedKindFor(reflect.Map): nil,
},
SerializedKindFor(reflect.String): nil,
})
type test_slice []string
err = RegisterType[test_slice](ctx, SerializeTypeStub, SerializeSlice, DeserializeTypeStub[test_slice], DeserializeSlice)
fatalErr(t, err)
testSerialize[[]string](t, ctx, []string{"test_1", "test_2", "test_3"})
testSerialize[test_slice](t, ctx, test_slice{"test_1", "test_2", "test_3"})
}
type test struct {
Int int `gv:"int"`
Str string `gv:"string"`
}
func (s test) String() string {
return fmt.Sprintf("%d:%s", s.Int, s.Str)
}
func TestSerializeStructTags(t *testing.T) {
ctx := logTestContext(t, []string{"test"})
err := RegisterStruct[test](ctx)
fatalErr(t, err)
test_int := 10
test_string := "test"
ret := testSerialize(t, ctx, test{
test_int,
test_string,
})
if ret.Int != test_int {
t.Fatalf("Deserialized int %d does not equal test %d", ret.Int, test_int)
} else if ret.Str != test_string {
t.Fatalf("Deserialized string %s does not equal test %s", ret.Str, test_string)
}
testSerialize(t, ctx, []test{
{
test_int,
test_string,
},
{
test_int * 2,
fmt.Sprintf("%s%s", test_string, test_string),
},
{
test_int * 4,
fmt.Sprintf("%s%s%s", test_string, test_string, test_string),
},
})
}
func testSerializeMap[M map[T]R, T, R comparable](t *testing.T, ctx *Context, val M) {
v := testSerialize(t, ctx, val)
for key, value := range(val) {
recreated, exists := v[key]
if exists == false {
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
} else if recreated != value {
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
}
}
if len(v) != len(val) {
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
}
}
func testSerializeSliceSlice[S [][]T, T comparable](t *testing.T, ctx *Context, val S) {
v := testSerialize(t, ctx, val)
for i, original := range(val) {
if (original == nil && v[i] != nil) || (original != nil && v[i] == nil) {
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
}
for j, o := range(original) {
if v[i][j] != o {
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
}
}
}
}
func testSerializeSlice[S []T, T comparable](t *testing.T, ctx *Context, val S) {
v := testSerialize(t, ctx, val)
for i, original := range(val) {
if v[i] != original {
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val))
}
}
}
func testSerializeComparable[T comparable](t *testing.T, ctx *Context, val T) {
v := testSerialize(t, ctx, val)
if v != val {
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val))
}
}
func testSerialize[T any](t *testing.T, ctx *Context, val T) T {
value := reflect.ValueOf(&val).Elem()
type_stack, err := SerializeType(ctx, value.Type())
chunks, err := SerializeValue(ctx, value)
value_serialized := SerializedValue{type_stack, chunks.Slice()}
fatalErr(t, err)
ctx.Log.Logf("test", "Serialized %+v to %+v(%d)", val, value_serialized, len(value_serialized.Data))
value_chunks, err := value_serialized.Chunks()
fatalErr(t, err)
ctx.Log.Logf("test", "Binary: %+v", value_chunks.Slice())
val_parsed, remaining_parse, err := ParseSerializedValue(value_chunks.Slice())
fatalErr(t, err)
ctx.Log.Logf("test", "Parsed: %+v", val_parsed)
if len(remaining_parse) != 0 {
t.Fatal("Data remaining after deserializing value")
}
val_type, remaining_types, err := DeserializeType(ctx, val_parsed.TypeStack)
deserialized_value, remaining_deserialize, err := DeserializeValue(ctx, val_type, val_parsed.Data)
fatalErr(t, err)
if len(remaining_deserialize) != 0 {
t.Fatal("Data remaining after deserializing value")
} else if len(remaining_types) != 0 {
t.Fatal("TypeStack remaining after deserializing value")
} else if val_type != value.Type() {
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong reflect.Type %+v - %+v", val_type, reflect.TypeOf(val)))
} else if deserialized_value.CanConvert(val_type) == false {
t.Fatal("DeserializeValue returned value that can't convert to original value")
}
ctx.Log.Logf("test", "Value: %+v", deserialized_value.Interface())
if val_type.Kind() == reflect.Interface && deserialized_value.Interface() == nil {
var zero T
return zero
}
return deserialized_value.Interface().(T)
}

@ -7,20 +7,13 @@ import (
"github.com/google/uuid"
)
type SignalDirection uint8
const (
Up SignalDirection = iota
Down
Direct
)
type TimeoutSignal struct {
ResponseHeader
}
func NewTimeoutSignal(req_id uuid.UUID) *TimeoutSignal {
return &TimeoutSignal{
NewResponseHeader(req_id, Direct),
NewResponseHeader(req_id),
}
}
@ -28,26 +21,23 @@ func (signal TimeoutSignal) String() string {
return fmt.Sprintf("TimeoutSignal(%s)", &signal.ResponseHeader)
}
// Timeouts are internal only, no permission allows sending them
func (signal TimeoutSignal) Permission() Tree {
return nil
}
type SignalDirection int
const (
Up SignalDirection = iota
Down
Direct
)
type SignalHeader struct {
Id uuid.UUID `gv:"id"`
Dir SignalDirection `gv:"direction"`
}
func (signal SignalHeader) ID() uuid.UUID {
return signal.Id
}
func (signal SignalHeader) Direction() SignalDirection {
return signal.Dir
}
func (header SignalHeader) String() string {
return fmt.Sprintf("SignalHeader(%d, %s)", header.Dir, header.Id)
return fmt.Sprintf("SignalHeader(%s)", header.Id)
}
type ResponseSignal interface {
@ -65,14 +55,12 @@ func (header ResponseHeader) ResponseID() uuid.UUID {
}
func (header ResponseHeader) String() string {
return fmt.Sprintf("ResponseHeader(%d, %s->%s)", header.Dir, header.Id, header.ReqID)
return fmt.Sprintf("ResponseHeader(%s, %s)", header.Id, header.ReqID)
}
type Signal interface {
fmt.Stringer
ID() uuid.UUID
Direction() SignalDirection
Permission() Tree
}
func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UUID) (ResponseSignal, []Signal, error) {
@ -129,16 +117,15 @@ func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, check
return zero, fmt.Errorf("LOOP_ENDED")
}
func NewSignalHeader(direction SignalDirection) SignalHeader {
func NewSignalHeader() SignalHeader {
return SignalHeader{
uuid.New(),
direction,
}
}
func NewResponseHeader(req_id uuid.UUID, direction SignalDirection) ResponseHeader {
func NewResponseHeader(req_id uuid.UUID) ResponseHeader {
return ResponseHeader{
NewSignalHeader(direction),
NewSignalHeader(),
req_id,
}
}
@ -151,16 +138,9 @@ func (signal SuccessSignal) String() string {
return fmt.Sprintf("SuccessSignal(%s)", signal.ResponseHeader)
}
func (signal SuccessSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[ResponseSignal]()): {
SerializedType(SignalTypeFor[SuccessSignal]()): nil,
},
}
}
func NewSuccessSignal(req_id uuid.UUID) *SuccessSignal {
return &SuccessSignal{
NewResponseHeader(req_id, Direct),
NewResponseHeader(req_id),
}
}
@ -171,16 +151,9 @@ type ErrorSignal struct {
func (signal ErrorSignal) String() string {
return fmt.Sprintf("ErrorSignal(%s, %s)", signal.ResponseHeader, signal.Error)
}
func (signal ErrorSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[ResponseSignal]()): {
SerializedType(SignalTypeFor[ErrorSignal]()): nil,
},
}
}
func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) *ErrorSignal {
return &ErrorSignal{
NewResponseHeader(req_id, Direct),
NewResponseHeader(req_id),
fmt.Sprintf(fmt_string, args...),
}
}
@ -188,14 +161,9 @@ func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) *E
type ACLTimeoutSignal struct {
ResponseHeader
}
func (signal ACLTimeoutSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[ACLTimeoutSignal]()): nil,
}
}
func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal {
sig := &ACLTimeoutSignal{
NewResponseHeader(req_id, Direct),
NewResponseHeader(req_id),
}
return sig
}
@ -205,17 +173,12 @@ type StatusSignal struct {
Source NodeID `gv:"source"`
Changes map[ExtType]Changes `gv:"changes"`
}
func (signal StatusSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[StatusSignal]()): nil,
}
}
func (signal StatusSignal) String() string {
return fmt.Sprintf("StatusSignal(%s, %+v)", signal.SignalHeader, signal.Changes)
}
func NewStatusSignal(source NodeID, changes map[ExtType]Changes) *StatusSignal {
return &StatusSignal{
NewSignalHeader(Up),
NewSignalHeader(),
source,
changes,
}
@ -232,17 +195,9 @@ const (
LinkActionAdd = "ADD"
)
func (signal LinkSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[LinkSignal]()): Tree{
Hash(LinkActionBase, signal.Action): nil,
},
}
}
func NewLinkSignal(action string, id NodeID) Signal {
return &LinkSignal{
NewSignalHeader(Direct),
NewSignalHeader(),
id,
action,
}
@ -256,21 +211,9 @@ func (signal LockSignal) String() string {
return fmt.Sprintf("LockSignal(%s, %s)", signal.SignalHeader, signal.State)
}
const (
LockStateBase = "LOCK_STATE"
)
func (signal LockSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[LockSignal]()): Tree{
Hash(LockStateBase, signal.State): nil,
},
}
}
func NewLockSignal(state string) *LockSignal {
return &LockSignal{
NewSignalHeader(Direct),
NewSignalHeader(),
state,
}
}
@ -284,21 +227,9 @@ func (signal ReadSignal) String() string {
return fmt.Sprintf("ReadSignal(%s, %+v)", signal.SignalHeader, signal.Extensions)
}
func (signal ReadSignal) Permission() Tree {
ret := Tree{}
for ext, fields := range(signal.Extensions) {
field_tree := Tree{}
for _, field := range(fields) {
field_tree[SerializedType(GetFieldTag(field))] = nil
}
ret[SerializedType(ext)] = field_tree
}
return Tree{SerializedType(SignalTypeFor[ReadSignal]()): ret}
}
func NewReadSignal(exts map[ExtType][]string) *ReadSignal {
return &ReadSignal{
NewSignalHeader(Direct),
NewSignalHeader(),
exts,
}
}
@ -307,23 +238,16 @@ type ReadResultSignal struct {
ResponseHeader
NodeID NodeID
NodeType NodeType
Extensions map[ExtType]map[string]SerializedValue
Extensions map[ExtType]map[string]any
}
func (signal ReadResultSignal) String() string {
return fmt.Sprintf("ReadResultSignal(%s, %s, %+v, %+v)", signal.ResponseHeader, signal.NodeID, signal.NodeType, signal.Extensions)
}
func (signal ReadResultSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[ResponseSignal]()): {
SerializedType(SignalTypeFor[ReadResultSignal]()): nil,
},
}
}
func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]SerializedValue) *ReadResultSignal {
func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]any) *ReadResultSignal {
return &ReadResultSignal{
NewResponseHeader(req_id, Direct),
NewResponseHeader(req_id),
node_id,
node_type,
exts,