Moved ecdh to it's own extension

gql_cataclysm
noah metz 2023-07-29 00:28:44 -06:00
parent f87571edcf
commit 641bd8febe
11 changed files with 463 additions and 273 deletions

@ -53,6 +53,13 @@ var (
NodeNotFoundError = errors.New("Node not found in DB") NodeNotFoundError = errors.New("Node not found in DB")
) )
type SignalLoadFunc func(*Context, []byte) (Signal, error)
type SignalInfo struct {
Load SignalLoadFunc
Type SignalType
}
// Information about a registered extension // Information about a registered extension
type ExtensionInfo struct { type ExtensionInfo struct {
// Function used to load extensions of this type from the database // Function used to load extensions of this type from the database
@ -77,6 +84,8 @@ type Context struct {
Log Logger Log Logger
// Map between database extension hashes and the registered info // Map between database extension hashes and the registered info
Extensions map[uint64]ExtensionInfo Extensions map[uint64]ExtensionInfo
// Map between serialized signal hashes and the registered info
Signals map[uint64]SignalInfo
// Map between database type hashes and the registered info // Map between database type hashes and the registered info
Types map[uint64]*NodeInfo Types map[uint64]*NodeInfo
// Curve used for signature operations // Curve used for signature operations
@ -118,6 +127,24 @@ func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) e
return nil return nil
} }
func (ctx *Context) RegisterSignal(signal_type SignalType, load_fn SignalLoadFunc) error {
if load_fn == nil {
return fmt.Errorf("def has no load function")
}
type_hash := Hash(signal_type)
_, exists := ctx.Signals[type_hash]
if exists == true {
return fmt.Errorf("Cannot register signal of type %s, type already exists in context", signal_type)
}
ctx.Signals[type_hash] = SignalInfo{
Load: load_fn,
Type: signal_type,
}
return nil
}
// Add a node to a context, returns an error if the def is invalid or already exists in the context // Add a node to a context, returns an error if the def is invalid or already exists in the context
func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc, data interface{}) error { func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc, data interface{}) error {
if load_fn == nil { if load_fn == nil {
@ -199,6 +226,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
Log: log, Log: log,
Extensions: map[uint64]ExtensionInfo{}, Extensions: map[uint64]ExtensionInfo{},
Types: map[uint64]*NodeInfo{}, Types: map[uint64]*NodeInfo{},
Signals: map[uint64]SignalInfo{},
Nodes: map[NodeID]*Node{}, Nodes: map[NodeID]*Node{},
ECDH: ecdh.P256(), ECDH: ecdh.P256(),
ECDSA: elliptic.P256(), ECDSA: elliptic.P256(),
@ -236,6 +264,10 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err return nil, err
} }
err = ctx.RegisterSignal(StopSignalType, func(ctx *Context, data []byte) (Signal, error) {
return StopSignal, nil
})
err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType}) err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType})
if err != nil { if err != nil {
return nil, err return nil, err

@ -0,0 +1,203 @@
package graphvent
import (
"fmt"
"time"
"encoding/json"
"crypto/ecdsa"
"crypto/x509"
"crypto/ecdh"
)
type ECDHState struct {
ECKey *ecdh.PrivateKey
SharedSecret []byte
}
type ECDHStateJSON struct {
ECKey []byte `json:"ec_key"`
SharedSecret []byte `json:"shared_secret"`
}
func (state *ECDHState) MarshalJSON() ([]byte, error) {
var key_bytes []byte
var err error
if state.ECKey != nil {
key_bytes, err = x509.MarshalPKCS8PrivateKey(state.ECKey)
if err != nil {
return nil, err
}
}
return json.Marshal(&ECDHStateJSON{
ECKey: key_bytes,
SharedSecret: state.SharedSecret,
})
}
func (state *ECDHState) UnmarshalJSON(data []byte) error {
var j ECDHStateJSON
err := json.Unmarshal(data, &j)
if err != nil {
return err
}
state.SharedSecret = j.SharedSecret
if len(j.ECKey) == 0 {
state.ECKey = nil
} else {
tmp_key, err := x509.ParsePKCS8PrivateKey(j.ECKey)
if err != nil {
return err
}
ecdsa_key, ok := tmp_key.(*ecdsa.PrivateKey)
if ok == false {
return fmt.Errorf("Parsed wrong key type from DB for ECDHState")
}
state.ECKey, err = ecdsa_key.ECDH()
if err != nil {
return err
}
}
return nil
}
type ECDHMap map[NodeID]ECDHState
func (m ECDHMap) MarshalJSON() ([]byte, error) {
tmp := map[string]ECDHState{}
for id, state := range(m) {
tmp[id.String()] = state
}
return json.Marshal(tmp)
}
type ECDHExt struct {
ECDHStates ECDHMap
}
func NewECDHExt() *ECDHExt {
return &ECDHExt{
ECDHStates: ECDHMap{},
}
}
func ResolveFields[T Extension](t T, name string, field_funcs map[string]func(T)interface{})interface{} {
var zero T
field_func, ok := field_funcs[name]
if ok == false {
return fmt.Errorf("%s is not a field of %s", name, zero.Type())
}
return field_func(t)
}
func (ext *ECDHExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*ECDHExt)interface{}{
"ecdh_states": func(ext *ECDHExt) interface{} {
return ext.ECDHStates
},
})
}
func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal ECDHSignal) {
ser, _ := signal.Serialize()
ctx.Log.Logf("ecdh", "ECDH_SIGNAL: %s->%s - %s", source, node.ID, ser)
switch signal.State {
case "req":
state, exists := ext.ECDHStates[source]
if exists == false {
state = ECDHState{nil, nil}
}
resp, shared_secret, err := NewECDHRespSignal(ctx, node, signal)
if err == nil {
state.SharedSecret = shared_secret
ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
ctx.Send(node.ID, source, resp)
} else {
ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err)
// TODO: send error response
}
case "resp":
state, exists := ext.ECDHStates[source]
if exists == false || state.ECKey == nil {
ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHStateSignalType), "no_req"})
} else {
err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW)
if err == nil {
shared_secret, err := state.ECKey.ECDH(signal.ECDH)
if err == nil {
state.SharedSecret = shared_secret
state.ECKey = nil
ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
}
}
}
default:
ctx.Log.Logf("ecdh", "unknown echd state %s", signal.State)
}
}
func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal StateSignal) {
ser, _ := signal.Serialize()
ctx.Log.Logf("ecdh", "ECHD_STATE: %s->%s - %s", source, node.ID, ser)
}
func (ext *ECDHExt) HandleECDHProxySignal(ctx *Context, source NodeID, node *Node, signal ECDHProxySignal) {
state, exists := ext.ECDHStates[source]
if exists == false {
ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHStateSignalType), "no_req"})
} else if state.SharedSecret == nil {
ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHStateSignalType), "no_shared"})
} else {
unwrapped_signal, err := ParseECDHProxySignal(ctx, &signal, state.SharedSecret)
if err != nil {
ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHStateSignalType), err.Error()})
} else {
//TODO: Figure out what I was trying to do here and fix it
ctx.Send(signal.Source, signal.Dest, unwrapped_signal)
}
}
}
func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) {
switch signal.Direction() {
case Direct:
switch signal.Type() {
case ECDHProxySignalType:
ecdh_signal := signal.(ECDHProxySignal)
ext.HandleECDHProxySignal(ctx, source, node, ecdh_signal)
case ECDHStateSignalType:
ecdh_signal := signal.(StateSignal)
ext.HandleStateSignal(ctx, source, node, ecdh_signal)
case ECDHSignalType:
ecdh_signal := signal.(ECDHSignal)
ext.HandleECDHSignal(ctx, source, node, ecdh_signal)
default:
}
default:
}
}
func (ext *ECDHExt) Type() ExtType {
return ECDHExtType
}
func (ext *ECDHExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext, "", " ")
}
func LoadECDHExt(ctx *Context, data []byte) (Extension, error) {
var ext ECDHExt
err := json.Unmarshal(data, &ext)
if err != nil {
return nil, err
}
return &ext, nil
}

@ -167,7 +167,7 @@ type ResolveContext struct {
GQLContext *GQLExtContext GQLContext *GQLExtContext
Server *Node Server *Node
Ext *GQLExt Ext *GQLExt
User *Node User NodeID
} }
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) {
@ -181,16 +181,11 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username)
} }
user, exists := gql_ext.Users[auth_id]
if exists == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no existing authorization for client %s", auth_id)
}
return &ResolveContext{ return &ResolveContext{
Context: ctx, Context: ctx,
GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext), GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext),
Server: server, Server: server,
User: user, User: auth_id,
}, nil }, nil
} }
@ -481,6 +476,11 @@ func NewGQLInterface(if_name string, default_name string, interfaces []*graphql.
return &gql return &gql
} }
type GQLNode struct {
ID NodeID
Type NodeType
}
// GQL Specific Context information // GQL Specific Context information
type GQLExtContext struct { type GQLExtContext struct {
// Generated GQL schema // Generated GQL schema
@ -545,7 +545,7 @@ func NewGQLExtContext() *GQLExtContext {
}) })
query.AddFieldConfig("Self", GQLQuerySelf) query.AddFieldConfig("Self", GQLQuerySelf)
query.AddFieldConfig("User", GQLQueryUser) query.AddFieldConfig("Node", GQLQueryNode)
mutation := graphql.NewObject(graphql.ObjectConfig{ mutation := graphql.NewObject(graphql.ObjectConfig{
Name: "Mutation", Name: "Mutation",
@ -600,9 +600,7 @@ type GQLExt struct {
tls_key []byte tls_key []byte
tls_cert []byte tls_cert []byte
Listen string Listen string
Users NodeMap
Key *ecdsa.PrivateKey
ECDH ecdh.Curve
SubscribeLock sync.Mutex SubscribeLock sync.Mutex
SubscribeListeners []chan Signal SubscribeListeners []chan Signal
} }
@ -629,7 +627,6 @@ func (ext *GQLExt) Process(context *Context, princ_id NodeID, node *Node, signal
if signal.Type() == ReadResultSignalType { if signal.Type() == ReadResultSignalType {
} }
ext.SubscribeLock.Lock() ext.SubscribeLock.Lock()
defer ext.SubscribeLock.Unlock() defer ext.SubscribeLock.Unlock()
@ -655,22 +652,13 @@ func (ext *GQLExt) Type() ExtType {
type GQLExtJSON struct { type GQLExtJSON struct {
Listen string `json:"listen"` Listen string `json:"listen"`
Key []byte `json:"key"`
ECDH uint8 `json:"ecdh_curve"`
TLSKey []byte `json:"ssl_key"` TLSKey []byte `json:"ssl_key"`
TLSCert []byte `json:"ssl_cert"` TLSCert []byte `json:"ssl_cert"`
} }
func (ext *GQLExt) Serialize() ([]byte, error) { func (ext *GQLExt) Serialize() ([]byte, error) {
ser_key, err := x509.MarshalECPrivateKey(ext.Key)
if err != nil {
return nil, err
}
return json.MarshalIndent(&GQLExtJSON{ return json.MarshalIndent(&GQLExtJSON{
Listen: ext.Listen, Listen: ext.Listen,
Key: ser_key,
ECDH: ecdh_curve_ids[ext.ECDH],
TLSKey: ext.tls_key, TLSKey: ext.tls_key,
TLSCert: ext.tls_cert, TLSCert: ext.tls_cert,
}, "", " ") }, "", " ")
@ -699,22 +687,12 @@ func LoadGQLExt(ctx *Context, data []byte) (Extension, error) {
return nil, err return nil, err
} }
ecdh_curve, ok := ecdh_curves[j.ECDH] return NewGQLExt(ctx, j.Listen, j.TLSCert, j.TLSKey), nil
if ok == false {
return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH)
}
key, err := x509.ParseECPrivateKey(j.Key)
if err != nil {
return nil, err
}
return NewGQLExt(j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey), nil
} }
func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) *GQLExt { func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte) *GQLExt {
if tls_cert == nil || tls_key == nil { if tls_cert == nil || tls_key == nil {
ssl_key, err := ecdsa.GenerateKey(key.Curve, rand.Reader) ssl_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -755,8 +733,6 @@ func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_
return &GQLExt{ return &GQLExt{
Listen: listen, Listen: listen,
SubscribeListeners: []chan Signal{}, SubscribeListeners: []chan Signal{},
Key: key,
ECDH: ecdh_curve,
tls_cert: tls_cert, tls_cert: tls_cert,
tls_key: tls_key, tls_key: tls_key,
} }

@ -3,26 +3,31 @@ import (
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
) )
var GQLQuerySelf = &graphql.Field{ var GQLQueryNode = &graphql.Field{
Type: GQLInterfaceNode.Default, Type: GQLInterfaceNode.Interface,
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
_, ctx, err := PrepResolve(p) ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx.Context.Log.Logf("gql", "FieldASTs: %+v", p.Info.FieldASTs)
// Get a list of fields that will be written
// Send the read signal
// Wait for the response, returning an error on timeout
return ctx.Server, nil return nil, nil
}, },
} }
var GQLQueryUser = &graphql.Field{ var GQLQuerySelf = &graphql.Field{
Type: GQLInterfaceNode.Default, Type: GQLInterfaceNode.Default,
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
_, ctx, err := PrepResolve(p) _, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ctx.User, nil return nil, nil
}, },
} }

@ -5,18 +5,13 @@ import (
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
) )
func PrepResolve(p graphql.ResolveParams) (*Node, *ResolveContext, error) { func PrepResolve(p graphql.ResolveParams) (*ResolveContext, error) {
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext) resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false { if ok == false {
return nil, nil, fmt.Errorf("Bad resolve in params context") return nil, fmt.Errorf("Bad resolve in params context")
} }
node, ok := p.Source.(*Node) return resolve_context, nil
if ok == false {
return nil, nil, fmt.Errorf("Source is not a *Node in PrepResolve")
}
return node, resolve_context, nil
} }
// TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does // TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does
@ -71,21 +66,11 @@ func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) {
// TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted // TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted
func GQLNodeID(p graphql.ResolveParams) (interface{}, error) { func GQLNodeID(p graphql.ResolveParams) (interface{}, error) {
node, _, err := PrepResolve(p) return nil, nil
if err != nil {
return nil, err
}
return node.ID, nil
} }
func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
node, _, err := PrepResolve(p) return nil, nil
if err != nil {
return nil, err
}
return string(node.Type), nil
} }
func GQLNodeListen(p graphql.ResolveParams) (interface{}, error) { func GQLNodeListen(p graphql.ResolveParams) (interface{}, error) {

@ -16,7 +16,7 @@ func GQLSubscribeSelf(p graphql.ResolveParams) (interface{}, error) {
} }
func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *Node, *GQLExt, Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *Node, *GQLExt, Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
_, ctx, err := PrepResolve(p) ctx, err := PrepResolve(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -3,10 +3,6 @@ package graphvent
import ( import (
"testing" "testing"
"time" "time"
"crypto/rand"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
) )
func TestGQLDB(t * testing.T) { func TestGQLDB(t * testing.T) {
@ -19,10 +15,7 @@ func TestGQLDB(t * testing.T) {
ctx.Log.Logf("test", "U1_ID: %s", u1.ID) ctx.Log.Logf("test", "U1_ID: %s", u1.ID)
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) gql_ext := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err)
gql_ext := NewGQLExt(":0", ecdh.P256(), key, nil, nil)
listener_ext := NewListenerExt(10) listener_ext := NewListenerExt(10)
gql := NewNode(ctx, nil, GQLNodeType, 10, nil, gql := NewNode(ctx, nil, GQLNodeType, 10, nil,
gql_ext, gql_ext,

@ -218,7 +218,7 @@ func nodeLoop(ctx *Context, node *Node) error {
ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal) ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal)
} else { } else {
result := ReadNodeFields(ctx, node, source, read_signal.Extensions) result := ReadNodeFields(ctx, node, source, read_signal.Extensions)
ctx.Send(node.ID, source, NewReadResultSignal(result)) ctx.Send(node.ID, source, NewReadResultSignal(node.Type, result))
} }
} }

@ -67,9 +67,12 @@ func TestECDH(t *testing.T) {
fatalErr(t, err) fatalErr(t, err)
n1_listener := NewListenerExt(10) n1_listener := NewListenerExt(10)
ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+")}) ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+"), MakeAction(ECDHProxySignalType, "+")})
n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt(), n1_listener) n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt(), n1_listener)
n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt()) n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt())
n3_listener := NewListenerExt(10)
n3_policy := NewPerNodePolicy(NodeActions{n1.ID: Actions{MakeAction(StopSignalType)}})
n3 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy, n3_policy), NewECDHExt(), n3_listener)
ctx.Log.Logf("test", "N1: %s", n1.ID) ctx.Log.Logf("test", "N1: %s", n1.ID)
ctx.Log.Logf("test", "N2: %s", n2.ID) ctx.Log.Logf("test", "N2: %s", n2.ID)
@ -91,4 +94,11 @@ func TestECDH(t *testing.T) {
return sig.State == "resp" return sig.State == "resp"
}) })
fatalErr(t, err) fatalErr(t, err)
time.Sleep(10*time.Millisecond)
ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, NewDirectSignal(StopSignalType), ecdh_ext.ECDHStates[n2.ID].SharedSecret)
fatalErr(t, err)
err = ctx.Send(n1.ID, n2.ID, ecdh_sig)
fatalErr(t, err)
} }

@ -4,10 +4,13 @@ import (
"time" "time"
"fmt" "fmt"
"encoding/json" "encoding/json"
"encoding/binary"
"crypto/sha512" "crypto/sha512"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ecdh" "crypto/ecdh"
"crypto/rand" "crypto/rand"
"crypto/aes"
"crypto/cipher"
) )
type SignalDirection int type SignalDirection int
@ -20,6 +23,8 @@ const (
ReadResultSignalType = "READ_RESULT" ReadResultSignalType = "READ_RESULT"
LinkStartSignalType = "LINK_START" LinkStartSignalType = "LINK_START"
ECDHSignalType = "ECDH" ECDHSignalType = "ECDH"
ECDHStateSignalType = "ECDH_STATE"
ECDHProxySignalType = "ECDH_PROXY"
Up SignalDirection = iota Up SignalDirection = iota
Down Down
@ -27,9 +32,8 @@ const (
) )
type SignalType string type SignalType string
func (signal_type SignalType) String() string { func (signal_type SignalType) String() string { return string(signal_type) }
return string(signal_type) func (signal_type SignalType) Prefix() string { return "SIGNAL: " }
}
type Signal interface { type Signal interface {
Serializable[SignalType] Serializable[SignalType]
@ -49,7 +53,6 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.
if signal.Type() == signal_type { if signal.Type() == signal_type {
sig, ok := signal.(S) sig, ok := signal.(S)
if ok == true { if ok == true {
ctx.Log.Logf("test", "received: %+v", sig)
if check(sig) == true { if check(sig) == true {
return sig, nil return sig, nil
} }
@ -80,12 +83,8 @@ func (signal BaseSignal) Direction() SignalDirection {
return signal.SignalDirection return signal.SignalDirection
} }
func (signal *BaseSignal) MarshalJSON() ([]byte, error) {
return json.Marshal(signal)
}
func (signal BaseSignal) Serialize() ([]byte, error) { func (signal BaseSignal) Serialize() ([]byte, error) {
return signal.MarshalJSON() return json.Marshal(signal)
} }
func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal { func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal {
@ -115,6 +114,10 @@ type IDSignal struct {
ID NodeID `json:"id"` ID NodeID `json:"id"`
} }
func (signal IDSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
}
func (signal IDSignal) String() string { func (signal IDSignal) String() string {
ser, err := json.Marshal(signal) ser, err := json.Marshal(signal)
if err != nil { if err != nil {
@ -135,10 +138,18 @@ type StateSignal struct {
State string `json:"state"` State string `json:"state"`
} }
func (signal StateSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
}
type IDStateSignal struct { type IDStateSignal struct {
BaseSignal BaseSignal
ID NodeID `json:"id"` ID NodeID `json:"id"`
State string `json:"status"` State string `json:"state"`
}
func (signal IDStateSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
} }
func (signal IDStateSignal) String() string { func (signal IDStateSignal) String() string {
@ -157,15 +168,6 @@ func NewStatusSignal(status string, source NodeID) IDStateSignal {
} }
} }
func (signal StateSignal) Serialize() ([]byte, error) {
return json.MarshalIndent(signal, "", " ")
}
func (signal StateSignal) String() string {
ser, _ := signal.Serialize()
return string(ser)
}
func NewLinkSignal(state string) StateSignal { func NewLinkSignal(state string) StateSignal {
return StateSignal{ return StateSignal{
BaseSignal: NewDirectSignal(LinkSignalType), BaseSignal: NewDirectSignal(LinkSignalType),
@ -201,6 +203,10 @@ type ReadSignal struct {
Extensions map[ExtType][]string `json:"extensions"` Extensions map[ExtType][]string `json:"extensions"`
} }
func (signal ReadSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
}
func NewReadSignal(exts map[ExtType][]string) ReadSignal { func NewReadSignal(exts map[ExtType][]string) ReadSignal {
return ReadSignal{ return ReadSignal{
BaseSignal: NewDirectSignal(ReadSignalType), BaseSignal: NewDirectSignal(ReadSignalType),
@ -210,12 +216,14 @@ func NewReadSignal(exts map[ExtType][]string) ReadSignal {
type ReadResultSignal struct { type ReadResultSignal struct {
BaseSignal BaseSignal
NodeType NodeType
Extensions map[ExtType]map[string]interface{} `json:"extensions"` Extensions map[ExtType]map[string]interface{} `json:"extensions"`
} }
func NewReadResultSignal(exts map[ExtType]map[string]interface{}) ReadResultSignal { func NewReadResultSignal(node_type NodeType, exts map[ExtType]map[string]interface{}) ReadResultSignal {
return ReadResultSignal{ return ReadResultSignal{
BaseSignal: NewDirectSignal(ReadResultSignalType), BaseSignal: NewDirectSignal(ReadResultSignalType),
NodeType: node_type,
Extensions: exts, Extensions: exts,
} }
} }
@ -228,6 +236,28 @@ type ECDHSignal struct {
Signature []byte Signature []byte
} }
type ECDHSignalJSON struct {
StateSignal
Time time.Time `json:"time"`
ECDSA []byte `json:"ecdsa_pubkey"`
ECDH []byte `json:"ecdh_pubkey"`
Signature []byte `json:"signature"`
}
func (signal *ECDHSignal) MarshalJSON() ([]byte, error) {
return json.Marshal(&ECDHSignalJSON{
StateSignal: signal.StateSignal,
Time: signal.Time,
ECDH: signal.ECDH.Bytes(),
ECDSA: signal.ECDH.Bytes(),
Signature: signal.Signature,
})
}
func (signal ECDHSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
}
func keyHash(now time.Time, ec_key *ecdh.PublicKey) ([]byte, error) { func keyHash(now time.Time, ec_key *ecdh.PublicKey) ([]byte, error) {
time_bytes, err := now.MarshalJSON() time_bytes, err := now.MarshalJSON()
if err != nil { if err != nil {
@ -335,3 +365,138 @@ func VerifyECDHSignal(now time.Time, sig ECDHSignal, window time.Duration) error
return nil return nil
} }
type ECDHProxySignal struct {
BaseSignal
Source NodeID
Dest NodeID
IV []byte
Data []byte
}
func NewECDHProxySignal(source, dest NodeID, signal Signal, shared_secret []byte) (ECDHProxySignal, error) {
if shared_secret == nil {
return ECDHProxySignal{}, fmt.Errorf("need shared_secret")
}
aes_key, err := aes.NewCipher(shared_secret[:32])
if err != nil {
return ECDHProxySignal{}, err
}
ser, err := SerializeSignal(signal, aes_key.BlockSize())
if err != nil {
return ECDHProxySignal{}, err
}
iv := make([]byte, aes_key.BlockSize())
n, err := rand.Reader.Read(iv)
if err != nil {
return ECDHProxySignal{}, err
} else if n != len(iv) {
return ECDHProxySignal{}, fmt.Errorf("Not enough bytes read for IV")
}
encrypter := cipher.NewCBCEncrypter(aes_key, iv)
encrypter.CryptBlocks(ser, ser)
return ECDHProxySignal{
BaseSignal: NewDirectSignal(ECDHProxySignalType),
Source: source,
Dest: dest,
IV: iv,
Data: ser,
}, nil
}
type SignalHeader struct {
Magic uint32
TypeHash uint64
Length uint64
}
const SIGNAL_SER_MAGIC uint32 = 0x753a64de
const SIGNAL_SER_HEADER_LENGTH = 20
func SerializeSignal(signal Signal, block_size int) ([]byte, error) {
signal_ser, err := signal.Serialize()
if err != nil {
return nil, err
}
pad_req := 0
if block_size > 0 {
pad := block_size - ((SIGNAL_SER_HEADER_LENGTH + len(signal_ser)) % block_size)
if pad != block_size {
pad_req = pad
}
}
header := SignalHeader{
Magic: SIGNAL_SER_MAGIC,
TypeHash: Hash(signal.Type()),
Length: uint64(len(signal_ser) + pad_req),
}
ser := make([]byte, SIGNAL_SER_HEADER_LENGTH + len(signal_ser) + pad_req)
binary.BigEndian.PutUint32(ser[0:4], header.Magic)
binary.BigEndian.PutUint64(ser[4:12], header.TypeHash)
binary.BigEndian.PutUint64(ser[12:20], header.Length)
copy(ser[SIGNAL_SER_HEADER_LENGTH:], signal_ser)
return ser, nil
}
func ParseSignal(ctx *Context, data []byte) (Signal, error) {
if len(data) < SIGNAL_SER_HEADER_LENGTH {
return nil, fmt.Errorf("data shorter than header length")
}
header := SignalHeader{
Magic: binary.BigEndian.Uint32(data[0:4]),
TypeHash: binary.BigEndian.Uint64(data[4:12]),
Length: binary.BigEndian.Uint64(data[12:20]),
}
if header.Magic != SIGNAL_SER_MAGIC {
return nil, fmt.Errorf("signal magic mismatch 0x%x", header.Magic)
}
left := len(data) - SIGNAL_SER_HEADER_LENGTH
if int(header.Length) != left {
return nil, fmt.Errorf("signal length mismatch %d/%d", header.Length, left)
}
signal_def, exists := ctx.Signals[header.TypeHash]
if exists == false {
return nil, fmt.Errorf("0x%x is not a known signal type", header.TypeHash)
}
signal, err := signal_def.Load(ctx, data[SIGNAL_SER_HEADER_LENGTH:])
if err != nil {
return nil, err
}
return signal, nil
}
func ParseECDHProxySignal(ctx *Context, signal *ECDHProxySignal, shared_secret []byte) (Signal, error) {
if shared_secret == nil {
return nil, fmt.Errorf("need shared_secret")
}
aes_key, err := aes.NewCipher(shared_secret[:32])
if err != nil {
return nil, err
}
decrypter := cipher.NewCBCDecrypter(aes_key, signal.IV)
decrypted := make([]byte, len(signal.Data))
decrypter.CryptBlocks(decrypted, signal.Data)
wrapped_signal, err := ParseSignal(ctx, decrypted)
if err != nil {
return nil, err
}
return wrapped_signal, nil
}

@ -1,188 +1,9 @@
package graphvent package graphvent
import ( import (
"fmt"
"time"
"encoding/json" "encoding/json"
"crypto/ecdsa"
"crypto/x509"
"crypto/ecdh"
) )
type ECDHState struct {
ECKey *ecdh.PrivateKey
SharedSecret []byte
}
type ECDHStateJSON struct {
ECKey []byte `json:"ec_key"`
SharedSecret []byte `json:"shared_secret"`
}
func (state *ECDHState) MarshalJSON() ([]byte, error) {
var key_bytes []byte
var err error
if state.ECKey != nil {
key_bytes, err = x509.MarshalPKCS8PrivateKey(state.ECKey)
if err != nil {
return nil, err
}
}
return json.Marshal(&ECDHStateJSON{
ECKey: key_bytes,
SharedSecret: state.SharedSecret,
})
}
func (state *ECDHState) UnmarshalJSON(data []byte) error {
var j ECDHStateJSON
err := json.Unmarshal(data, &j)
if err != nil {
return err
}
state.SharedSecret = j.SharedSecret
if len(j.ECKey) == 0 {
state.ECKey = nil
} else {
tmp_key, err := x509.ParsePKCS8PrivateKey(j.ECKey)
if err != nil {
return err
}
ecdsa_key, ok := tmp_key.(*ecdsa.PrivateKey)
if ok == false {
return fmt.Errorf("Parsed wrong key type from DB for ECDHState")
}
state.ECKey, err = ecdsa_key.ECDH()
if err != nil {
return err
}
}
return nil
}
type ECDHMap map[NodeID]ECDHState
func (m ECDHMap) MarshalJSON() ([]byte, error) {
tmp := map[string]ECDHState{}
for id, state := range(m) {
tmp[id.String()] = state
}
return json.Marshal(tmp)
}
type ECDHExt struct {
ECDHStates ECDHMap
}
func NewECDHExt() *ECDHExt {
return &ECDHExt{
ECDHStates: ECDHMap{},
}
}
func ResolveFields[T Extension](t T, name string, field_funcs map[string]func(T)interface{})interface{} {
var zero T
field_func, ok := field_funcs[name]
if ok == false {
return fmt.Errorf("%s is not a field of %s", name, zero.Type())
}
return field_func(t)
}
func (ext *ECDHExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*ECDHExt)interface{}{
"ecdh_states": func(ext *ECDHExt) interface{} {
return ext.ECDHStates
},
})
}
func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal ECDHSignal) {
ctx.Log.Logf("ecdh", "ECDH_SIGNAL: %s->%s - %+v", source, node, signal)
switch signal.State {
case "req":
state, exists := ext.ECDHStates[source]
if exists == false {
state = ECDHState{nil, nil}
}
resp, shared_secret, err := NewECDHRespSignal(ctx, node, signal)
if err == nil {
state.SharedSecret = shared_secret
ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
ctx.Send(node.ID, source, resp)
} else {
ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err)
// TODO: send error response
}
case "resp":
state, exists := ext.ECDHStates[source]
if exists == false || state.ECKey == nil {
ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHSignalType), "no_req"})
} else {
err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW)
if err == nil {
shared_secret, err := state.ECKey.ECDH(signal.ECDH)
if err == nil {
state.SharedSecret = shared_secret
state.ECKey = nil
ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
}
}
}
default:
ctx.Log.Logf("ecdh", "unknown echd state %s", signal.State)
}
}
func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal StateSignal) {
}
func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) {
switch signal.Direction() {
case Direct:
switch signal.Type() {
case ECDHSignalType:
switch ecdh_signal := signal.(type) {
case ECDHSignal:
ext.HandleECDHSignal(ctx, source, node, ecdh_signal)
case StateSignal:
ext.HandleStateSignal(ctx, source, node, ecdh_signal)
default:
ctx.Log.Logf("ecdh", "BAD_SIGNAL_CAST: %+v", signal)
}
default:
}
default:
}
}
func (ext *ECDHExt) Type() ExtType {
return ECDHExtType
}
func (ext *ECDHExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext, "", " ")
}
func LoadECDHExt(ctx *Context, data []byte) (Extension, error) {
var ext ECDHExt
err := json.Unmarshal(data, &ext)
if err != nil {
return nil, err
}
return &ext, nil
}
type GroupExt struct { type GroupExt struct {
Members map[NodeID]string Members map[NodeID]string
} }