package main

import (
	"crypto/rand"
	"encoding/binary"
	"encoding/hex"
	"fmt"
	"io"
	"os"
	. "speistel/pkg"
	"strconv"
	"time"
)

const (
	defaultKey            = "0000000000000001000000020000000300000004000000050000000600000007"
	defaultMaxMessageSize = 10 * 1024 * 1024
	defaultTimeout        = 30 * time.Second
)

func getenv(key string, def string) string {
	if v, ok := os.LookupEnv(key); ok {
		return v
	}
	return def
}

func getenvInt(key string, def int) int {
	if v, ok := os.LookupEnv(key); ok {
		if n, err := strconv.Atoi(v); err == nil {
			return n
		}
	}
	return def
}

func getenvDuration(key string, def time.Duration) time.Duration {
	if v, ok := os.LookupEnv(key); ok {
		if d, err := time.ParseDuration(v); err == nil {
			return d
		}
	}
	return def
}

func newConnID() string {
	buf := make([]byte, 4)
	_, err := rand.Read(buf)
	if err != nil {
		return fmt.Sprintf("%x", time.Now().UnixNano())
	}
	return hex.EncodeToString(buf)
}

// readWithTimeout wraps io.ReadFull with a timeout.
func readWithTimeout(r io.Reader, buf []byte, timeout time.Duration) error {
	done := make(chan error, 1)
	go func() {
		_, err := io.ReadFull(r, buf)
		done <- err
	}()
	select {
	case err := <-done:
		return err
	case <-time.After(timeout):
		return fmt.Errorf("read timeout after %s", timeout)
	}
}

// writeWithTimeout wraps io.Writer.Write with a timeout.
func writeWithTimeout(w io.Writer, buf []byte, timeout time.Duration) error {
	done := make(chan error, 1)
	go func() {
		_, err := w.Write(buf)
		done <- err
	}()
	select {
	case err := <-done:
		return err
	case <-time.After(timeout):
		return fmt.Errorf("write timeout after %s", timeout)
	}
}

func fprintf(format string, a ...interface{}) {
	_, _ = fmt.Fprintf(os.Stderr, format, a...)
}

func main() {
	// 1. read the settings
	keyHex := getenv("SPEISTEL_KEY", defaultKey)
	maxMessageSize := getenvInt("SPEISTEL_MAX_MESSAGE_SIZE", defaultMaxMessageSize)
	timeout := getenvDuration("SPEISTEL_TIMEOUT", defaultTimeout)

	clientAddr := os.Getenv("SOCAT_PEERADDR")
	clientPort := os.Getenv("SOCAT_PEERPORT")
	connID := newConnID()

	start := time.Now()
	fprintf("[%s] Connection started from %s:%s\n", connID, clientAddr, clientPort)

	// 2. read the message
	// read the first 4 bytes for message length
	lenBuf := make([]byte, 4)
	if err := readWithTimeout(os.Stdin, lenBuf, timeout); err != nil {
		fprintf("[%s] [!] failed to read length: %v\n", connID, err)
		os.Exit(1)
	}

	n := binary.BigEndian.Uint32(lenBuf)
	if int(n) > maxMessageSize {
		fprintf("[%s] [!] message is too large: %d > %d\n", connID, n, maxMessageSize)
		os.Exit(1)
	}

	// read the next n bytes
	data := make([]byte, n)
	if err := readWithTimeout(os.Stdin, data, timeout); err != nil {
		fprintf("[%s] [!] failed to read data: %v\n", connID, err)
		os.Exit(1)
	}
	fprintf("[%s] received %d bytes, encrypting...\n", connID, len(data))

	// 3. encrypt the data
	ciphertext, err := EncryptData(data, keyHex)
	if err != nil {
		fprintf("[%s] [!] failed to encrypt data: %v\n", connID, err)
		os.Exit(1)
	}

	// 4. send back the ciphertext
	outLen := make([]byte, 4)
	binary.BigEndian.PutUint32(outLen, uint32(len(ciphertext)))
	if err := writeWithTimeout(os.Stdout, outLen, timeout); err != nil {
		fprintf("[%s] [!] failed to write length: %v\n", connID, err)
		os.Exit(1)
	}
	if err := writeWithTimeout(os.Stdout, ciphertext, timeout); err != nil {
		fprintf("[%s] [!] failed to write ciphertext: %v\n", connID, err)
		os.Exit(1)
	}

	fprintf("[%s] sent %d bytes, connection closed (duration %s)\n", connID, len(ciphertext), time.Since(start))
}
