diff --git a/lockable.go b/lockable.go index c1c8239..28f03db 100644 --- a/lockable.go +++ b/lockable.go @@ -21,8 +21,10 @@ type LockableState interface { Requirements() []Lockable AddRequirement(requirement Lockable) + RemoveRequirement(requirement Lockable) Dependencies() []Lockable AddDependency(dependency Lockable) + RemoveDependency(dependency Lockable) Owner() Lockable SetOwner(owner Lockable) } @@ -156,6 +158,72 @@ func (state * BaseLockableState) AddDependency(dependency Lockable) { state.dependencies = append(state.dependencies, dependency) } +func (state * BaseLockableState) RemoveDependency(dependency Lockable) { + idx := -1 + + for i, dep := range(state.dependencies) { + if dep.ID() == dependency.ID() { + idx = i + break + } + } + + if idx == -1 { + panic(fmt.Sprintf("%s is not a dependency of %s", dependency.ID(), state.Name())) + } + + dep_len := len(state.dependencies) + state.dependencies[idx] = state.dependencies[dep_len-1] + state.dependencies = state.dependencies[0:(dep_len-1)] +} + +func (state * BaseLockableState) RemoveRequirement(requirement Lockable) { + idx := -1 + for i, req := range(state.requirements) { + if req.ID() == requirement.ID() { + idx = i + break + } + } + + if idx == -1 { + panic(fmt.Sprintf("%s is not a requirement of %s", requirement.ID(), state.Name())) + } + + req_len := len(state.requirements) + state.requirements[idx] = state.requirements[req_len-1] + state.requirements = state.requirements[0:(req_len-1)] +} + +func UnlinkLockables(ctx * GraphContext, lockable Lockable, requirement Lockable) error { + // Check if requirement is a requirement of lockable + err := UpdateStates(ctx, []GraphNode{lockable}, func(nodes NodeMap) error{ + state := lockable.State().(LockableState) + var found GraphNode = nil + for _, req := range(state.Requirements()) { + if requirement.ID() == req.ID() { + found = req + break + } + } + + if found == nil { + return fmt.Errorf("UNLINK_LOCKABLES_ERR: %s is not a requirement of %s", requirement.ID(), lockable.ID()) + } + + err := UpdateMoreStates(ctx, []GraphNode{found}, nodes, func(nodes NodeMap) error { + req_state := found.State().(LockableState) + req_state.RemoveDependency(lockable) + state.RemoveRequirement(requirement) + return nil + }) + + return err + }) + + return err +} + func LinkLockables(ctx * GraphContext, lockable Lockable, requirements []Lockable) error { if lockable == nil { return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") diff --git a/lockable_test.go b/lockable_test.go index 9a3a51d..9db7dac 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -396,3 +396,15 @@ func TestLockableDBLoad(t * testing.T){ return err }) } + +func TestLockableUnlink(t * testing.T){ + ctx := logTestContext(t, []string{"lockable"}) + l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) + fatalErr(t, err) + + l2, err := NewSimpleLockable(ctx, "Test Lockable 2", []Lockable{l1}) + fatalErr(t, err) + + err = UnlinkLockables(ctx, l2, l1) + fatalErr(t, err) +} diff --git a/thread.go b/thread.go index 4e1554f..80caeb0 100644 --- a/thread.go +++ b/thread.go @@ -56,6 +56,7 @@ type ThreadState interface { Child(id NodeID) Thread ChildInfo(child NodeID) ThreadInfo AddChild(child Thread, info ThreadInfo) error + RemoveChild(child Thread) Start() error Stop() error @@ -260,6 +261,50 @@ func (state * BaseThreadState) ChildInfo(child NodeID) ThreadInfo { return state.child_info[child] } +func UnlinkThreads(ctx * GraphContext, thread Thread, child Thread) error { + err := UpdateStates(ctx, []GraphNode{thread}, func(nodes NodeMap) error{ + state := thread.State().(ThreadState) + var found GraphNode = nil + for _, c := range(state.Children()) { + if child.ID() == c.ID() { + found = c + break + } + } + + if found == nil { + return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID(), thread.ID()) + } + + err := UpdateMoreStates(ctx, []GraphNode{found}, nodes, func(nodes NodeMap) error { + child_state := found.State().(ThreadState) + child_state.SetParent(nil) + state.RemoveChild(child) + return nil + }) + return err + }) + return err +} + +func (state * BaseThreadState) RemoveChild(child Thread) { + idx := -1 + for i, c := range(state.children) { + if c.ID() == child.ID() { + idx = i + break + } + } + + if idx == -1 { + panic(fmt.Sprintf("%s is not a child of %s", child.ID(), state.Name())) + } + + child_len := len(state.children) + state.children[idx] = state.children[child_len-1] + state.children = state.children[0:child_len-1] +} + func (state * BaseThreadState) AddChild(child Thread, info ThreadInfo) error { if child == nil { return fmt.Errorf("Will not connect nil to the thread tree") diff --git a/thread_test.go b/thread_test.go index 929d03c..3a9d589 100644 --- a/thread_test.go +++ b/thread_test.go @@ -86,3 +86,18 @@ func TestThreadDBLoad(t * testing.T) { return err }) } + +func TestThreadUnlink(t * testing.T) { + ctx := logTestContext(t, []string{}) + t1, err := NewSimpleThread(ctx, "Test Thread 1", []Lockable{}, BaseThreadActions, BaseThreadHandlers) + fatalErr(t, err) + + t2, err := NewSimpleThread(ctx, "Test Thread 2", []Lockable{}, BaseThreadActions, BaseThreadHandlers) + fatalErr(t, err) + + err = LinkThreads(ctx, t1, t2, nil) + fatalErr(t, err) + + err = UnlinkThreads(ctx, t1, t2) + fatalErr(t, err) +}