package pkg

import (
	"encoding/binary"
	"encoding/hex"
	"fmt"
)

func PKCS7Pad(data []byte, blockSize int) []byte {
	if blockSize <= 0 || blockSize > 255 {
		panic("invalid block size")
	}
	padLen := blockSize - (len(data) % blockSize)
	if padLen == 0 {
		padLen = blockSize
	}
	padding := make([]byte, padLen)
	for i := range padding {
		padding[i] = byte(padLen)
	}
	return append(data, padding...)
}

func PKCS7Unpad(data []byte, blockSize int) ([]byte, error) {
	if len(data) == 0 || len(data)%blockSize != 0 {
		return nil, fmt.Errorf("invalid padded data length")
	}
	padLen := int(data[len(data)-1])
	if padLen == 0 || padLen > blockSize {
		return nil, fmt.Errorf("invalid padding length")
	}
	for _, b := range data[len(data)-padLen:] {
		if int(b) != padLen {
			return nil, fmt.Errorf("invalid padding bytes")
		}
	}
	return data[:len(data)-padLen], nil
}

func BytesToUint64s(data []byte) ([]uint64, error) {
	if rem := len(data) % 8; rem != 0 {
		data = append(data, make([]byte, 8-rem)...)
		return nil, fmt.Errorf("invalid data length: must be multiple of 8, got %d", len(data))
	}
	blocks := make([]uint64, len(data)/8)
	for i := 0; i < len(blocks); i++ {
		blocks[i] = binary.BigEndian.Uint64(data[i*8 : (i+1)*8])
	}
	return blocks, nil
}

func Uint64sToBytes(blocks []uint64) []byte {
	buf := make([]byte, len(blocks)*8)
	for i, b := range blocks {
		binary.BigEndian.PutUint64(buf[i*8:(i+1)*8], b)
	}
	return buf
}

func ParseRoundKeys(keyHex string) ([2][4]uint32, error) {
	var roundKeys [2][4]uint32

	data, err := hex.DecodeString(keyHex)
	if err != nil {
		return roundKeys, err
	}
	if len(data) != 32 {
		return roundKeys, fmt.Errorf("invalid key length: got %d bytes, expected 32", len(data))
	}

	for row := 0; row < 2; row++ {
		for col := 0; col < 4; col++ {
			offset := (row*4 + col) * 4
			roundKeys[row][col] = binary.BigEndian.Uint32(data[offset : offset+4])
		}
	}
	return roundKeys, nil
}

func Substitution(input uint32) uint32 {
	var output uint32 = 0
	for i := 0; i < 8; i++ {
		output |= uint32(SBoxes[i][input&0xF]) << (4 * i)
		input = input >> 4
	}
	return output
}

func Permutation(input uint32) uint32 {
	var output uint32 = 0
	for i := 0; i < 32; i++ {
		if input&1 == 1 {
			output |= 1 << PBox[i]
		}
		input = input >> 1
	}
	return output
}

func AddRoundKey(input, roundKey uint32) uint32 {
	return input ^ roundKey
}

func F(p uint32, roundKeys *[4]uint32) uint32 {
	r1 := AddRoundKey(Permutation(Substitution(p)), roundKeys[0])
	r2 := AddRoundKey(Permutation(Substitution(r1)), roundKeys[1])
	r3 := AddRoundKey(Permutation(Substitution(r2)), roundKeys[2])
	r4 := AddRoundKey(Permutation(Substitution(r3)), roundKeys[3])
	return r4
}

func EncryptBlock(block uint64, roundKeys *[2][4]uint32) uint64 {
	L0, R0 := uint32(block>>32), uint32(block&0xFFFFFFFF)
	L1, R1 := R0, L0^F(R0, &roundKeys[0])
	L2, R2 := R1, L1^F(R1, &roundKeys[1])
	return uint64(L2)<<32 | uint64(R2)
}

func DecryptBlock(block uint64, roundKeys *[2][4]uint32) uint64 {
	L2, R2 := uint32(block>>32), uint32(block&0xFFFFFFFF)
	L1, R1 := R2^F(L2, &roundKeys[1]), L2
	L0, R0 := R1^F(L1, &roundKeys[0]), L1
	return uint64(L0)<<32 | uint64(R0)
}

func EncryptData(data []byte, keyHex string) ([]byte, error) {
	const blockSize = 8

	keys, err := ParseRoundKeys(keyHex)
	if err != nil {
		return nil, err
	}

	data = PKCS7Pad(data, blockSize)
	blocks, err := BytesToUint64s(data)
	if err != nil {
		return nil, err
	}

	for i := range blocks {
		blocks[i] = EncryptBlock(blocks[i], &keys)
	}

	return Uint64sToBytes(blocks), nil
}

func DecryptData(data []byte, keyHex string) ([]byte, error) {
	const blockSize = 8

	keys, err := ParseRoundKeys(keyHex)
	if err != nil {
		return nil, err
	}

	blocks, err := BytesToUint64s(data)
	if err != nil {
		return nil, err
	}

	for i := range blocks {
		blocks[i] = DecryptBlock(blocks[i], &keys)
	}
	data = Uint64sToBytes(blocks)
	decrypted, err := PKCS7Unpad(data, blockSize)
	if err != nil {
		return nil, err
	}

	return decrypted, nil
}
