diff --git a/event.go b/event.go deleted file mode 100644 index 83c9f5d..0000000 --- a/event.go +++ /dev/null @@ -1,607 +0,0 @@ -package graphvent - -import ( - "fmt" - "time" - "errors" - "reflect" - "sort" - "sync" - badger "github.com/dgraph-io/badger/v3" -) - -// Update the events listeners, and notify the parent to do the same -func (event * BaseEvent) PropagateUpdate(signal GraphSignal) { - event.state_lock.RLock() - defer event.state_lock.RUnlock() - state := event.state.(*EventState) - - if signal.Downwards() == false { - // Child->Parent - if state.parent != nil { - SendUpdate(state.parent, signal) - } - - for _, resource := range(state.resources) { - SendUpdate(resource, signal) - } - } else { - // Parent->Child - for _, child := range(state.children) { - SendUpdate(child, signal) - } - } - event.signal <- signal -} - -type EventInfo interface { -} - -type BaseEventInfo interface { - EventInfo -} - -type EventQueueInfo struct { - EventInfo - priority int - state string -} - -func NewEventQueueInfo(priority int) * EventQueueInfo { - info := &EventQueueInfo{ - priority: priority, - state: "queued", - } - - return info -} - -// Event is the interface that event tree nodes must implement -type Event interface { - GraphNode - Children() []Event - LockChildren() - UnlockChildren() - InfoType() reflect.Type - LockInfo() - UnlockInfo() - ChildInfo(event Event) EventInfo - Parent() Event - LockParent() - UnlockParent() - Action(action string) (func()(string, error), bool) - Handler(signal_type string) (func(GraphSignal) (string, error), bool) - Resources() []Resource - Resource(id string) Resource - AddResource(Resource) error - DoneResource() Resource - SetTimeout(end_time time.Time, action string) - ClearTimeout() - Timeout() <-chan time.Time - TimeoutAction() string - Signal() chan GraphSignal - - finish() error - - addChild(child Event, info EventInfo) - setParent(parent Event) -} - -func (event * BaseEvent) AddResource(resource Resource) error { - event.state_lock.Lock() - defer event.state_lock.Unlock() - state := event.state.(*EventState) - - _, exists := state.resources[resource.ID()] - if exists == true { - return fmt.Errorf("%s is already required for %s, cannot add again", resource.Name(), state.name) - } - - state.resources[resource.ID()] = resource - return nil -} - -func (event * BaseEvent) Signal() chan GraphSignal { - return event.signal -} - -func (event * BaseEvent) TimeoutAction() string { - return event.timeout_action -} - -func (event * BaseEvent) Timeout() <-chan time.Time { - return event.timeout -} - -func (event * BaseEvent) ClearTimeout() { - event.timeout_action = "" - event.timeout = nil -} - -func (event * BaseEvent) SetTimeout(end_time time.Time, action string) { - event.timeout_action = action - event.timeout = time.After(time.Until(end_time)) -} - -func (event * BaseEvent) Handler(signal_type string) (func(GraphSignal)(string, error), bool) { - handler, exists := event.Handlers[signal_type] - return handler, exists -} - -func FindChild(event Event, id string) Event { - if id == event.ID() { - return event - } - - for _, child := range event.Children() { - result := FindChild(child, id) - if result != nil { - return result - } - } - - return nil -} - -func CheckInfoType(event Event, info EventInfo) bool { - if event.InfoType() == nil || info == nil { - if event.InfoType() == nil && info == nil { - return true - } else { - return false - } - } - - return event.InfoType() == reflect.TypeOf(info) -} - -func LinkEvent(event Event, child Event, info EventInfo) error { - if CheckInfoType(event, info) == false { - return errors.New("LinkEvents got wrong type") - } - - event.LockParent() - child.LockParent() - if child.Parent() != nil { - child.UnlockParent() - event.UnlockParent() - return errors.New(fmt.Sprintf("Parent already registered: %s->%s already %s", child.Name(), event.Name(), child.Parent().Name())) - } - - event.LockChildren() - - for _, c := range(event.Children()) { - if c.ID() == child.ID() { - event.UnlockChildren() - child.UnlockParent() - event.UnlockParent() - return errors.New("Child already in event") - } - } - - // After all the checks are done, update the state of child + parent, then unlock and update - child.setParent(event) - event.addChild(child, info) - - event.UnlockChildren() - child.UnlockParent() - event.UnlockParent() - - SendUpdate(event, NewSignal(event, "child_added")) - return nil -} - -func RunEvent(event Event) error { - Log.Logf("event", "EVENT_RUN: %s", event.Name()) - - for _, resource := range(event.Resources()) { - if resource.Owner() == nil { - return fmt.Errorf("EVENT_RUN_RESOURCE_NOT_LOCKED: %s, %s", event.Name(), resource.Name()) - } else if resource.Owner().ID() != event.ID() { - return fmt.Errorf("EVENT_RUN_RESOURCE_ALREADY_LOCKED: %s, %s, %s", event.Name(), resource.Name(), resource.Owner().Name()) - } - } - - SendUpdate(event, NewSignal(event, "event_start")) - next_action := "start" - var err error = nil - for next_action != "" { - action, exists := event.Action(next_action) - if exists == false { - error_str := fmt.Sprintf("%s is not a valid action", next_action) - return errors.New(error_str) - } - - Log.Logf("event", "EVENT_ACTION: %s - %s", event.Name(), next_action) - next_action, err = action() - if err != nil { - return err - } - } - - err = FinishEvent(event) - if err != nil { - Log.Logf("event", "EVENT_RUN_FINISH_ERR: %s", err) - return err - } - - Log.Logf("event", "EVENT_RUN_DONE: %s", event.Name()) - - return nil -} - -func EventAbort(event Event) func(signal GraphSignal) (string, error) { - return func(signal GraphSignal) (string, error) { - return "", errors.New(fmt.Sprintf("%s aborted by signal", event.ID())) - } -} - -func EventCancel(event Event) func(signal GraphSignal) (string, error) { - return func(signal GraphSignal) (string, error) { - return "", nil - } -} - -func LockResources(event Event) error { - Log.Logf("event", "RESOURCE_LOCKING for %s - %+v", event.Name(), event.Resources()) - locked_resources := []Resource{} - var lock_err error = nil - for _, resource := range(event.Resources()) { - err := LockResource(resource, event) - if err != nil { - lock_err = err - break - } - locked_resources = append(locked_resources, resource) - } - - if lock_err != nil { - for _, resource := range(locked_resources) { - UnlockResource(resource, event) - } - Log.Logf("event", "RESOURCE_LOCK_FAIL for %s: %s", event.Name(), lock_err) - return lock_err - } - - Log.Logf("event", "RESOURCE_LOCK_SUCCESS for %s", event.Name()) - signal := NewDownSignal(event, "locked") - SendUpdate(event, signal) - - return nil -} - -func FinishEvent(event Event) error { - Log.Logf("event", "EVENT_FINISH: %s", event.Name()) - for _, resource := range(event.Resources()) { - err := UnlockResource(resource, event) - if err != nil { - panic(err) - } - } - - err := UnlockResource(event.DoneResource(), event) - if err != nil { - return err - } - - SendUpdate(event, NewDownSignal(event, "unlocked")) - SendUpdate(event.DoneResource(), NewDownSignal(event, "unlocked")) - - err = event.finish() - if err != nil { - return err - } - - SendUpdate(event, NewSignal(event, "event_done")) - return nil -} - -// BaseEvent is the most basic event that can exist in the event tree. -// On start it automatically transitions to completion. -// It can optionally require events, which will all need to be locked to start it -// It can optionally create resources, which will be locked by default and unlocked on completion -// This node by itself doesn't implement any special behaviours for children, so they will be ignored. -// When starter, this event automatically transitions to completion and unlocks all it's resources(including created) -type BaseEvent struct { - BaseNode - - resources_lock sync.Mutex - children_lock sync.Mutex - info_lock sync.Mutex - parent_lock sync.Mutex - - Actions map[string]func() (string, error) - Handlers map[string]func(GraphSignal) (string, error) - - timeout <-chan time.Time - timeout_action string -} - -type EventState struct { - BaseNodeState - children []Event - child_info map[string]EventInfo - resources map[string]Resource - parent Event -} - -func (event * BaseEvent) LockInfo() { - event.info_lock.Lock() -} - -func (event * BaseEvent) UnlockInfo() { - event.info_lock.Unlock() -} - -func (event * BaseEvent) Action(action string) (func() (string, error), bool) { - action_fn, exists := event.Actions[action] - return action_fn, exists -} - -func EventWait(event Event) (func() (string, error)) { - return func() (string, error) { - Log.Logf("event", "EVENT_WAIT: %s TIMEOUT: %+v", event.Name(), event.Timeout()) - select { - case signal := <- event.Signal(): - Log.Logf("event", "EVENT_SIGNAL: %s %+v", event.Name(), signal) - signal_fn, exists := event.Handler(signal.Type()) - if exists == true { - Log.Logf("event", "EVENT_HANDLER: %s - %s", event.Name(), signal.Type()) - return signal_fn(signal) - } - return "wait", nil - case <- event.Timeout(): - Log.Logf("event", "EVENT_TIMEOUT %s - NEXT_STATE: %s", event.Name(), event.TimeoutAction()) - return event.TimeoutAction(), nil - } - } -} - -func NewBaseEvent(name string, description string) (BaseEvent) { - event := BaseEvent{ - BaseNode: NewBaseNode(randid()), - Actions: map[string]func()(string, error){}, - Handlers: map[string]func(GraphSignal)(string, error){}, - timeout: nil, - timeout_action: "", - } - return event -} - -func AddResources(event Event, resources []Resource) error { - for _, r := range(resources) { - err := event.AddResource(r) - if err != nil { - return err - } - } - return nil -} - -func NewEventState(name string, description string) *EventState{ - return &EventState{ - BaseNodeState: BaseNodeState{ - name: name, - description: description, - delegation_map: map[string]GraphNode{}, - }, - children: []Event{}, - child_info: map[string]EventInfo{}, - resources: map[string]Resource{}, - parent: nil, - } -} - -func NewEvent(db *badger.DB, name string, description string, resources []Resource) (* BaseEvent, error) { - event := NewBaseEvent(name, description) - event_ptr := &event - event_ptr.state = NewEventState(name, description) - - err := AddResources(event_ptr, resources) - if err != nil { - return nil, err - } - - event_ptr.Actions["wait"] = EventWait(event_ptr) - event_ptr.Handlers["abort"] = EventAbort(event_ptr) - event_ptr.Handlers["cancel"] = EventCancel(event_ptr) - - event_ptr.Actions["start"] = func() (string, error) { - return "", nil - } - - return event_ptr, nil -} - -func (event * BaseEvent) finish() error { - return nil -} - -func (event * BaseEvent) InfoType() reflect.Type { - return nil -} - -// EventQueue is a basic event that can have children. -// On start, it attempts to start it's children from the highest 'priority' -type EventQueue struct { - BaseEvent - listened_resources map[string]Resource - queue_lock sync.Mutex -} - -func (queue * EventQueue) finish() error { - for _, resource := range(queue.listened_resources) { - resource.UnregisterChannel(queue.signal) - } - return nil -} - -func (queue * EventQueue) InfoType() reflect.Type { - return reflect.TypeOf((*EventQueueInfo)(nil)) -} - -func NewEventQueue(name string, description string, resources []Resource) (* EventQueue, error) { - queue := &EventQueue{ - BaseEvent: NewBaseEvent(name, description), - listened_resources: map[string]Resource{}, - } - - queue.state = NewEventState(name, description) - - AddResources(queue, resources) - - queue.Actions["wait"] = EventWait(queue) - queue.Handlers["abort"] = EventAbort(queue) - queue.Handlers["cancel"] = EventCancel(queue) - - queue.Actions["start"] = func() (string, error) { - return "queue_event", nil - } - - queue.Actions["queue_event"] = func() (string, error) { - // Copy the events to sort the list - queue.LockChildren() - copied_events := make([]Event, len(queue.Children())) - copy(copied_events, queue.Children()) - less := func(i int, j int) bool { - info_i := queue.ChildInfo(copied_events[i]).(*EventQueueInfo) - info_j := queue.ChildInfo(copied_events[j]).(*EventQueueInfo) - return info_i.priority < info_j.priority - } - sort.SliceStable(copied_events, less) - - needed_resources := map[string]Resource{} - for _, event := range(copied_events) { - // make sure all the required resources are registered to update the event - for _, resource := range(event.Resources()) { - needed_resources[resource.ID()] = resource - } - - info := queue.ChildInfo(event).(*EventQueueInfo) - event.LockInfo() - defer event.UnlockInfo() - if info.state == "queued" { - err := LockResources(event) - // start in new goroutine - if err != nil { - } else { - info.state = "running" - Log.Logf("event", "EVENT_START: %s", event.Name()) - go func(event Event, info * EventQueueInfo, queue Event) { - Log.Logf("event", "EVENT_GOROUTINE: %s", event.Name()) - err := RunEvent(event) - if err != nil { - Log.Logf("event", "EVENT_ERROR: %s", err) - } - event.LockInfo() - defer event.UnlockInfo() - info.state = "done" - }(event, info, queue) - } - } - } - - - for _, resource := range(needed_resources) { - _, exists := queue.listened_resources[resource.ID()] - if exists == false { - Log.Logf("event", "REGISTER_RESOURCE: %s - %s", queue.Name(), resource.Name()) - queue.listened_resources[resource.ID()] = resource - resource.RegisterChannel(queue.signal) - } - } - - queue.UnlockChildren() - - return "wait", nil - } - - queue.Handlers["resource_connected"] = func(signal GraphSignal) (string, error) { - return "queue_event", nil - } - - queue.Handlers["child_added"] = func(signal GraphSignal) (string, error) { - return "queue_event", nil - } - - queue.Handlers["lock_changed"] = func(signal GraphSignal) (string, error) { - return "queue_event", nil - } - - queue.Handlers["event_done"] = func(signal GraphSignal) (string, error) { - return "queue_event", nil - } - - return queue, nil -} - -func (event * BaseEvent) Allowed() []GraphNode { - event.state_lock.RLock() - defer event.state_lock.RUnlock() - state := event.state.(*EventState) - - ret := make([]GraphNode, len(state.children)) - for i, v := range(state.children) { - ret[i] = v - } - return ret -} - -func (event * BaseEvent) Resources() []Resource { - resources := []Resource{} - for _, val := range(event.resources) { - resources = append(resources, val) - } - return resources -} - -func (event * BaseEvent) Resource(id string) Resource { - resource, _ := event.resources[id] - return resource -} - -func (event * BaseEvent) DoneResource() Resource { - return event.done_resource -} - -func (event * BaseEvent) Children() []Event { - return event.children -} - -func (event * BaseEvent) ChildInfo(idx Event) EventInfo { - val, ok := event.child_info[idx.ID()] - if ok == false { - return nil - } - return val -} - -func (event * BaseEvent) LockChildren() { - event.children_lock.Lock() -} - -func (event * BaseEvent) UnlockChildren() { - event.children_lock.Unlock() -} - -func (event * BaseEvent) LockParent() { - event.parent_lock.Lock() -} - -func (event * BaseEvent) UnlockParent() { - event.parent_lock.Unlock() -} - -func (event * BaseEvent) setParent(parent Event) { - event.parent = parent -} - -func (event * BaseEvent) addChild(child Event, info EventInfo) { - event.children = append(event.children, child) - event.child_info[child.ID()] = info -} - -type GQLEvent struct { - BaseEvent - abort chan error -} diff --git a/gql.go b/gql.go deleted file mode 100644 index 1cc31f4..0000000 --- a/gql.go +++ /dev/null @@ -1,1150 +0,0 @@ -package graphvent - -import ( - "net/http" - "github.com/graphql-go/graphql" - "github.com/graphql-go/graphql/language/parser" - "github.com/graphql-go/graphql/language/source" - "github.com/graphql-go/graphql/language/ast" - "context" - "encoding/json" - "io" - "reflect" - "errors" - "fmt" - "sync" - "time" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" -) - -func GraphiQLHandler() func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r * http.Request) { - graphiql_string := fmt.Sprintf(` - - - - - GraphiQL - - - - - - - - - - - -
Loading...
- - - - - `) - - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(http.StatusOK) - io.WriteString(w, graphiql_string) - } - -} - -var gql_type_base_node *graphql.Object = nil -func GQLTypeBaseNode() *graphql.Object { - if gql_type_base_node == nil { - gql_type_base_node = graphql.NewObject(graphql.ObjectConfig{ - Name: "BaseNode", - Interfaces: []*graphql.Interface{ - GQLInterfaceNode(), - }, - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(*BaseNode) - return ok - }, - Fields: graphql.Fields{}, - }) - - gql_type_base_node.AddFieldConfig("ID", &graphql.Field{ - Type: graphql.String, - Resolve: GQLResourceID, - }) - - gql_type_base_node.AddFieldConfig("Name", &graphql.Field{ - Type: graphql.String, - Resolve: GQLResourceName, - }) - - gql_type_base_node.AddFieldConfig("Description", &graphql.Field{ - Type: graphql.String, - Resolve: GQLResourceDescription, - }) - } - - return gql_type_base_node -} - -var gql_interface_node *graphql.Interface = nil -func GQLInterfaceNode() *graphql.Interface { - if gql_interface_node == nil { - gql_interface_node = graphql.NewInterface(graphql.InterfaceConfig{ - Name: "Node", - ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { - valid_events, ok := p.Context.Value("valid_events").(map[reflect.Type]*graphql.Object) - if ok == false { - return nil - } - - valid_resources, ok := p.Context.Value("valid_resources").(map[reflect.Type]*graphql.Object) - if ok == false { - return nil - } - - for key, value := range(valid_events) { - if reflect.TypeOf(p.Value) == key { - return value - } - } - - for key, value := range(valid_resources) { - if reflect.TypeOf(p.Value) == key { - return value - } - } - - return GQLTypeBaseNode() - }, - Fields: graphql.Fields{}, - }) - - gql_interface_node.AddFieldConfig("ID", &graphql.Field{ - Type: graphql.String, - }) - - gql_interface_node.AddFieldConfig("Name", &graphql.Field{ - Type: graphql.String, - }) - - gql_interface_node.AddFieldConfig("Description", &graphql.Field{ - Type: graphql.String, - }) - - } - - return gql_interface_node -} - -type GQLWSPayload struct { - OperationName string `json:"operationName,omitempty"` - Query string `json:"query,omitempty"` - Variables map[string]interface{} `json:"variables,omitempty"` - Extensions map[string]interface{} `json:"extensions,omitempty"` - Data string `json:"data,omitempty"` -} - -type GQLWSMsg struct { - ID string `json:"id,omitempty"` - Type string `json:"type"` - Payload GQLWSPayload `json:"payload,omitempty"` -} - -func enableCORS(w *http.ResponseWriter) { - (*w).Header().Set("Access-Control-Allow-Origin", "*") - (*w).Header().Set("Access-Control-Allow-Credentials", "true") - (*w).Header().Set("Access-Control-Allow-Headers", "*") - (*w).Header().Set("Access-Control-Allow-Methods", "*") -} - -func GQLHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r * http.Request) { - Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr) - enableCORS(&w) - header_map := map[string]interface{}{} - for header, value := range(r.Header) { - header_map[header] = value - } - Log.Logm("gql", header_map, "REQUEST_HEADERS") - - str, err := io.ReadAll(r.Body) - if err != nil { - Log.Logf("gql", "failed to read request body: %s", err) - return - } - query := GQLWSPayload{} - json.Unmarshal(str, &query) - - params := graphql.Params{ - Schema: schema, - Context: ctx, - RequestString: query.Query, - } - if query.OperationName != "" { - params.OperationName = query.OperationName - } - if len(query.Variables) > 0 { - params.VariableValues = query.Variables - } - result := graphql.Do(params) - if len(result.Errors) > 0 { - extra_fields := map[string]interface{}{} - extra_fields["body"] = string(str) - extra_fields["headers"] = r.Header - Log.Logm("gql", extra_fields, "wrong result, unexpected errors: %v", result.Errors) - } - json.NewEncoder(w).Encode(result) - } -} - -func sendOneResultAndClose(res *graphql.Result) chan *graphql.Result { - resultChannel := make(chan *graphql.Result) - go func() { - resultChannel <- res - close(resultChannel) - }() - return resultChannel -} - - -func getOperationTypeOfReq(p graphql.Params) string{ - source := source.NewSource(&source.Source{ - Body: []byte(p.RequestString), - Name: "GraphQL request", - }) - - AST, err := parser.Parse(parser.ParseParams{Source: source}) - if err != nil { - return "" - } - - for _, node := range AST.Definitions { - if operationDef, ok := node.(*ast.OperationDefinition); ok { - name := "" - if operationDef.Name != nil { - name = operationDef.Name.Value - } - if name == p.OperationName || p.OperationName == "" { - return operationDef.Operation - } - } - } - return "" -} - -func GQLWSDo(p graphql.Params) chan *graphql.Result { - operation := getOperationTypeOfReq(p) - Log.Logf("gqlws", "GQLWSDO_OPERATION: %s %+v", operation, p.RequestString) - - if operation == ast.OperationTypeSubscription { - return graphql.Subscribe(p) - } - - res := graphql.Do(p) - return sendOneResultAndClose(res) -} - -func GQLWSHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r * http.Request) { - Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr) - header_map := map[string]interface{}{} - for header, value := range(r.Header) { - header_map[header] = value - } - Log.Logm("gql", header_map, "REQUEST_HEADERS") - u := ws.HTTPUpgrader{ - Protocol: func(protocol string) bool { - Log.Logf("gqlws", "UPGRADE_PROTOCOL: %s", string(protocol)) - return string(protocol) == "graphql-transport-ws" - }, - } - conn, _, _, err := u.Upgrade(r, w) - if err == nil { - defer conn.Close() - conn_state := "init" - for { - // TODO: Make this a select between reading client data and getting updates from the event to push to clients" - msg_raw, op, err := wsutil.ReadClientData(conn) - Log.Logf("gqlws_hb", "MSG: %s\nOP: 0x%02x\nERR: %+v\n", string(msg_raw), op, err) - msg := GQLWSMsg{} - json.Unmarshal(msg_raw, &msg) - if err != nil { - Log.Logf("gqlws", "WS_CLIENT_ERROR") - break - } - if msg.Type == "connection_init" { - if conn_state != "init" { - Log.Logf("gqlws", "WS_CLIENT_ERROR: INIT WHILE IN %s", conn_state) - break - } - conn_state = "ready" - err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"connection_ack\"}")) - if err != nil { - Log.Logf("gqlws", "WS_SERVER_ERROR: FAILED TO SEND connection_ack") - break - } - } else if msg.Type == "ping" { - Log.Logf("gqlws_hb", "PING FROM %s", r.RemoteAddr) - err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"pong\"}")) - if err != nil { - Log.Logf("gqlws", "WS_SERVER_ERROR: FAILED TO SEND PONG") - } - } else if msg.Type == "subscribe" { - Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload) - params := graphql.Params{ - Schema: schema, - Context: ctx, - RequestString: msg.Payload.Query, - } - if msg.Payload.OperationName != "" { - params.OperationName = msg.Payload.OperationName - } - if len(msg.Payload.Variables) > 0 { - params.VariableValues = msg.Payload.Variables - } - - res_chan := GQLWSDo(params) - - go func(res_chan chan *graphql.Result) { - for { - next, ok := <-res_chan - if ok == false { - Log.Logf("gqlws", "response channel was closed") - return - } - if next == nil { - Log.Logf("gqlws", "NIL_ON_CHANNEL") - return - } - if len(next.Errors) > 0 { - extra_fields := map[string]interface{}{} - extra_fields["query"] = string(msg.Payload.Query) - Log.Logm("gqlws", extra_fields, "ERROR: wrong result, unexpected errors: %+v", next.Errors) - continue - } - Log.Logf("gqlws", "DATA: %+v", next.Data) - data, err := json.Marshal(next.Data) - if err != nil { - Log.Logf("gqlws", "ERROR: %+v", err) - continue - } - msg, err := json.Marshal(GQLWSMsg{ - ID: msg.ID, - Type: "next", - Payload: GQLWSPayload{ - Data: string(data), - }, - }) - if err != nil { - Log.Logf("gqlws", "ERROR: %+v", err) - continue - } - - err = wsutil.WriteServerMessage(conn, 1, msg) - if err != nil { - Log.Logf("gqlws", "ERROR: %+v", err) - continue - } - } - }(res_chan) - } else { - } - } - return - } else { - panic("Failed to upgrade websocket") - } - } -} - -func GQLEventFn(p graphql.ResolveParams, fn func(Event, graphql.ResolveParams)(interface{}, error))(interface{}, error) { - if event, ok := p.Source.(Event); ok { - return fn(event, p) - } - return nil, errors.New("Failed to cast source to event") -} - -func GQLEventID(p graphql.ResolveParams) (interface{}, error) { - return GQLEventFn(p, func(event Event, p graphql.ResolveParams)(interface{}, error) { - return event.ID(), nil - }) -} - -func GQLEventName(p graphql.ResolveParams) (interface{}, error) { - return GQLEventFn(p, func(event Event, p graphql.ResolveParams)(interface{}, error) { - return event.Name(), nil - }) -} - -func GQLEventDescription(p graphql.ResolveParams) (interface{}, error) { - return GQLEventFn(p, func(event Event, p graphql.ResolveParams)(interface{}, error) { - return event.Description(), nil - }) -} - -func GQLEventChildren(p graphql.ResolveParams) (interface{}, error) { - return GQLEventFn(p, func(event Event, p graphql.ResolveParams)(interface{}, error) { - return event.Children(), nil - }) -} - -func GQLEventResources(p graphql.ResolveParams) (interface{}, error) { - return GQLEventFn(p, func(event Event, p graphql.ResolveParams)(interface{}, error) { - return event.Resources(), nil - }) -} - -var gql_list_resource * graphql.List = nil -func GQLListResource() * graphql.List { - if gql_list_resource == nil { - gql_list_resource = graphql.NewList(GQLInterfaceResource()) - } - - return gql_list_resource -} - -var gql_interface_resource * graphql.Interface = nil -func GQLInterfaceResource() * graphql.Interface { - if gql_interface_resource == nil { - gql_interface_resource = graphql.NewInterface(graphql.InterfaceConfig{ - Name: "Resource", - ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { - if p.Value == nil { - return GQLTypeBaseResource() - } - valid_resources, ok := p.Context.Value("valid_resources").(map[reflect.Type]*graphql.Object) - if ok == false { - return nil - } - for key, value := range(valid_resources) { - if reflect.TypeOf(p.Value) == key { - return value - } - } - return nil - }, - Fields: graphql.Fields{}, - }) - - if gql_list_resource == nil { - gql_list_resource = graphql.NewList(gql_interface_resource) - } - - gql_interface_resource.AddFieldConfig("ID", &graphql.Field{ - Type: graphql.String, - }) - - gql_interface_resource.AddFieldConfig("Name", &graphql.Field{ - Type: graphql.String, - }) - - gql_interface_resource.AddFieldConfig("Description", &graphql.Field{ - Type: graphql.String, - }) - - gql_interface_resource.AddFieldConfig("Parents", &graphql.Field{ - Type: GQLListResource(), - }) - - gql_interface_resource.AddFieldConfig("Children", &graphql.Field{ - Type: GQLListResource(), - }) - - gql_interface_resource.AddFieldConfig("Owner", &graphql.Field{ - Type: GQLInterfaceNode(), - }) - - } - - return gql_interface_resource -} - -func GQLResourceFn(p graphql.ResolveParams, fn func(Resource, graphql.ResolveParams)(interface{}, error))(interface{}, error) { - if resource, ok := p.Source.(Resource); ok { - return fn(resource, p) - } - return nil, errors.New(fmt.Sprintf("Failed to cast source to resource, %+v", p.Source)) -} - -func GQLResourceID(p graphql.ResolveParams) (interface{}, error) { - return GQLResourceFn(p, func(resource Resource, p graphql.ResolveParams) (interface{}, error) { - return resource.ID(), nil - }) -} - -func GQLResourceName(p graphql.ResolveParams) (interface{}, error) { - return GQLResourceFn(p, func(resource Resource, p graphql.ResolveParams) (interface{}, error) { - return resource.Name(), nil - }) -} - -func GQLResourceDescription(p graphql.ResolveParams) (interface{}, error) { - return GQLResourceFn(p, func(resource Resource, p graphql.ResolveParams) (interface{}, error) { - return resource.Description(), nil - }) -} - -func GQLResourceParents(p graphql.ResolveParams) (interface{}, error) { - return GQLResourceFn(p, func(resource Resource, p graphql.ResolveParams) (interface{}, error) { - return resource.Parents(), nil - }) -} - -func GQLResourceOwner(p graphql.ResolveParams) (interface{}, error) { - return GQLResourceFn(p, func(resource Resource, p graphql.ResolveParams) (interface{}, error) { - return resource.Owner(), nil - }) -} - -func GQLResourceChildren(p graphql.ResolveParams) (interface{}, error) { - return GQLResourceFn(p, func(resource Resource, p graphql.ResolveParams) (interface{}, error) { - return resource.Children(), nil - }) -} - -var gql_type_gql_server *graphql.Object = nil -func GQLTypeGQLServer() * graphql.Object { - if gql_type_gql_server == nil { - gql_type_gql_server = graphql.NewObject(graphql.ObjectConfig{ - Name: "GQLServer", - Interfaces: []*graphql.Interface{ - GQLInterfaceResource(), - GQLInterfaceNode(), - }, - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(*GQLServer) - return ok - }, - Fields: graphql.Fields{}, - }) - - gql_type_gql_server.AddFieldConfig("ID", &graphql.Field{ - Type: graphql.String, - Resolve: GQLResourceID, - }) - - gql_type_gql_server.AddFieldConfig("Name", &graphql.Field{ - Type: graphql.String, - Resolve: GQLResourceName, - }) - - gql_type_gql_server.AddFieldConfig("Description", &graphql.Field{ - Type: graphql.String, - Resolve: GQLResourceDescription, - }) - - gql_type_gql_server.AddFieldConfig("Parents", &graphql.Field{ - Type: GQLListResource(), - Resolve: GQLResourceParents, - }) - - gql_type_gql_server.AddFieldConfig("Children", &graphql.Field{ - Type: GQLListResource(), - Resolve: GQLResourceChildren, - }) - - gql_type_gql_server.AddFieldConfig("Owner", &graphql.Field{ - Type: GQLInterfaceNode(), - Resolve: GQLResourceOwner, - }) - } - return gql_type_gql_server -} - -var gql_type_base_resource *graphql.Object = nil -func GQLTypeBaseResource() * graphql.Object { - if gql_type_base_resource == nil { - gql_type_base_resource = graphql.NewObject(graphql.ObjectConfig{ - Name: "BaseResource", - Interfaces: []*graphql.Interface{ - GQLInterfaceResource(), - GQLInterfaceNode(), - }, - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(*BaseResource) - return ok - }, - Fields: graphql.Fields{}, - }) - - gql_type_base_resource.AddFieldConfig("ID", &graphql.Field{ - Type: graphql.String, - Resolve: GQLResourceID, - }) - - gql_type_base_resource.AddFieldConfig("Name", &graphql.Field{ - Type: graphql.String, - Resolve: GQLResourceName, - }) - - gql_type_base_resource.AddFieldConfig("Description", &graphql.Field{ - Type: graphql.String, - Resolve: GQLResourceDescription, - }) - - gql_type_base_resource.AddFieldConfig("Parents", &graphql.Field{ - Type: GQLListResource(), - Resolve: GQLResourceParents, - }) - - gql_type_base_resource.AddFieldConfig("Children", &graphql.Field{ - Type: GQLListResource(), - Resolve: GQLResourceChildren, - }) - - gql_type_base_resource.AddFieldConfig("Owner", &graphql.Field{ - Type: GQLInterfaceNode(), - Resolve: GQLResourceOwner, - }) - } - - return gql_type_base_resource -} - -var gql_list_event * graphql.List = nil -func GQLListEvent() * graphql.List { - if gql_list_event == nil { - gql_list_event = graphql.NewList(GQLInterfaceEvent()) - } - return gql_list_event -} - -var gql_interface_event * graphql.Interface = nil -func GQLInterfaceEvent() * graphql.Interface { - if gql_interface_event == nil { - gql_interface_event = graphql.NewInterface(graphql.InterfaceConfig{ - Name: "Event", - ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { - valid_events, ok := p.Context.Value("valid_events").(map[reflect.Type]*graphql.Object) - if ok == false { - return nil - } - for key, value := range(valid_events) { - if reflect.TypeOf(p.Value) == key { - return value - } - } - return nil - }, - Fields: graphql.Fields{}, - }) - - if gql_list_event == nil { - gql_list_event = graphql.NewList(gql_interface_event) - } - - if gql_list_resource == nil { - gql_list_resource = graphql.NewList(GQLInterfaceResource()) - } - - gql_interface_event.AddFieldConfig("ID", &graphql.Field{ - Type: graphql.String, - }) - - gql_interface_event.AddFieldConfig("Name", &graphql.Field{ - Type: graphql.String, - }) - - gql_interface_event.AddFieldConfig("Description", &graphql.Field{ - Type: graphql.String, - }) - - gql_interface_event.AddFieldConfig("Children", &graphql.Field{ - Type: gql_list_event, - }) - - gql_interface_event.AddFieldConfig("Resources", &graphql.Field{ - Type: gql_list_resource, - }) - } - - return gql_interface_event -} - -var gql_type_base_event * graphql.Object = nil -func GQLTypeBaseEvent() * graphql.Object { - if gql_type_base_event == nil { - gql_type_base_event = graphql.NewObject(graphql.ObjectConfig{ - Name: "BaseEvent", - Interfaces: []*graphql.Interface{ - GQLInterfaceEvent(), - GQLInterfaceNode(), - }, - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(*BaseEvent) - return ok - }, - Fields: graphql.Fields{}, - }) - gql_type_base_event.AddFieldConfig("ID", &graphql.Field{ - Type: graphql.String, - Resolve: GQLEventID, - }) - - gql_type_base_event.AddFieldConfig("Name", &graphql.Field{ - Type: graphql.String, - Resolve: GQLEventName, - }) - - gql_type_base_event.AddFieldConfig("Description", &graphql.Field{ - Type: graphql.String, - Resolve: GQLEventDescription, - }) - - gql_type_base_event.AddFieldConfig("Children", &graphql.Field{ - Type: GQLListEvent(), - Resolve: GQLEventChildren, - }) - - gql_type_base_event.AddFieldConfig("Resources", &graphql.Field{ - Type: GQLListResource(), - Resolve: GQLEventResources, - }) - } - - return gql_type_base_event -} - -var gql_type_event_queue * graphql.Object = nil -func GQLTypeEventQueue() * graphql.Object { - if gql_type_event_queue == nil { - gql_type_event_queue = graphql.NewObject(graphql.ObjectConfig{ - Name: "EventQueue", - Interfaces: []*graphql.Interface{ - GQLInterfaceEvent(), - GQLInterfaceNode(), - }, - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(*EventQueue) - return ok - }, - Fields: graphql.Fields{}, - }) - gql_type_event_queue.AddFieldConfig("ID", &graphql.Field{ - Type: graphql.String, - Resolve: GQLEventID, - }) - gql_type_event_queue.AddFieldConfig("Name", &graphql.Field{ - Type: graphql.String, - Resolve: GQLEventName, - }) - gql_type_event_queue.AddFieldConfig("Description", &graphql.Field{ - Type: graphql.String, - Resolve: GQLEventDescription, - }) - gql_type_event_queue.AddFieldConfig("Children", &graphql.Field{ - Type: GQLListEvent(), - Resolve: GQLEventChildren, - }) - gql_type_event_queue.AddFieldConfig("Resources", &graphql.Field{ - Type: GQLListResource(), - Resolve: GQLEventResources, - }) - } - return gql_type_event_queue -} - -func GQLSignalFn(p graphql.ResolveParams, fn func(GraphSignal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { - if signal, ok := p.Source.(GraphSignal); ok { - return fn(signal, p) - } - return nil, errors.New("Failed to cast source to event") -} - -func GQLSignalType(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ - return signal.Type(), nil - }) -} - -func GQLSignalSource(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ - return signal.Source(), nil - }) -} - -func GQLSignalDownwards(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ - return signal.Downwards(), nil - }) -} - -func GQLSignalString(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ - return signal.String(), nil - }) -} - - -var gql_type_signal *graphql.Object = nil -func GQLTypeSignal() *graphql.Object { - if gql_type_signal == nil { - gql_type_signal = graphql.NewObject(graphql.ObjectConfig{ - Name: "SignalOut", - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(GraphSignal) - return ok - }, - Fields: graphql.Fields{}, - }) - - gql_type_signal.AddFieldConfig("Type", &graphql.Field{ - Type: graphql.String, - Resolve: GQLSignalType, - }) - gql_type_signal.AddFieldConfig("Source", &graphql.Field{ - Type: graphql.String, - Resolve: GQLSignalSource, - }) - gql_type_signal.AddFieldConfig("Downwards", &graphql.Field{ - Type: graphql.Boolean, - Resolve: GQLSignalDownwards, - }) - gql_type_signal.AddFieldConfig("String", &graphql.Field{ - Type: graphql.String, - Resolve: GQLSignalString, - }) - } - return gql_type_signal -} - -var gql_type_signal_input *graphql.InputObject = nil -func GQLTypeSignalInput() *graphql.InputObject { - if gql_type_signal_input == nil { - gql_type_signal_input = graphql.NewInputObject(graphql.InputObjectConfig{ - Name: "SignalIn", - Fields: graphql.InputObjectConfigFieldMap{}, - }) - gql_type_signal_input.AddFieldConfig("Type", &graphql.InputObjectFieldConfig{ - Type: graphql.String, - }) - gql_type_signal_input.AddFieldConfig("Description", &graphql.InputObjectFieldConfig{ - Type: graphql.String, - DefaultValue: "", - }) - gql_type_signal_input.AddFieldConfig("Time", &graphql.InputObjectFieldConfig{ - Type: graphql.DateTime, - DefaultValue: time.Now(), - }) - } - return gql_type_signal_input -} - -var gql_mutation_update_event *graphql.Field = nil -func GQLMutationUpdateEvent() *graphql.Field { - if gql_mutation_update_event == nil { - gql_mutation_update_event = &graphql.Field{ - Type: GQLTypeSignal(), - Args: graphql.FieldConfigArgument{ - "id": &graphql.ArgumentConfig{ - Type: graphql.String, - }, - "signal": &graphql.ArgumentConfig{ - Type: GQLTypeSignalInput(), - }, - }, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - server, ok := p.Context.Value("gql_server").(*GQLServer) - if ok == false { - return nil, errors.New(fmt.Sprintf("Failed to cast context gql_server to GQLServer: %+v", p.Context.Value("gql_server"))) - } - - signal_map, ok := p.Args["signal"].(map[string]interface{}) - if ok == false { - return nil, errors.New(fmt.Sprintf("Failed to cast arg signal to GraphSignal: %+v", p.Args["signal"])) - } - var signal GraphSignal = nil - if signal_map["Downwards"] == false { - signal = NewSignal(server, signal_map["Type"].(string)) - } else { - signal = NewDownSignal(server, signal_map["Type"].(string)) - } - - id , ok := p.Args["id"].(string) - if ok == false { - return nil, errors.New("Failed to cast arg id to string") - } - - owner := server.Owner() - if owner == nil { - return nil, errors.New("Cannot send update without owner") - } - - root_event, ok := owner.(Event) - if ok == false { - return nil, errors.New("Cannot send update to Event unless owned by an Event") - } - - node := FindChild(root_event, id) - if node == nil { - return nil, errors.New("Failed to find id in event tree from server") - } - - SendUpdate(node, signal) - return signal, nil - }, - } - } - - return gql_mutation_update_event -} - -type GQLServer struct { - BaseResource - abort chan error - listen string - gql_channel chan error - extended_types map[reflect.Type]*graphql.Object - extended_queries map[string]*graphql.Field - extended_subscriptions map[string]*graphql.Field - extended_mutations map[string]*graphql.Field -} - -type ObjTypeMap map[reflect.Type]*graphql.Object -type FieldMap map[string]*graphql.Field - -func NewGQLServer(listen string, extended_types ObjTypeMap, extended_queries FieldMap, extended_mutations FieldMap, extended_subscriptions FieldMap) * GQLServer { - server := &GQLServer{ - BaseResource: NewBaseResource("GQL Server", "graphql server for event signals"), - listen: listen, - abort: make(chan error, 1), - gql_channel: make(chan error, 1), - extended_types: extended_types, - extended_queries: extended_queries, - extended_mutations: extended_mutations, - extended_subscriptions: extended_subscriptions, - } - - go func() { - Log.Logf("gql", "GOROUTINE_START for %s", server.ID()) - - mux := http.NewServeMux() - http_handler, ws_handler := MakeGQLHandlers(server) - mux.HandleFunc("/gql", http_handler) - mux.HandleFunc("/gqlws", ws_handler) - mux.HandleFunc("/graphiql", GraphiQLHandler()) - fs := http.FileServer(http.Dir("./site")) - mux.Handle("/site/", http.StripPrefix("/site", fs)) - - srv := &http.Server{ - Addr: server.listen, - Handler: mux, - } - - http_done := &sync.WaitGroup{} - http_done.Add(1) - go func(srv *http.Server, http_done *sync.WaitGroup) { - defer http_done.Done() - err := srv.ListenAndServe() - if err != http.ErrServerClosed { - panic(fmt.Sprintf("Failed to start gql server: %s", err)) - } - }(srv, http_done) - - for true { - select { - case signal:=<-server.signal: - if signal.Type() == "abort" || signal.Type() == "cancel" { - err := srv.Shutdown(context.Background()) - if err != nil{ - panic(fmt.Sprintf("Failed to shutdown gql server: %s", err)) - } - http_done.Wait() - break - } - Log.Logf("gql", "GOROUTINE_SIGNAL for %s: %+v", server.ID(), signal) - // Take signals to resource and send to GQL subscriptions - } - } - }() - - return server -} - -func (server * GQLServer) PropagateUpdate(signal GraphSignal) { - server.signal <- signal - server.BaseResource.PropagateUpdate(signal) -} - -func GQLSubscribeSignal(p graphql.ResolveParams) (interface{}, error) { - return GQLSubscribeFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error) { - return signal, nil - }) -} - -func GQLSubscribeFn(p graphql.ResolveParams, fn func(GraphSignal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { - server, ok := p.Context.Value("gql_server").(*GQLServer) - if ok == false { - return nil, fmt.Errorf("Failed to get gql_Server from context and cast to GQLServer") - } - - c := make(chan interface{}) - go func(c chan interface{}, server *GQLServer) { - sig_c := server.UpdateChannel() - for { - val, ok := <- sig_c - if ok == false { - return - } - ret, err := fn(val, p) - if err != nil { - Log.Logf("gqlws", "type convertor error %s", err) - return - } - c <- ret - } - }(c, server) - return c, nil -} - -var gql_subscription_update * graphql.Field = nil -func GQLSubscriptionUpdate() * graphql.Field { - if gql_subscription_update == nil { - gql_subscription_update = &graphql.Field{ - Type: GQLTypeSignal(), - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - return p.Source, nil - }, - Subscribe: GQLSubscribeSignal, - } - } - - return gql_subscription_update -} - -func MakeGQLHandlers(server * GQLServer) (func(http.ResponseWriter, *http.Request), func(http.ResponseWriter, *http.Request)) { - valid_events := map[reflect.Type]*graphql.Object{} - valid_events[reflect.TypeOf((*BaseEvent)(nil))] = GQLTypeBaseEvent() - valid_events[reflect.TypeOf((*EventQueue)(nil))] = GQLTypeEventQueue() - - valid_resources := map[reflect.Type]*graphql.Object{} - valid_resources[reflect.TypeOf((*BaseResource)(nil))] = GQLTypeBaseResource() - valid_resources[reflect.TypeOf((*GQLServer)(nil))] = GQLTypeGQLServer() - - gql_types := []graphql.Type{GQLTypeBaseEvent(), GQLTypeEventQueue(), GQLTypeSignal(), GQLTypeSignalInput(), GQLTypeBaseNode(), GQLTypeGQLServer(), GQLTypeBaseResource()} - event_type := reflect.TypeOf((*Event)(nil)).Elem() - resource_type := reflect.TypeOf((*Resource)(nil)).Elem() - for go_t, gql_t := range(server.extended_types) { - if go_t.Implements(event_type) { - valid_events[go_t] = gql_t - } else if go_t.Implements(resource_type) { - valid_resources[go_t] = gql_t - } - gql_types = append(gql_types, gql_t) - } - - gql_queries := graphql.Fields{ - "Owner": GQLQueryOwner(), - } - - for key, value := range(server.extended_queries) { - gql_queries[key] = value - } - - gql_subscriptions := graphql.Fields{ - "Updates": GQLSubscriptionUpdate(), - } - - for key, value := range(server.extended_subscriptions) { - gql_subscriptions[key] = value - } - - gql_mutations := graphql.Fields{ - "updateEvent": GQLMutationUpdateEvent(), - } - - for key, value := range(server.extended_mutations) { - gql_mutations[key] = value - } - - schemaConfig := graphql.SchemaConfig{ - Types: gql_types, - Query: graphql.NewObject(graphql.ObjectConfig{ - Name: "Query", - Fields: gql_queries, - }), - Mutation: graphql.NewObject(graphql.ObjectConfig{ - Name: "Mutation", - Fields: gql_mutations, - }), - Subscription: graphql.NewObject(graphql.ObjectConfig{ - Name: "Subscription", - Fields: gql_subscriptions, - }), - } - - schema, err := graphql.NewSchema(schemaConfig) - if err != nil{ - panic(err) - } - ctx := context.Background() - ctx = context.WithValue(ctx, "valid_events", valid_events) - ctx = context.WithValue(ctx, "valid_resources", valid_resources) - ctx = context.WithValue(ctx, "gql_server", server) - return GQLHandler(schema, ctx), GQLWSHandler(schema, ctx) -} - -var gql_query_owner *graphql.Field = nil -func GQLQueryOwner() *graphql.Field { - if gql_query_owner == nil { - gql_query_owner = &graphql.Field{ - Type: GQLInterfaceEvent(), - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - server, ok := p.Context.Value("gql_server").(*GQLServer) - - if ok == false { - panic("Failed to get/cast gql_server from context") - } - - return server.Owner(), nil - }, - } - } - - return gql_query_owner -} diff --git a/graph.go b/graph.go index 9c25bf9..fd64154 100644 --- a/graph.go +++ b/graph.go @@ -2,31 +2,53 @@ package graphvent import ( "sync" + "reflect" "github.com/google/uuid" - "time" "os" "github.com/rs/zerolog" "fmt" - "encoding/json" + badger "github.com/dgraph-io/badger/v3" ) +type GraphContext struct { + DB * badger.DB + Log Logger +} + +func NewGraphContext(db * badger.DB, log Logger) * GraphContext { + return &GraphContext{DB: db, Log: log} +} + +// A Logger is passed around to record events happening to components enabled by SetComponents type Logger interface { - Init() error + SetComponents(components []string) error + // Log a formatted string Logf(component string, format string, items ... interface{}) + // Log a map of attributes and a format string Logm(component string, fields map[string]interface{}, format string, items ... interface{}) } -type DefaultLogger struct { - init_lock sync.Mutex - Loggers map[string]zerolog.Logger - Components []string +func NewConsoleLogger(components []string) *ConsoleLogger { + logger := &ConsoleLogger{ + loggers: map[string]zerolog.Logger{}, + components: []string{}, + } + + logger.SetComponents(components) + + return logger } -var Log DefaultLogger = DefaultLogger{Loggers: map[string]zerolog.Logger{}, Components: []string{}} +// A ConsoleLogger logs to stdout +type ConsoleLogger struct { + loggers map[string]zerolog.Logger + components_lock sync.Mutex + components []string +} -func (logger * DefaultLogger) Init(components []string) error { - logger.init_lock.Lock() - defer logger.init_lock.Unlock() +func (logger * ConsoleLogger) SetComponents(components []string) error { + logger.components_lock.Lock() + defer logger.components_lock.Unlock() component_enabled := func (component string) bool { for _, c := range(components) { @@ -37,23 +59,23 @@ func (logger * DefaultLogger) Init(components []string) error { return false } - for c, _ := range(logger.Loggers) { + for c, _ := range(logger.loggers) { if component_enabled(c) == false { - delete(logger.Loggers, c) + delete(logger.loggers, c) } } for _, c := range(components) { - _, exists := logger.Loggers[c] + _, exists := logger.loggers[c] if component_enabled(c) == true && exists == false { - logger.Loggers[c] = zerolog.New(os.Stdout).With().Timestamp().Str("component", c).Logger() + logger.loggers[c] = zerolog.New(os.Stdout).With().Timestamp().Str("component", c).Logger() } } return nil } -func (logger * DefaultLogger) Logm(component string, fields map[string]interface{}, format string, items ... interface{}) { - l, exists := logger.Loggers[component] +func (logger * ConsoleLogger) Logm(component string, fields map[string]interface{}, format string, items ... interface{}) { + l, exists := logger.loggers[component] if exists == true { log := l.Log() for key, value := range(fields) { @@ -63,37 +85,48 @@ func (logger * DefaultLogger) Logm(component string, fields map[string]interface } } -func (logger * DefaultLogger) Logf(component string, format string, items ... interface{}) { - l, exists := logger.Loggers[component] +func (logger * ConsoleLogger) Logf(component string, format string, items ... interface{}) { + l, exists := logger.loggers[component] if exists == true { l.Log().Msg(fmt.Sprintf(format, items...)) } } -// Generate a random graphql id -func randid() string{ +type NodeID string +// Generate a random id +func RandID() NodeID { uuid_str := uuid.New().String() - return uuid_str + return NodeID(uuid_str) } +type SignalDirection int +const ( + Up SignalDirection = iota + Down + Direct +) + +// GraphSignals are passed around the event tree/resource DAG and cast by Type() type GraphSignal interface { - Downwards() bool - Source() string + // How to propogate the signal + Direction() SignalDirection + Source() NodeID Type() string String() string } +// BaseSignal is the most basic type of signal, it has no additional data type BaseSignal struct { - downwards bool - source string + direction SignalDirection + source NodeID _type string } -func (signal BaseSignal) Downwards() bool { - return signal.downwards +func (signal BaseSignal) Direction() SignalDirection { + return signal.direction } -func (signal BaseSignal) Source() string { +func (signal BaseSignal) Source() NodeID { return signal.source } @@ -102,22 +135,17 @@ func (signal BaseSignal) Type() string { } func (signal BaseSignal) String() string { - return fmt.Sprintf("{downwards: %t, source: %s, type: %s}", signal.downwards, signal.source, signal._type) -} - -type TimeSignal struct { - BaseSignal - time time.Time + return fmt.Sprintf("{direction: %d, source: %s, type: %s}", signal.direction, signal.source, signal._type) } -func NewBaseSignal(source GraphNode, _type string, downwards bool) BaseSignal { - source_id := "" +func NewBaseSignal(source GraphNode, _type string, direction SignalDirection) BaseSignal { + var source_id NodeID = "" if source != nil { source_id = source.ID() } signal := BaseSignal{ - downwards: downwards, + direction: direction, source: source_id, _type: _type, } @@ -125,127 +153,198 @@ func NewBaseSignal(source GraphNode, _type string, downwards bool) BaseSignal { } func NewDownSignal(source GraphNode, _type string) BaseSignal { - return NewBaseSignal(source, _type, true) + return NewBaseSignal(source, _type, Down) } func NewSignal(source GraphNode, _type string) BaseSignal { - return NewBaseSignal(source, _type, false) -} - -type NodeState interface { - Name() string - Description() string - DelegationMap() map[string]GraphNode + return NewBaseSignal(source, _type, Up) } -type BaseNodeState struct { - name string - description string - delegation_map map[string]GraphNode +func NewDirectSignal(source GraphNode, _type string) BaseSignal { + return NewBaseSignal(source, _type, Direct) } -func (state * BaseNodeState) Name() string { - return state.name +func NewAbortSignal(source GraphNode) BaseSignal { + return NewBaseSignal(source, "abort", Down) } -func (state * BaseNodeState) Description() string { - return state.description +func NewCancelSignal(source GraphNode) BaseSignal { + return NewBaseSignal(source, "cancel", Down) } -func (state * BaseNodeState) DelegationMap() map[string]GraphNode { - return state.delegation_map +type NodeState interface { + Serialize() []byte + OriginalLockHolder(id NodeID) GraphNode + AllowedToTakeLock(id NodeID) bool + RecordLockHolder(id NodeID, lock_holder GraphNode) NodeState } - // GraphNode is the interface common to both DAG nodes and Event tree nodes +// They have a NodeState interface which is saved to the database every update type GraphNode interface { + ID() NodeID + State() NodeState - ID() string - UpdateListeners(update GraphSignal) - PropagateUpdate(update GraphSignal) + StateLock() *sync.RWMutex + + SetState(new_state NodeState) + DeserializeState([]byte) NodeState + + // Signal propagation function for listener channels + UpdateListeners(ctx * GraphContext, update GraphSignal) + // Signal propagation function for connected nodes(defined in state) + PropagateUpdate(ctx * GraphContext, update GraphSignal) + + // Register and unregister a channel to propogate updates to RegisterChannel(listener chan GraphSignal) UnregisterChannel(listener chan GraphSignal) + // Get a handle to the nodes internal signal channel SignalChannel() chan GraphSignal } -func (node * BaseNode) StateLock() *sync.Mutex { - return &node.state_lock -} - -func NewBaseNode(id string) BaseNode { +// Create a new base node with the given ID +func NewNode(ctx * GraphContext, id NodeID, state NodeState) BaseNode { node := BaseNode{ id: id, signal: make(chan GraphSignal, 512), listeners: map[chan GraphSignal]chan GraphSignal{}, + state: state, + } + + err := WriteDBState(ctx, id, state) + if err != nil { + panic(fmt.Sprintf("DB_NEW_WRITE_ERROR: %s", err)) } - Log.Logf("graph", "NEW_NODE: %s", node.ID()) + + ctx.Log.Logf("graph", "NEW_NODE: %s - %+v", id, state) return node } -// BaseNode is the most basic implementation of the GraphNode interface -// It is used to implement functions common to Events and Resources +// BaseNode is the minimum set of fields needed to implement a GraphNode, +// and provides a template for more complicated Nodes type BaseNode struct { - id string + id NodeID + state NodeState state_lock sync.RWMutex + signal chan GraphSignal + listeners_lock sync.Mutex listeners map[chan GraphSignal]chan GraphSignal } -func (node * BaseNode) SignalChannel() chan GraphSignal { - return node.signal +func (node * BaseNode) ID() NodeID { + return node.id } func (node * BaseNode) State() NodeState { return node.state } -func (node * BaseNode) ID() string { - return node.id +func (node * BaseNode) StateLock() * sync.RWMutex { + return &node.state_lock } -const listener_buffer = 100 -func GetUpdateChannel(node * BaseNode) chan GraphSignal { - new_listener := make(chan GraphSignal, listener_buffer) - node.RegisterChannel(new_listener) - return new_listener +func (node * BaseNode) DeserializeState([]byte) NodeState { + return nil } -func (node * BaseNode) RegisterChannel(listener chan GraphSignal) { - node.listeners_lock.Lock() - _, exists := node.listeners[listener] - if exists == false { - node.listeners[listener] = listener +func WriteDBState(ctx * GraphContext, id NodeID, state NodeState) error { + ctx.Log.Logf("db", "DB_WRITE: %s - %+v", id, state) + + var serialized_state []byte = nil + if state != nil { + serialized_state = state.Serialize() + } else { + serialized_state = []byte{} } - node.listeners_lock.Unlock() + + err := ctx.DB.Update(func(txn *badger.Txn) error { + err := txn.Set([]byte(id), serialized_state) + return err + }) + + return err } -func (node * BaseNode) UnregisterChannel(listener chan GraphSignal) { - node.listeners_lock.Lock() - _, exists := node.listeners[listener] - if exists == false { - panic("Attempting to unregister non-registered listener") - } else { - delete(node.listeners, listener) +func (node * BaseNode) SetState(new_state NodeState) { + node.state = new_state +} + +func UseStates(ctx * GraphContext, nodes []GraphNode, states_fn func(states []NodeState)(interface{}, error)) (interface{}, error) { + for _, node := range(nodes) { + node.StateLock().RLock() } - node.listeners_lock.Unlock() + + states := make([]NodeState, len(nodes)) + for i, node := range(nodes) { + states[i] = node.State() + } + + val, err := states_fn(states) + + for _, node := range(nodes) { + node.StateLock().RUnlock() + } + + return val, err } -func (node * BaseNode) PropagateUpdate(update GraphSignal) { +func UpdateStates(ctx * GraphContext, nodes []GraphNode, states_fn func(states []NodeState)([]NodeState, interface{}, error)) (interface{}, error) { + for _, node := range(nodes) { + node.StateLock().Lock() + } + + states := make([]NodeState, len(nodes)) + for i, node := range(nodes) { + states[i] = node.State() + } + + new_states, val, err := states_fn(states) + + if new_states != nil { + if len(new_states) != len(nodes) { + panic(fmt.Sprintf("NODE_NEW_STATE_LEN_MISMATCH: %d/%d", len(new_states), len(nodes))) + } + + for i, new_state := range(new_states) { + if new_state != nil { + old_state_type := reflect.TypeOf(states[i]) + new_state_type := reflect.TypeOf(new_state) + + if old_state_type != new_state_type { + panic(fmt.Sprintf("NODE_STATE_MISMATCH: old - %+v, new - %+v", old_state_type, new_state_type)) + } + + err := WriteDBState(ctx, nodes[i].ID(), new_state) + if err != nil { + panic(fmt.Sprintf("DB_WRITE_ERROR: %s", err)) + } + + nodes[i].SetState(new_state) + } + } + } + + for _, node := range(nodes) { + node.StateLock().Unlock() + } + + return val, err } -func (node * BaseNode) UpdateListeners(update GraphSignal) { - node.ListenersLock.Lock() - defer node.ListenersLock.Unlock() - closed := []chan GraphSignal +func (node * BaseNode) UpdateListeners(ctx * GraphContext, update GraphSignal) { + node.listeners_lock.Lock() + defer node.listeners_lock.Unlock() + closed := []chan GraphSignal{} - for _, listener := range node.Listeners() { - Log.Logf("listeners", "UPDATE_LISTENER %s: %p", node.ID(), listener) + for _, listener := range node.listeners { + ctx.Log.Logf("listeners", "UPDATE_LISTENER %s: %p", node.ID(), listener) select { - case listener <- signal: + case listener <- update: default: - Log.Logf("listeners", "CLOSED_LISTENER %s: %p", node.ID(), listener) + ctx.Log.Logf("listeners", "CLOSED_LISTENER %s: %p", node.ID(), listener) go func(node GraphNode, listener chan GraphSignal) { listener <- NewSignal(node, "listener_closed") close(listener) @@ -259,34 +358,49 @@ func (node * BaseNode) UpdateListeners(update GraphSignal) { } } -func SendUpdate(node GraphNode, signal GraphSignal) { - node_name := "nil" - if node != nil { - node_name = node.Name() - } - Log.Logf("update", "UPDATE %s <- %s: %+v", node_name, signal.Source(), signal) - node.ListenersLock.Lock() - defer node.ListenersLock.Unlock() - closed := []chan GraphSignal +func (node * BaseNode) PropagateUpdate(ctx * GraphContext, update GraphSignal) { +} - for _, listener := range node.Listeners() { - Log.Logf("listeners", "UPDATE_LISTENER %s: %p", node.ID(), listener) - select { - case listener <- signal: - default: - Log.Logf("listeners", "CLOSED_LISTENER %s: %p", node.ID(), listener) - go func(node GraphNode, listener chan GraphSignal) { - listener <- NewSignal(node, "listener_closed") - close(listener) - }(node, listener) - closed = append(closed, listener) - } +func (node * BaseNode) RegisterChannel(listener chan GraphSignal) { + node.listeners_lock.Lock() + _, exists := node.listeners[listener] + if exists == false { + node.listeners[listener] = listener } + node.listeners_lock.Unlock() +} - for _, listener := range(closed) { +func (node * BaseNode) UnregisterChannel(listener chan GraphSignal) { + node.listeners_lock.Lock() + _, exists := node.listeners[listener] + if exists == false { + panic("Attempting to unregister non-registered listener") + } else { delete(node.listeners, listener) } + node.listeners_lock.Unlock() +} + +func (node * BaseNode) SignalChannel() chan GraphSignal { + return node.signal +} + +// Create a new GraphSinal channel with a buffer of size buffer and register it to a node +func GetUpdateChannel(node * BaseNode, buffer int) chan GraphSignal { + new_listener := make(chan GraphSignal, buffer) + node.RegisterChannel(new_listener) + return new_listener +} + +// Propogate a signal starting at a node +func SendUpdate(ctx * GraphContext, node GraphNode, signal GraphSignal) { + if node == nil { + panic("Cannot start an update from no node") + } + + ctx.Log.Logf("update", "UPDATE %s <- %s: %+v", node.ID(), signal.Source(), signal) - node.PropagateUpdate(signal) + node.UpdateListeners(ctx, signal) + node.PropagateUpdate(ctx, signal) } diff --git a/graph_test.go b/graph_test.go index 35bfc2c..6faed47 100644 --- a/graph_test.go +++ b/graph_test.go @@ -2,23 +2,23 @@ package graphvent import ( "testing" - "time" "fmt" - "os" + "time" "runtime/pprof" + "os" badger "github.com/dgraph-io/badger/v3" ) type GraphTester testing.T const listner_timeout = 50 * time.Millisecond -func (t * GraphTester) WaitForValue(listener chan GraphSignal, signal_type string, source GraphNode, timeout time.Duration, str string) GraphSignal { +func (t * GraphTester) WaitForValue(ctx * GraphContext, listener chan GraphSignal, signal_type string, source GraphNode, timeout time.Duration, str string) GraphSignal { timeout_channel := time.After(timeout) for true { select { case signal := <- listener: if signal.Type() == signal_type { - Log.Logf("test", "SIGNAL_TYPE_FOUND: %s - %s %+v\n", signal.Type(), signal.Source(), listener) + ctx.Log.Logf("test", "SIGNAL_TYPE_FOUND: %s - %s %+v\n", signal.Type(), signal.Source(), listener) if signal.Source() == source.ID() { return signal } @@ -32,18 +32,6 @@ func (t * GraphTester) WaitForValue(listener chan GraphSignal, signal_type strin return nil } -func (t * GraphTester) CheckForValue(listener chan GraphSignal, str string) GraphSignal { - timeout := time.After(listner_timeout) - select { - case signal := <- listener: - return signal - case <-timeout: - pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) - t.Fatal(str) - return nil - } -} - func (t * GraphTester) CheckForNone(listener chan GraphSignal, str string) { timeout := time.After(listner_timeout) select { @@ -54,462 +42,26 @@ func (t * GraphTester) CheckForNone(listener chan GraphSignal, str string) { } } -func TestNewEventWithResource(t *testing.T) { - db, err := badger.Open(badger.DefaultOptions("/tmp/badger1")) - if err != nil { - t.Fatal(err) - } - name := "Test Resource" - description := "A resource for testing" - children := []Resource{} - - test_resource, _ := NewResource(name, description, children) - root_event, err := NewEvent(db, "root_event", "", []Resource{test_resource}) - if err != nil { - t.Fatal(err) - } - - res := FindResource(root_event, test_resource.ID()) - if res == nil { - t.Fatal("Failed to find Resource in EventManager after adding") - } - - if res.Name() != name || res.Description() != description { - t.Fatal("Name/description of returned resource did not match added resource") - } -} - -func TestDoubleResourceAdd(t * testing.T) { - db, err := badger.Open(badger.DefaultOptions("/tmp/badger2")) - if err != nil { - t.Fatal(err) - } - test_resource, _ := NewResource("", "", []Resource{}) - _, err = NewEvent(db, "", "", []Resource{test_resource, test_resource}) - - if err == nil { - t.Fatal("NewEvent didn't return an error") - } -} - -func TestTieredResource(t * testing.T) { - db, err := badger.Open(badger.DefaultOptions("/tmp/badger3")) - if err != nil { - t.Fatal(err) - } - - r1, _ := NewResource("r1", "", []Resource{}) - r2, err := NewResource("r2", "", []Resource{r1}) - if err != nil { - t.Fatal(err) - } - _, err = NewEvent(db, "", "", []Resource{r2}) - - if err != nil { - t.Fatal("Failed to create event with tiered resources") - } -} - -func TestResourceUpdate(t * testing.T) { - db, err := badger.Open(badger.DefaultOptions("/tmp/badger4")) - if err != nil { - t.Fatal(err) - } - - r1, err := NewResource("r1", "", []Resource{}) - if err != nil { - t.Fatal(err) - } - r2, err := NewResource("r2", "", []Resource{}) - if err != nil { - t.Fatal(err) - } - r3, err := NewResource("r3", "", []Resource{r1, r2}) - if err != nil { - t.Fatal(err) - } - r4, err := NewResource("r4", "", []Resource{r3}) - if err != nil { - t.Fatal(err) - } - - _, err = NewEvent(db, "", "", []Resource{r3, r4}) - if err != nil { - t.Fatal("Failed to add initial tiered resources for test") - } - - r1_l := r1.UpdateChannel() - r2_l := r2.UpdateChannel() - r3_l := r3.UpdateChannel() - r4_l := r4.UpdateChannel() - - // Calling Update() on the parent with no other parents should only notify node listeners - SendUpdate(r3, NewSignal(nil, "test")) - (*GraphTester)(t).CheckForNone(r1_l, "Update on r1 after updating r3") - (*GraphTester)(t).CheckForNone(r2_l, "Update on r2 after updating r3") - (*GraphTester)(t).CheckForValue(r3_l, "No update on r3 after updating r3") - (*GraphTester)(t).CheckForValue(r4_l, "No update on r4 after updating r3") - - // Calling Update() on a child should notify listeners of the parent and child, but not siblings - SendUpdate(r2, NewSignal(nil, "test")) - (*GraphTester)(t).CheckForNone(r1_l, "Update on r1 after updating r2") - (*GraphTester)(t).CheckForValue(r2_l, "No update on r2 after updating r2") - (*GraphTester)(t).CheckForValue(r3_l, "No update on r3 after updating r2") - (*GraphTester)(t).CheckForValue(r4_l, "No update on r4 after updating r2") - - // Calling Update() on a child should notify listeners of the parent and child, but not siblings - SendUpdate(r1, NewSignal(nil, "test")) - (*GraphTester)(t).CheckForValue(r1_l, "No update on r1 after updating r1") - (*GraphTester)(t).CheckForNone(r2_l, "Update on r2 after updating r1") - (*GraphTester)(t).CheckForValue(r3_l, "No update on r3 after updating r1") - (*GraphTester)(t).CheckForValue(r4_l, "No update on r4 after updating r1") -} - -func TestAddEvent(t * testing.T) { - db, err := badger.Open(badger.DefaultOptions("/tmp/badger5")) - if err != nil { - t.Fatal(err) - } - r1, _ := NewResource("r1", "", []Resource{}) - r2, _ := NewResource("r2", "", []Resource{r1}) - root_event, _ := NewEvent(db, "", "", []Resource{r2}) - - name := "Test Event" - description := "A test event" - resources := []Resource{r2} - new_event, _ := NewEvent(db, name, description, resources) - - err = LinkEvent(root_event, new_event, nil) - if err != nil { - t.Fatalf("Failed to add new_event to root_event: %s", err) - } - - res := FindChild(root_event, new_event.ID()) - if res == nil { - t.Fatalf("Failed to find new_event in event_manager: %s", err) - } - - if res.Name() != name || res.Description() != description { - t.Fatal("Event found in event_manager didn't match added") - } - - res_required := res.Resources() - if len(res_required) < 1 { - t.Fatal("Event found in event_manager didn't match added") - } else if res_required[0].ID() != r2.ID() { - t.Fatal("Event found in event_manager didn't match added") - } -} - -func TestLockResource(t * testing.T) { - db, err := badger.Open(badger.DefaultOptions("/tmp/badger6")) - if err != nil { - t.Fatal(err) - } - r1, err := NewResource("r1", "", []Resource{}) - if err != nil { - t.Fatal(err) - } - r2, err := NewResource("r2", "", []Resource{}) - if err != nil { - t.Fatal(err) - } - r3, err := NewResource("r3", "", []Resource{r1, r2}) - if err != nil { - t.Fatal(err) - } - r4, err := NewResource("r4", "", []Resource{r1, r2}) - if err != nil { - t.Fatal(err) - } - root_event, err := NewEvent(db, "", "", []Resource{}) - if err != nil { - t.Fatal(err) - } - test_event, err := NewEvent(db, "", "", []Resource{}) - if err != nil { - t.Fatal(err) - } - - r1_l := r1.UpdateChannel() - rel := root_event.UpdateChannel() - - err = LockResource(r3, root_event) - if err != nil { - t.Fatal("Failed to lock r3") - } - SendUpdate(r3, NewDownSignal(r3, "locked")) - - (*GraphTester)(t).WaitForValue(r1_l, "locked", r3, time.Second, "Wasn't notified of r1 lock on r1 after r3 lock") - (*GraphTester)(t).WaitForValue(rel, "locked", r3, time.Second, "Wasn't notified of r1 lock on rel after r3 lock") - - err = LockResource(r3, root_event) - if err == nil { - t.Fatal("Locked r3 after locking r3") - } - - err = LockResource(r4, root_event) - if err == nil { - t.Fatal("Locked r4 after locking r3") - } - - err = LockResource(r1, root_event) - if err == nil { - t.Fatal("Locked r1 after locking r3") - } - - err = UnlockResource(r3, test_event) - if err == nil { - t.Fatal("Unlocked r3 with event that didn't lock it") - } - - err = UnlockResource(r3, root_event) - if err != nil { - t.Fatal("Failed to unlock r3") - } - SendUpdate(r3, NewDownSignal(r3, "unlocked")) - (*GraphTester)(t).WaitForValue(r1_l, "unlocked", r3, time.Second * 2, "Wasn't notified of r1 unlock on r1 after r3 unlock") - - err = LockResource(r4, root_event) - if err != nil { - t.Fatal("Failed to lock r4 after unlocking r3") - } - SendUpdate(r4, NewDownSignal(r4, "locked")) - (*GraphTester)(t).WaitForValue(r1_l, "locked", r4, time.Second * 2, "Wasn't notified of r1 lock on r1 after r4 lock") - (*GraphTester)(t).WaitForValue(rel, "locked", r4, time.Second * 2, "Wasn't notified of r1 lock on r1 after r4 lock") - - err = UnlockResource(r4, root_event) - if err != nil { - t.Fatal("Failed to unlock r4") - } - SendUpdate(r4, NewDownSignal(r4, "unlocked")) - (*GraphTester)(t).WaitForValue(r1_l, "unlocked", r4, time.Second * 2, "Wasn't notified of r1 unlock on r1 after r4 lock") -} - -func TestAddToEventQueue(t * testing.T) { - db, err := badger.Open(badger.DefaultOptions("/tmp/badger7")) - if err != nil { - t.Fatal(err) - } - queue, _ := NewEventQueue("q", "", []Resource{}) - event_1, _ := NewEvent(db, "1", "", []Resource{}) - event_2, _ := NewEvent(db, "2", "", []Resource{}) - - err = LinkEvent(queue, event_1, nil) - if err == nil { - t.Fatal("suceeded in added nil info to queue") - } - - err = LinkEvent(queue, event_1, &EventQueueInfo{priority: 0}) - if err != nil { - t.Fatal("failed to add valid event + info to queue") - } - - err = LinkEvent(queue, event_2, &EventQueueInfo{priority: 1}) - if err != nil { - t.Fatal("failed to add valid event + info to queue") - } -} - -func TestStartBaseEvent(t * testing.T) { - db, err := badger.Open(badger.DefaultOptions("/tmp/badger8")) - if err != nil { - t.Fatal(err) - } - event_1, _ := NewEvent(db, "TestStartBaseEvent event_1", "", []Resource{}) - r := event_1.DoneResource() - - e_l := event_1.UpdateChannel() - r_l := r.UpdateChannel() - (*GraphTester)(t).CheckForNone(e_l, "Update on event_1 before starting") - (*GraphTester)(t).CheckForNone(r_l, "Update on r_1 before starting") - - if r.Owner().ID() != event_1.ID() { - t.Fatal("r is not owned by event_1") - } - - err = LockResources(event_1) - if err != nil { - t.Fatal(err) - } - - err = RunEvent(event_1) - if err != nil { - t.Fatal(err) - } - // Check that the update channels for the event and resource have updates - (*GraphTester)(t).WaitForValue(e_l, "event_start", event_1, 1*time.Second, "No event_start on e_l") - (*GraphTester)(t).WaitForValue(e_l, "event_done", event_1, 1*time.Second, "No event_start on e_l") - (*GraphTester)(t).WaitForValue(r_l, "unlocked", event_1, 1*time.Second, "No unlocked on r_l") - - if r.Owner() != nil { - t.Fatal("r still owned after event completed") - } -} - -func TestAbortEventQueue(t * testing.T) { - r1, _ := NewResource("r1", "", []Resource{}) - root_event, _ := NewEventQueue("root_event", "", []Resource{}) - r := root_event.DoneResource() - - LockResource(r1, root_event) - - // Now that the event is constructed with a queue and 3 basic events - // start the queue and check that all the events are executed - go func() { - time.Sleep(100 * time.Millisecond) - abort_signal := NewDownSignal(nil, "abort") - SendUpdate(root_event, abort_signal) - }() - - err := LockResources(root_event) - if err != nil { - t.Fatal(err) - } - err = RunEvent(root_event) - if err == nil { - t.Fatal("root_event completed without error") - } - - if r.Owner() == nil { - t.Fatal("root event was finished after starting") - } -} - -func TestDelegateLock(t * testing.T) { - db, err := badger.Open(badger.DefaultOptions("/tmp/badger")) - if err != nil { - t.Fatal(err) - } - test_resource, _ := NewResource("test_resource", "", []Resource{}) - root_event, _ := NewEventQueue("root_event", "", []Resource{test_resource}) - test_event, _ := NewEvent(db, "test_event", "", []Resource{test_resource}) - err = LinkEvent(root_event, test_event, NewEventQueueInfo(1)) +func logTestContext(t * testing.T, components []string) * GraphContext { + db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true)) if err != nil { t.Fatal(err) } - err = LockResources(root_event) - if err != nil { - t.Fatal(err) - } - - test_listener := test_event.UpdateChannel() - - go func() { - (*GraphTester)(t).WaitForValue(test_listener, "event_done", test_event, 250 * time.Millisecond, "No event_done for test_event") - if test_resource.Owner().ID() != root_event.ID() { - t.Fatal("Lock failed to pass back to root_event") - } - abort_signal := NewDownSignal(nil, "cancel") - SendUpdate(root_event, abort_signal) - }() - - err = RunEvent(root_event) - if err != nil { - t.Fatal(err) - } + return NewGraphContext(db, NewConsoleLogger(components)) } -func TestStartWithoutLocking(t * testing.T) { - db, err := badger.Open(badger.DefaultOptions("/tmp/badger9")) +func testContext(t * testing.T) * GraphContext { + db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true)) if err != nil { t.Fatal(err) } - test_resource, _ := NewResource("test_resource", "", []Resource{}) - root_event, _ := NewEvent(db, "root_event", "", []Resource{test_resource}) - err = RunEvent(root_event) - if err == nil { - t.Fatal("Event ran without error without locking resources") - } + return NewGraphContext(db, NewConsoleLogger([]string{})) } -func TestStartEventQueue(t * testing.T) { - db, err := badger.Open(badger.DefaultOptions("/tmp/badger10")) - if err != nil { - t.Fatal(err) - } - root_event, _ := NewEventQueue("root_event", "", []Resource{}) - r := root_event.DoneResource() - rel := root_event.UpdateChannel(); - res_1, _ := NewResource("test_resource_1", "", []Resource{}) - res_2, _ := NewResource("test_resource_2", "", []Resource{}) - - - e1, _ := NewEvent(db, "e1", "", []Resource{res_1, res_2}) - e1_l := e1.UpdateChannel() - e1_r := e1.DoneResource() - e1_info := NewEventQueueInfo(1) - err = LinkEvent(root_event, e1, e1_info) - if err != nil { - t.Fatal("Failed to add e1 to root_event") - } - (*GraphTester)(t).WaitForValue(rel, "child_added", root_event, time.Second, "No update on root_event after adding e1") - - e2, _ := NewEvent(db, "e2", "", []Resource{res_1}) - e2_l := e2.UpdateChannel() - e2_r := e2.DoneResource() - e2_info := NewEventQueueInfo(2) - err = LinkEvent(root_event, e2, e2_info) - if err != nil { - t.Fatal("Failed to add e2 to root_event") - } - (*GraphTester)(t).WaitForValue(rel, "child_added", root_event, time.Second, "No update on root_event after adding e2") - - e3, _ := NewEvent(db, "e3", "", []Resource{res_2}) - e3_l := e3.UpdateChannel() - e3_r := e3.DoneResource() - e3_info := NewEventQueueInfo(3) - err = LinkEvent(root_event, e3, e3_info) - if err != nil { - t.Fatal("Failed to add e3 to root_event") - } - (*GraphTester)(t).WaitForValue(rel, "child_added", root_event, time.Second, "No update on root_event after adding e3") - - // Abort the event after 5 seconds just in case - go func() { - time.Sleep(5 * time.Second) - if r.Owner() != nil { - abort_signal := NewDownSignal(nil, "abort") - SendUpdate(root_event, abort_signal) - } - }() - - // Now that a root_event is constructed with a queue and 3 basic events - // start the queue and check that all the events are executed - go func() { - (*GraphTester)(t).WaitForValue(e1_l, "event_done", e1, time.Second, "No event_done for e3") - (*GraphTester)(t).WaitForValue(e2_l, "event_done", e2, time.Second, "No event_done for e3") - (*GraphTester)(t).WaitForValue(e3_l, "event_done", e3, time.Second, "No event_done for e3") - signal := NewDownSignal(nil, "cancel") - SendUpdate(root_event, signal) - }() - - err = LockResources(root_event) +func fatalErr(t * testing.T, err error) { if err != nil { t.Fatal(err) } - - err = RunEvent(root_event) - if err != nil { - t.Fatal(err) - } - - if r.Owner() != nil { - t.Fatal("root event was not finished after starting") - } - - if e1_r.Owner() != nil { - t.Fatal(fmt.Sprintf("e1 was not completed: %s", e1_r.Owner())) - } - - if e2_r.Owner() != nil { - t.Fatal(fmt.Sprintf("e2 was not completed")) - } - - if e3_r.Owner() != nil { - t.Fatal("e3 was not completed") - } } - diff --git a/resource.go b/resource.go index 745f778..e4f7585 100644 --- a/resource.go +++ b/resource.go @@ -2,264 +2,327 @@ package graphvent import ( "fmt" - "sync" - "errors" ) -// Resources propagate update up to multiple parents, and not downwards -// (subscriber to team won't get update to alliance, but subscriber to alliance will get update to team) -func (resource * BaseResource) PropagateUpdate(signal GraphSignal) { - - if signal.Downwards() == false { - // Child->Parent, resource updates parent resources - resource.connection_lock.Lock() - defer resource.connection_lock.Unlock() - for _, parent := range resource.Parents() { - SendUpdate(parent, signal) - } - } else { - // Parent->Child, resource updates lock holder - resource.lock_holder_lock.Lock() - defer resource.lock_holder_lock.Unlock() - if resource.lock_holder != nil { - SendUpdate(resource.lock_holder, signal) +// Link a resource with a child +func LinkResource(ctx * GraphContext, resource Resource, child Resource) error { + if resource == nil || child == nil { + return fmt.Errorf("Will not connect nil to DAG") + } + _, err := UpdateStates(ctx, []GraphNode{resource, child}, func(states []NodeState) ([]NodeState, interface{}, error) { + resource_state := states[0].(ResourceState) + child_state := states[1].(ResourceState) + + if checkIfChild(ctx, resource_state, resource.ID(), child_state, child.ID()) == true { + return nil, nil, fmt.Errorf("RESOURCE_LINK_ERR: %s is a parent of %s so cannot link as child", child.ID(), resource.ID()) } - resource.connection_lock.Lock() - defer resource.connection_lock.Unlock() - for _, child := range(resource.children) { - SendUpdate(child, signal) + resource_state.children = append(resource_state.children, child) + child_state.parents = append(child_state.parents, resource) + return []NodeState{resource_state, child_state}, nil, nil + }) + return err +} + +// Link multiple children to a resource +func LinkResources(ctx * GraphContext, resource Resource, children []Resource) error { + if resource == nil || children == nil { + return fmt.Errorf("Invalid input") + } + + found := map[NodeID]bool{} + child_nodes := make([]GraphNode, len(children)) + for i, child := range(children) { + if child == nil { + return fmt.Errorf("Will not connect nil to DAG") + } + _, exists := found[child.ID()] + if exists == true { + return fmt.Errorf("Will not connect the same child twice") } + found[child.ID()] = true + child_nodes[i] = child } -} -// Resource is the interface that DAG nodes are made from -// A resource needs to be able to represent Logical entities and connections to physical entities. -// A resource lock could be aborted at any time if this connection is broken, if that happens the event locking it must be aborted -// The device connection should be maintained as much as possible(requiring some reconnection behaviour in the background) -type Resource interface { - GraphNode - Owner() GraphNode - Children() []Resource - Parents() []Resource + _, err := UpdateStates(ctx, append([]GraphNode{resource}, child_nodes...), func(states []NodeState) ([]NodeState, interface{}, error) { + resource_state := states[0].(ResourceState) - AddParent(parent Resource) - AddChild(child Resource) - LockConnections() - UnlockConnections() + new_states := make([]ResourceState, len(states)) + for i, state := range(states) { + new_states[i] = state.(ResourceState) + } - SetOwner(owner GraphNode) - LockState() - UnlockState() + for i, state := range(states[1:]) { + child_state := state.(ResourceState) - String() string + if checkIfChild(ctx, resource_state, resource.ID(), child_state, children[i].ID()) == true { + return nil, nil, fmt.Errorf("RESOURCES_LINK_ERR: %s is a parent of %s so cannot link as child", children[i].ID() , resource.ID()) + } - lock(node GraphNode) error - unlock(node GraphNode) error -} + new_states[0].children = append(new_states[0].children, children[i]) + new_states[i+1].parents = append(new_states[i+1].parents, resource) + } + ret_states := make([]NodeState, len(states)) + for i, state := range(new_states) { + ret_states[i] = state + } + return ret_states, nil, nil + }) -func (resource * BaseResource) String() string { - return resource.Name() + return err } -// Recurse up cur's parents to ensure r is not present -func checkIfParent(r Resource, cur Resource) bool { - if r == nil || cur == nil { - panic("Cannot recurse DAG with nil") - } +type ResourceState struct { + name string + owner GraphNode + children []Resource + parents []Resource +} - if r.ID() == cur.ID() { - return true - } +func (state ResourceState) Serialize() []byte { + return []byte(state.name) +} - cur.LockConnections() - defer cur.UnlockConnections() - for _, p := range(cur.Parents()) { - if checkIfParent(r, p) == true { - return true - } - } +// Locks cannot be passed between resources, so the answer to +// "who used to own this lock held by a resource" is always "nobody" +func (state ResourceState) OriginalLockHolder(id NodeID) GraphNode { + return nil +} +// Nothing can take a lock from a resource +func (state ResourceState) AllowedToTakeLock(id NodeID) bool { return false } -// Recurse doen cur's children to ensure r is not present -func checkIfChild(r Resource, cur Resource) bool { - if r == nil || cur == nil { - panic("Cannot recurse DAG with nil") +func (state ResourceState) RecordLockHolder(id NodeID, lock_holder GraphNode) NodeState { + if lock_holder != nil { + panic("Attempted to delegate a lock to a resource") } - if r.ID() == cur.ID() { - return true - } + return state +} - cur.LockConnections() - defer cur.UnlockConnections() - for _, c := range(cur.Children()) { - if checkIfChild(r, c) == true { - return true - } +func NewResourceState(name string) ResourceState { + return ResourceState{ + name: name, + owner: nil, + children: []Resource{}, + parents: []Resource{}, } - - return false } -func UnlockResource(resource Resource, node GraphNode) error { - var err error = nil - resource.LockState() - defer resource.UnlockState() - if resource.Owner() == nil { - return errors.New("Resource already unlocked") - } +// Resource represents a Node which can be locked by another node, +// and needs to own all it's childrens locks before being locked. +// Resource connections form a directed acyclic graph +// Resources do not allow any other nodes to take locks from them +type Resource interface { + GraphNode - if resource.Owner().ID() != node.ID() { - return errors.New("Resource not locked by parent, unlock failed") - } + // Called when locking the node to allow for custom lock behaviour + Lock(node GraphNode, state NodeState) (NodeState, error) + // Called when unlocking the node to allow for custom lock behaviour + Unlock(node GraphNode, state NodeState) (NodeState, error) +} - var lock_err error = nil - for _, child := range resource.Children() { - err := UnlockResource(child, node) - if err != nil { - lock_err = err - break +// Resources propagate update up to multiple parents, and not downwards +// (subscriber to team won't get update to alliance, but subscriber to alliance will get update to team) +func (resource * BaseResource) PropagateUpdate(ctx * GraphContext, signal GraphSignal) { + UseStates(ctx, []GraphNode{resource}, func(states []NodeState) (interface{}, error){ + resource_state := states[0].(ResourceState) + if signal.Direction() == Up { + // Child->Parent, resource updates parent resources + for _, parent := range resource_state.parents { + SendUpdate(ctx, parent, signal) + } + } else if signal.Direction() == Down { + // Parent->Child, resource updates lock holder + if resource_state.owner != nil { + SendUpdate(ctx, resource_state.owner, signal) + } + + for _, child := range(resource_state.children) { + SendUpdate(ctx, child, signal) + } + } else if signal.Direction() == Direct { + } else { + panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) } - } - - if lock_err != nil { - return fmt.Errorf("Resource failed to unlock: %s", lock_err) - } - - resource.SetOwner(node.Delegator(resource.ID())) + return nil, nil + }) +} - err = resource.unlock(node) - if err != nil { - return errors.New("Failed to unlock resource") +func checkIfChild(ctx * GraphContext, r ResourceState, r_id NodeID, cur ResourceState, cur_id NodeID) bool { + if r_id == cur_id { + return true } - return nil -} + for _, c := range(cur.children) { + val, _ := UseStates(ctx, []GraphNode{c}, func(states []NodeState) (interface{}, error) { + child_state := states[0].(ResourceState) + return checkIfChild(ctx, cur, cur_id, child_state, c.ID()), nil + }) -func isAllowedToTakeLock(node GraphNode, current_owner GraphNode) bool { - for _, allowed := range(current_owner.Allowed()) { - if allowed.ID() == node.ID() { + is_child := val.(bool) + if is_child { return true } } + return false } -func LockResource(resource Resource, node GraphNode) error { - resource.LockState() - defer resource.UnlockState() +func UnlockResource(ctx * GraphContext, resource Resource, node GraphNode, node_state NodeState) (NodeState, error) { + if node == nil || resource == nil{ + panic("Cannot unlock without a specified node and resource") + } + _, err := UpdateStates(ctx, []GraphNode{resource}, func(states []NodeState) ([]NodeState, interface{}, error) { + if resource.ID() == node.ID() { + if node_state != nil { + panic("node_state must be nil if unlocking resource from itself") + } + node_state = states[0] + } + resource_state := states[0].(ResourceState) - if resource.Owner() != nil { - // Check if node is allowed to take a lock from resource.Owner() - if isAllowedToTakeLock(node, resource.Owner()) == false { - return fmt.Errorf("%s is not allowed to take a lock from %s, allowed: %+v", node.Name(), resource.Owner().Name(), resource.Owner().Allowed()) + if resource_state.owner == nil { + return nil, nil, fmt.Errorf("Resource already unlocked") } - } - err := resource.lock(node) - if err != nil { - return fmt.Errorf("Failed to lock resource: %s", err) - } + if resource_state.owner.ID() != node.ID() { + return nil, nil, fmt.Errorf("Resource %s not locked by %s", resource.ID(), node.ID()) + } - var lock_err error = nil - locked_resources := []Resource{} - for _, child := range resource.Children() { - err := LockResource(child, node) - if err != nil{ - lock_err = err - break + var lock_err error = nil + for _, child := range(resource_state.children) { + var err error = nil + node_state, err = UnlockResource(ctx, child, node, node_state) + if err != nil { + lock_err = err + break + } } - locked_resources = append(locked_resources, child) - } - if lock_err != nil { - return fmt.Errorf("Resource failed to lock: %s", lock_err) - } + if lock_err != nil { + return nil, nil, fmt.Errorf("Resource %s failed to unlock: %e", resource.ID(), lock_err) + } - Log.Logf("resource", "Locked %s", resource.Name()) - node.TakeLock(resource) - resource.SetOwner(node) + resource_state.owner = node_state.OriginalLockHolder(resource.ID()) + unlock_state, err := resource.Unlock(node, resource_state) + resource_state = unlock_state.(ResourceState) + if err != nil { + return nil, nil, fmt.Errorf("Resource %s failed custom Unlock: %e", resource.ID(), err) + } - return nil -} + if resource_state.owner == nil { + ctx.Log.Logf("resource", "RESOURCE_UNLOCK: %s unlocked %s", node.ID(), resource.ID()) + } else { + ctx.Log.Logf("resource", "RESOURCE_UNLOCK: %s passed lock of %s back to %s", node.ID(), resource.ID(), resource_state.owner.ID()) + } -// BaseResource is the most basic resource that can exist in the DAG -// It holds a single state variable, which contains a pointer to the event that is locking it -type BaseResource struct { - BaseNode - parents []Resource - children []Resource - connection_lock sync.Mutex - lock_holder GraphNode - lock_holder_lock sync.Mutex - state_lock sync.Mutex -} + return []NodeState{resource_state}, nil, nil + }) -func (resource * BaseResource) SetOwner(owner GraphNode) { - resource.lock_holder_lock.Lock() - resource.lock_holder = owner - resource.lock_holder_lock.Unlock() -} + if err != nil { + return nil, err + } -func (resource * BaseResource) LockState() { - resource.state_lock.Lock() + return node_state, nil } -func (resource * BaseResource) UnlockState() { - resource.state_lock.Unlock() -} +// TODO: State +func LockResource(ctx * GraphContext, resource Resource, node GraphNode, node_state NodeState) (NodeState, error) { + if node == nil || resource == nil { + panic("Cannot lock without a specified node and resource") + } -func (resource * BaseResource) Owner() GraphNode { - return resource.lock_holder -} + _, err := UpdateStates(ctx, []GraphNode{resource}, func(states []NodeState) ([]NodeState, interface{}, error) { + if resource.ID() == node.ID() { + if node_state != nil { + panic("node_state must be nil if locking resource from itself") + } + node_state = states[0] + } + resource_state := states[0].(ResourceState) + if resource_state.owner != nil { + var lock_pass_allowed bool = false + + if resource_state.owner.ID() == resource.ID() { + lock_pass_allowed = resource_state.AllowedToTakeLock(node.ID()) + } else { + tmp, _ := UseStates(ctx, []GraphNode{resource_state.owner}, func(states []NodeState)(interface{}, error){ + return states[0].AllowedToTakeLock(node.ID()), nil + }) + lock_pass_allowed = tmp.(bool) + } + + + if lock_pass_allowed == false { + return nil, nil, fmt.Errorf("%s is not allowed to take a lock from %s", node.ID(), resource_state.owner.ID()) + } + } -//BaseResources don't check anything special when locking/unlocking -func (resource * BaseResource) lock(node GraphNode) error { - return nil -} + lock_state, err := resource.Lock(node, resource_state) + if err != nil { + return nil, nil, fmt.Errorf("Failed to lock resource: %e", err) + } -func (resource * BaseResource) unlock(node GraphNode) error { - return nil -} + resource_state = lock_state.(ResourceState) + + var lock_err error = nil + locked_resources := []Resource{} + for _, child := range(resource_state.children) { + node_state, err = LockResource(ctx, child, node, node_state) + if err != nil { + lock_err = err + break + } + locked_resources = append(locked_resources, child) + } -func (resource * BaseResource) Children() []Resource { - return resource.children -} + if lock_err != nil { + for _, locked_resource := range(locked_resources) { + node_state, err = UnlockResource(ctx, locked_resource, node, node_state) + if err != nil { + panic(err) + } + } + return nil, nil, fmt.Errorf("Resource failed to lock: %e", lock_err) + } -func (resource * BaseResource) Parents() []Resource { - return resource.parents -} + old_owner := resource_state.owner + resource_state.owner = node + node_state = node_state.RecordLockHolder(node.ID(), old_owner) -func (resource * BaseResource) LockConnections() { - resource.connection_lock.Lock() -} + if old_owner == nil { + ctx.Log.Logf("resource", "RESOURCE_LOCK: %s locked %s", node.ID(), resource.ID()) + } else { + ctx.Log.Logf("resource", "RESOURCE_LOCK: %s took lock of %s from %s", node.ID(), resource.ID(), old_owner.ID()) + } -func (resource * BaseResource) UnlockConnections() { - resource.connection_lock.Unlock() -} + return []NodeState{resource_state}, nil, nil + }) + if err != nil { + return nil, err + } -func (resource * BaseResource) AddParent(parent Resource) { - resource.parents = append(resource.parents, parent) + return node_state, nil } -func (resource * BaseResource) AddChild(child Resource) { - resource.children = append(resource.children, child) +// BaseResources represent simple resources in the DAG that can be used to create a hierarchy of locks that store names +type BaseResource struct { + BaseNode } -func NewBaseResource(name string, description string) BaseResource { - resource := BaseResource{ - BaseNode: NewBaseNode(name, description, randid()), - parents: []Resource{}, - children: []Resource{}, - } +//BaseResources don't check anything special when locking/unlocking +func (resource * BaseResource) Lock(node GraphNode, state NodeState) (NodeState, error) { + return state, nil +} - return resource +func (resource * BaseResource) Unlock(node GraphNode, state NodeState) (NodeState, error) { + return state, nil } -func FindResource(root Event, id string) Resource { +/*func FindResource(root Event, id string) Resource { if root == nil || id == ""{ panic("invalid input") } @@ -276,48 +339,17 @@ func FindResource(root Event, id string) Resource { } } return nil -} +}*/ -func LinkResource(resource Resource, child Resource) error { - if child == nil || resource == nil { - return fmt.Errorf("Will not connect nil to resource DAG") - } else if child.ID() == resource.ID() { - return fmt.Errorf("Will not connect resource to itself") +func NewResource(ctx * GraphContext, name string, children []Resource) (* BaseResource, error) { + resource := &BaseResource{ + BaseNode: NewNode(ctx, RandID(), NewResourceState(name)), } - if checkIfChild(resource, child) { - return fmt.Errorf("%s is a child of %s, cannot add as parent", resource.Name(), child.Name()) - } - - for _, p := range(resource.Parents()) { - if checkIfParent(child, p) { - return fmt.Errorf("Will not add %s as a parent of itself", child.Name()) - } - } - - child.AddParent(resource) - resource.AddChild(child) - return nil -} - -func LinkResources(resource Resource, children []Resource) error { - for _, c := range(children) { - err := LinkResource(resource, c) - if err != nil { - return err - } - } - return nil -} - -func NewResource(name string, description string, children []Resource) (* BaseResource, error) { - resource := NewBaseResource(name, description) - resource_ptr := &resource - - err := LinkResources(resource_ptr, children) + err := LinkResources(ctx, resource, children) if err != nil { return nil, err } - return resource_ptr, nil + return resource, nil } diff --git a/resource_test.go b/resource_test.go new file mode 100644 index 0000000..98923cc --- /dev/null +++ b/resource_test.go @@ -0,0 +1,226 @@ +package graphvent + +import ( + "testing" + "fmt" +) + +func TestNewResource(t * testing.T) { + ctx := testContext(t) + + r1, err := NewResource(ctx, "Test resource 1", []Resource{}) + fatalErr(t, err) + + _, err = NewResource(ctx, "Test resource 2", []Resource{r1}) + fatalErr(t, err) +} + +func TestRepeatedChildResource(t * testing.T) { + ctx := testContext(t) + + r1, err := NewResource(ctx, "Test resource 1", []Resource{}) + fatalErr(t, err) + + _, err = NewResource(ctx, "Test resource 2", []Resource{r1, r1}) + if err == nil { + t.Fatal("Added the same resource as a child twice to the same resource") + } +} + +func TestResourceSelfLock(t * testing.T) { + ctx := testContext(t) + + r1, err := NewResource(ctx, "Test resource 1", []Resource{}) + fatalErr(t, err) + + _, err = LockResource(ctx, r1, r1, nil) + fatalErr(t, err) + + _, err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (interface{}, error) { + owner_id := states[0].(ResourceState).owner.ID() + if owner_id != r1.ID() { + return nil, fmt.Errorf("r1 is owned by %s instead of self", owner_id) + } + return nil, nil + }) + fatalErr(t, err) + + _, err = UnlockResource(ctx, r1, r1, nil) + fatalErr(t, err) + + _, err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (interface{}, error) { + owner := states[0].(ResourceState).owner + if owner != nil { + return nil, fmt.Errorf("r1 is not unowned after unlock: %s", owner.ID()) + } + return nil, nil + }) + + fatalErr(t, err) +} + +func TestResourceSelfLockTiered(t * testing.T) { + ctx := testContext(t) + + r1, err := NewResource(ctx, "Test resource 1", []Resource{}) + fatalErr(t, err) + + r2, err := NewResource(ctx, "Test resource 1", []Resource{}) + fatalErr(t, err) + + r3, err := NewResource(ctx, "Test resource 3", []Resource{r1, r2}) + fatalErr(t, err) + + _, err = LockResource(ctx, r3, r3, nil) + fatalErr(t, err) + + _, err = UseStates(ctx, []GraphNode{r1, r2}, func(states []NodeState) (interface{}, error) { + owner_1_id := states[0].(ResourceState).owner.ID() + if owner_1_id != r3.ID() { + return nil, fmt.Errorf("r1 is owned by %s instead of r3", owner_1_id) + } + + owner_2_id := states[1].(ResourceState).owner.ID() + if owner_2_id != r3.ID() { + return nil, fmt.Errorf("r2 is owned by %s instead of r3", owner_2_id) + } + return nil, nil + }) + fatalErr(t, err) + + _, err = UnlockResource(ctx, r3, r3, nil) + fatalErr(t, err) + + _, err = UseStates(ctx, []GraphNode{r1, r2, r3}, func(states []NodeState) (interface{}, error) { + owner_1 := states[0].(ResourceState).owner + if owner_1 != nil { + return nil, fmt.Errorf("r1 is not unowned after unlocking: %s", owner_1.ID()) + } + + owner_2 := states[1].(ResourceState).owner + if owner_2 != nil { + return nil, fmt.Errorf("r2 is not unowned after unlocking: %s", owner_2.ID()) + } + + owner_3 := states[2].(ResourceState).owner + if owner_3 != nil { + return nil, fmt.Errorf("r3 is not unowned after unlocking: %s", owner_3.ID()) + } + return nil, nil + }) + + fatalErr(t, err) +} + +func TestResourceLockOther(t * testing.T) { + ctx := testContext(t) + + r1, err := NewResource(ctx, "Test resource 1", []Resource{}) + fatalErr(t, err) + + r2, err := NewResource(ctx, "Test resource 2", []Resource{}) + fatalErr(t, err) + + _, err = UpdateStates(ctx, []GraphNode{r2}, func(states []NodeState) ([]NodeState, interface{}, error) { + new_state, err := LockResource(ctx, r1, r2, states[0]) + fatalErr(t, err) + return []NodeState{new_state}, nil, nil + }) + fatalErr(t, err) + + _, err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (interface{}, error) { + owner_id := states[0].(ResourceState).owner.ID() + if owner_id != r2.ID() { + return nil, fmt.Errorf("r1 is owned by %s instead of r2", owner_id) + } + + return nil, nil + }) + fatalErr(t, err) + + _, err = UpdateStates(ctx, []GraphNode{r2}, func(states []NodeState) ([]NodeState, interface{}, error) { + new_state, err := UnlockResource(ctx, r1, r2, states[0]) + fatalErr(t, err) + return []NodeState{new_state}, nil, nil + }) + fatalErr(t, err) + + _, err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (interface{}, error) { + owner := states[0].(ResourceState).owner + if owner != nil { + return nil, fmt.Errorf("r1 is owned by %s instead of r2", owner.ID()) + } + + return nil, nil + }) + fatalErr(t, err) + +} + +func TestResourceLockSimpleConflict(t * testing.T) { + ctx := testContext(t) + + r1, err := NewResource(ctx, "Test resource 1", []Resource{}) + fatalErr(t, err) + + r2, err := NewResource(ctx, "Test resource 2", []Resource{}) + fatalErr(t, err) + + _, err = LockResource(ctx, r1, r1, nil) + fatalErr(t, err) + + _, err = UpdateStates(ctx, []GraphNode{r2}, func(states []NodeState) ([]NodeState, interface{}, error) { + new_state, err := LockResource(ctx, r1, r2, states[0]) + if err == nil { + t.Fatal("r2 took r1's lock from itself") + } + + return []NodeState{new_state}, nil, nil + }) + fatalErr(t, err) + + _, err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (interface{}, error) { + owner_id := states[0].(ResourceState).owner.ID() + if owner_id != r1.ID() { + return nil, fmt.Errorf("r1 is owned by %s instead of r1", owner_id) + } + + return nil, nil + }) + fatalErr(t, err) + + _, err = UnlockResource(ctx, r1, r1, nil) + fatalErr(t, err) + + _, err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (interface{}, error) { + owner := states[0].(ResourceState).owner + if owner != nil { + return nil, fmt.Errorf("r1 is owned by %s instead of r1", owner.ID()) + } + + return nil, nil + }) + fatalErr(t, err) + +} + +func TestResourceLockTieredConflict(t * testing.T) { + ctx := testContext(t) + + r1, err := NewResource(ctx, "Test resource 1", []Resource{}) + fatalErr(t, err) + + r2, err := NewResource(ctx, "Test resource 2", []Resource{r1}) + fatalErr(t, err) + + r3, err := NewResource(ctx, "Test resource 3", []Resource{r1}) + fatalErr(t, err) + + _, err = LockResource(ctx, r2, r2, nil) + fatalErr(t, err) + + _, err = LockResource(ctx, r3, r3, nil) + if err == nil { + t.Fatal("Locked r3 which depends on r1 while r2 which depends on r1 is already locked") + } +}