import javax.crypto.IllegalBlockSizeException;
import java.io.PrintStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.*;
import java.security.spec.ECGenParameterSpec;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Arrays;
import java.util.Base64;
import java.util.UUID;

import static java.nio.charset.StandardCharsets.UTF_8;

public class JWTCreator {
    public static final String DEFAULT_AUDIENCE = "https://hypr.com";
    public static final String DEFAULT_ISSUER = "https://hypr.com";
    public static final String DEFAULT_SUBJECT = "API_KEY";
    public static final long DEFAULT_MAX_VALIDITY_IN_SECS = 3600L;

    public static final String DEFAULT_ALGORITHM = "ES256";
    public static final String DEFAULT_KEY_ID = UUID.randomUUID().toString();

    public static void main(String[] args) throws Exception {
        // Check for -help option first
        if (Arrays.asList(args).contains("-help")) {
            printUsage();
            System.exit(0);
        }

        printUsage();

        String sub = validateArgument(args, "-sub", DEFAULT_SUBJECT, true);
        String iss = validateArgument(args, "-iss", DEFAULT_ISSUER, false);
        String aud = validateArgument(args, "-aud", DEFAULT_AUDIENCE, false);
        String kid = validateArgument(args, "-kid", DEFAULT_KEY_ID, false);
        long exp = Long.parseLong(validateArgument(args, "-exp", System.currentTimeMillis() / 1000 + DEFAULT_MAX_VALIDITY_IN_SECS, false));
        String algorithm = validateArgument(args, "-alg", DEFAULT_ALGORITHM, false);
        String privateKeyBase64 = validateArgument(args, "-keybase64", "", true);
        String outFile = validateArgument(args, "-out", "", false);

        PrivateKey privateKey = loadPrivateKeyFromBase64(privateKeyBase64, algorithm);

        // Generate the JWT
        String jwt = generateJwt(privateKey, iss, sub, aud, exp, kid, algorithm);

        System.out.println("\nGenerated JWT Token:");
        System.out.println(jwt);

        if (!outFile.isEmpty()){
            Files.writeString(Path.of(outFile), jwt);
            System.out.println("\nWrote to " + outFile + " file");
        }
    }

    private static String validateArgument(String[] args, String argName, Object defaultValue, boolean mandatory) {
        final boolean contains = Arrays.stream(args).toList().contains(argName);
        if (!contains && mandatory) {
            throw new IllegalArgumentException("\nMissing mandatory arg " + argName);
        }
        if (!contains) {
            return defaultValue.toString();
        }

        final int argIndex = Arrays.stream(args).toList().indexOf(argName);
        final int argValueIndex = argIndex + 1;
        if (argValueIndex >= args.length || args[argValueIndex].startsWith("-")) {
            throw new IllegalArgumentException("\nMissing value for arg" + argName);
        }
        return args[argIndex + 1];
    }

    private static void printUsage() {
        final PrintStream out = System.out;
        out.println("\nUsage: java JWTCreator [options]");
        out.println("Options:");
        out.println("  -help              Show this help message and exit");
        out.println("  -sub <subject>     MANDATORY subject claim. `Client Id` from CC UI token creation screen");
        out.println("  -keybase64 <key>   MANDATORY base64 private key. `Client key` from the CC UI token creation screen");
        out.println("  -exp <expiration>  OPTIONAL  expiration time of the JWT token created. In seconds since epoch. Default: now() + 1 hour");
        out.println("  -out <file>        OPTIONAL  output file to write the JWT token");
    }

    public static String generateJwt(String privateKey, String sub) throws Exception {
        return generateJwt(
                loadPrivateKeyFromBase64(privateKey, DEFAULT_ALGORITHM),
                DEFAULT_ISSUER,
                sub,
                DEFAULT_AUDIENCE,
                System.currentTimeMillis() / 1000 + DEFAULT_MAX_VALIDITY_IN_SECS,
                DEFAULT_KEY_ID,
                DEFAULT_ALGORITHM);
    }

    public static String generateJwt(String privateKey, String iss, String sub, String aud, long exp, String kid, String algorithm)
            throws Exception {
        return generateJwt(
                loadPrivateKeyFromBase64(privateKey, algorithm),
                iss,
                sub,
                aud,
                exp,
                kid,
                algorithm);
    }

    private static String generateJwt(
            PrivateKey privateKey,
            String iss,
            String sub,
            String aud,
            long exp,
            String kid,
            String algo) throws Exception {

        // Step 1: Create the JWT Header
        String header64 = Base64.getUrlEncoder().withoutPadding().encodeToString(
                String.format("{" +
                              "\"alg\":\"%s\"," +
                              "\"typ\":\"JWT\"," +
                              "\"kid\":\"%s\"}", algo, kid).getBytes(UTF_8));

        // Step 2: Build the JWT Payload using StringBuilder
        final long now = (System.currentTimeMillis() / 1000);
        final String jwtId = UUID.randomUUID().toString();

        String payload64 = Base64.getUrlEncoder().withoutPadding().encodeToString(
                String.format("{ \"iss\": \"%s\","
                              + "\"sub\": \"%s\","
                              + "\"aud\": \"%s\","
                              + "\"jti\": \"%s\","
                              + "\"exp\": %d,"
                              + "\"iat\": %d,"
                              + "\"nbf\": %d }", iss, sub, aud, jwtId, exp, now, now).getBytes(UTF_8));

        // Step 3: Concatenate header and payload
        String unsignedToken = header64 + "." + payload64;

        // Step 4: Sign the token
        final byte[] signedBytes = signString(unsignedToken, privateKey, algo);
        String signatureBase64 = Base64.getUrlEncoder().withoutPadding().encodeToString(signedBytes);

        // Step 6: Create the final JWT
        return unsignedToken + "." + signatureBase64;
    }

    public static PrivateKey loadPrivateKeyFromBase64(String base64Key, String algorithm) throws Exception {
        byte[] keyBytes = Base64.getDecoder().decode(base64Key);
        PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(keyBytes);
        return (switch (algorithm) {
            case "RS256" -> KeyFactory.getInstance("RSA");
            case "ES256" -> KeyFactory.getInstance("EC");
            default -> throw new IllegalArgumentException("Unsupported algorithm for key loading: " + algorithm);
        }).generatePrivate(spec);
    }

    public static KeyPair generateKeyPair(String algorithm) throws Exception {
        KeyPairGenerator keyPairGenerator;

        switch (algorithm) {
            case "RS256" -> {
                keyPairGenerator = KeyPairGenerator.getInstance("RSA");
                keyPairGenerator.initialize(2048);
            }
            case "ES256" -> {
                keyPairGenerator = KeyPairGenerator.getInstance("EC");
                keyPairGenerator.initialize(new ECGenParameterSpec("secp256r1"));
            }
            default -> throw new IllegalArgumentException("Unsupported algorithm: " + algorithm);
        }

        return keyPairGenerator.generateKeyPair();
    }

    public static byte[] signString(String input, PrivateKey privateKey, String algorithm) throws Exception {
        String signatureAlgorithm = switch (algorithm.toUpperCase()) {
            case "RS256" -> "SHA256withRSA";
            case "ES256" -> "SHA256withECDSA";
            default -> throw new IllegalArgumentException("Unsupported algorithm: " + algorithm);
        };

        // Create the signature instance
        Signature signature = Signature.getInstance(signatureAlgorithm);
        signature.initSign(privateKey, new SecureRandom());
        signature.update(input.getBytes());

        // Generate and return the signature as Base64
        return algorithm.equals("ES256") ? transcodeSignatureToConcat(signature.sign(), 64) : signature.sign();
    }

    /**
     * transform a DER-formatted ECDSA (Elliptic Curve Digital Signature Algorithm) signature
     * into a comma-separated string format
     */
    public static byte[] transcodeSignatureToConcat(final byte[] derSignature, final int outputLength) throws IllegalBlockSizeException {

        if (derSignature.length < 8 || derSignature[0] != 48) {
            throw new IllegalBlockSizeException("Invalid ECDSA signature format");
        }

        int offset;
        if (derSignature[1] > 0) {
            offset = 2;
        } else if (derSignature[1] == (byte) 0x81) {
            offset = 3;
        } else {
            throw new IllegalBlockSizeException("Invalid ECDSA signature format");
        }

        byte rLength = derSignature[offset + 1];

        int i;
        for (i = rLength; (i > 0) && (derSignature[(offset + 2 + rLength) - i] == 0); i--) {
            // do nothing
        }

        byte sLength = derSignature[offset + 2 + rLength + 1];

        int j;
        for (j = sLength; (j > 0) && (derSignature[(offset + 2 + rLength + 2 + sLength) - j] == 0); j--) {
            // do nothing
        }

        int rawLen = Math.max(i, j);
        rawLen = Math.max(rawLen, outputLength / 2);

        if ((derSignature[offset - 1] & 0xff) != derSignature.length - offset
            || (derSignature[offset - 1] & 0xff) != 2 + rLength + 2 + sLength
            || derSignature[offset] != 2
            || derSignature[offset + 2 + rLength] != 2) {
            throw new IllegalBlockSizeException("Invalid ECDSA signature format");
        }

        final byte[] concatSignature = new byte[2 * rawLen];

        System.arraycopy(derSignature, (offset + 2 + rLength) - i, concatSignature, rawLen - i, i);
        System.arraycopy(derSignature, (offset + 2 + rLength + 2 + sLength) - j, concatSignature, 2 * rawLen - j, j);

        return concatSignature;
    }
}
