echo-go/main.go

246 lines
4.6 KiB
Go

package echo_go
import (
"bytes"
"crypto/aes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"net"
"strings"
"git.vh7.uk/jakew/echo-go/crypto"
)
func randomHex(n int) (string, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
func (c *Client) SendPlain(msgType MessageType, data *string, subType *string, metadata []string) error {
metadataBytes, err := json.Marshal(&metadata)
if err != nil {
return err
}
msg := RawMessage{
UserId: c.UserId,
MessageType: msgType,
SubType: subType,
Data: data,
Metadata: string(metadataBytes),
}
b, err := json.Marshal(&msg)
if err != nil {
return err
}
//log.Printf("sending message %v", msg)
_, err = c.Con.Write(b)
return err
}
func (c *Client) Send(msgType MessageType, data *string, subType *string, metadata []string) error {
metadataBytes, err := json.Marshal(&metadata)
if err != nil {
return err
}
msg := RawMessage{
UserId: c.UserId,
MessageType: msgType,
SubType: subType,
Data: data,
Metadata: string(metadataBytes),
}
plainBytes, err := json.Marshal(&msg)
if err != nil {
return err
}
aesCipher, err := aes.NewCipher([]byte(c.SessionKey))
if err != nil {
return err
}
encryptedBytes, err := crypto.EncryptAesCbc(aesCipher, plainBytes)
if err != nil {
return err
}
_, err = c.Con.Write(encryptedBytes)
return err
}
func (c *Client) rawReceive(messageEnding string) (string, error) {
received := ""
for !strings.HasSuffix(received, messageEnding) {
data := make([]byte, 20480)
_, err := c.Con.Read(data)
if err != nil {
return received, err
}
received = received + string(bytes.Trim(data, "\x00"))
}
return received, nil
}
func (c *Client) ReceivePlain() ([]RawMessage, error) {
received, err := c.rawReceive("}")
if err != nil {
return nil, err
}
messages := []RawMessage{}
for _, message := range strings.Split(received, "}") {
if strings.TrimSpace(message) == "" {
continue
}
var msg RawMessage
err = json.Unmarshal([]byte(message+"}"), &msg)
if err != nil {
return nil, err
}
messages = append(messages, msg)
}
return messages, nil
}
func (c *Client) Receive() ([]RawMessage, error) {
received, err := c.rawReceive("]")
if err != nil {
return nil, err
}
messages := []RawMessage{}
for _, message := range strings.Split(received, "]") {
if strings.TrimSpace(message) == "" {
continue
}
aesCipher, err := aes.NewCipher([]byte(c.SessionKey))
if err != nil {
return nil, err
}
decrypted, err := crypto.DecryptAesCbc(aesCipher, []byte(message+"]"))
if err != nil {
return nil, err
}
var msg RawMessage
err = json.Unmarshal(decrypted, &msg)
if err != nil {
return nil, err
}
messages = append(messages, msg)
}
return messages, nil
}
func (c *Client) HandshakeLoop(clientVersion string, password string) error {
log.Println("sending server info request")
err := c.SendPlain(ReqServerInfo, nil, nil, nil)
if err != nil {
return err
}
encrypted := false
for {
var msgs []RawMessage
var err error
if encrypted {
msgs, err = c.Receive()
} else {
msgs, err = c.ReceivePlain()
}
if err != nil {
return err
}
for _, msg := range msgs {
switch msg.MessageType {
case ResServerInfo:
ciphertext, err := crypto.RsaEncrypt([]byte(*msg.Data), []byte(c.SessionKey))
err = c.SendPlain(ReqClientSecret, &ciphertext, nil, nil)
if err != nil {
return err
}
encrypted = true
case ResClientSecret:
data, err := json.Marshal([]string{
c.Username,
password,
clientVersion,
"",
})
if err != nil {
return err
}
dataStr := string(data)
err = c.Send(ReqConnection, &dataStr, nil, nil)
if err != nil {
return err
}
case ResConnectionAccepted:
log.Println("handshake accepted")
return nil
case ResConnectionDenied:
return fmt.Errorf("handshake failed - %v", *msg.Data)
default:
return fmt.Errorf("unexpected handshake message type %v", msg.MessageType)
}
}
}
}
func (c *Client) Disconnect() {
log.Println("gracefully disconnecting")
_ = c.Send(ReqDisconnect, nil, nil, nil)
_ = c.Con.Close()
}
func New(addr string, username string) (*Client, error) {
userId, err := randomHex(32)
if err != nil {
return nil, err
}
sessionKey, err := randomHex(8)
if err != nil {
return nil, err
}
log.Printf("connecting to %v", addr)
con, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
client := Client{
Con: con,
UserId: userId,
SessionKey: sessionKey,
Username: username,
}
return &client, nil
}