diff --git a/packet.go b/packet.go index d3d830f..d58edec 100644 --- a/packet.go +++ b/packet.go @@ -245,7 +245,7 @@ func ParseSessionData(session *Session, encrypted []byte) ([]byte, error) { func NewSessionClose(session *Session) []byte { packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CLOSE_LENGTH) - packet[0] = COMMAND_LENGTH + packet[0] = byte(SESSION_CLOSE) copy(packet[1:], session.ID[:]) hmac := sha512.Sum512(append(session.ID[:], session.secret...)) diff --git a/packet_test.go b/packet_test.go index d37b979..6bd91ad 100644 --- a/packet_test.go +++ b/packet_test.go @@ -25,6 +25,9 @@ func TestSessionOpen(t *testing.T) { client_so, client_ecdh, err := NewSessionOpen(client_key) fatalErr(t, err) + if client_so[0] != byte(SESSION_OPEN) { + t.Fatalf("Session open command byte mismatch(%x != %x)", client_so[0], SESSION_OPEN) + } server_so, server_ecdh, err := NewSessionOpen(server_key) fatalErr(t, err) @@ -71,7 +74,14 @@ func TestSessionConnect(t *testing.T) { test_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:8080") fatalErr(t, err) + session_id := ID[SessionID](secret) session_connect := NewSessionConnect(test_addr, secret) + if session_connect[0] != byte(SESSION_CONNECT) { + t.Fatalf("Session open command byte mismatch(%x != %x)", session_connect[0], SESSION_CONNECT) + } else if slices.Compare(session_connect[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH], session_id[:]) != 0 { + t.Fatal("Session open ID mismatch") + } + parsed_addr, err := ParseSessionConnect(session_connect[COMMAND_LENGTH + ID_LENGTH:], secret) fatalErr(t, err) @@ -100,10 +110,16 @@ func TestSessionData(t *testing.T) { fatalErr(t, err) message := []byte("hello") - server_hello, err := NewSessionData(&server_session, message) + session_data, err := NewSessionData(&server_session, message) fatalErr(t, err) - parsed_message, err := ParseSessionData(&client_session, server_hello[COMMAND_LENGTH+ID_LENGTH:]) + if session_data[0] != byte(SESSION_DATA) { + t.Fatalf("Session data command byte mismatch(%x != %x)", session_data[0], SESSION_DATA) + } else if slices.Compare(session_data[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH], server_session.ID[:]) != 0 { + t.Fatal("Session data ID mismatch") + } + + parsed_message, err := ParseSessionData(&client_session, session_data[COMMAND_LENGTH+ID_LENGTH:]) fatalErr(t, err) if slices.Compare(message, parsed_message) != 0 { @@ -131,5 +147,12 @@ func TestSessionClose(t *testing.T) { fatalErr(t, err) session_close := NewSessionClose(&client_session) + + if session_close[0] != byte(SESSION_CLOSE) { + t.Fatalf("Session close command byte mismatch(%x != %x)", session_close[0], SESSION_CLOSE) + } else if slices.Compare(session_close[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH], server_session.ID[:]) != 0 { + t.Fatal("Session close ID mismatch") + } + fatalErr(t, ParseSessionClose(&server_session, session_close[COMMAND_LENGTH+ID_LENGTH:])) }