from flask import Flask, request, jsonify
import sys
import os
import json
import base64
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import requests

app = Flask(__name__)

class KYC:
    def __init__(self, access_token):
        self.access_token = access_token

    def generate_key(self):
        private_key = rsa.generate_private_key(
            public_exponent=65537,
            key_size=2048,
            backend=default_backend()
        )
        public_key = private_key.public_key()

        public_key_pem = public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )

        private_key_pem = private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption()
        )

        return {
            'publicKey': public_key_pem.decode(),
            'privateKey': private_key_pem.decode()
        }

    def import_rsa_key(self, pem):
        public_key = serialization.load_pem_public_key(
            pem.encode(),
            backend=default_backend()
        )
        return public_key

    def generate_symmetric_key(self):
        return os.urandom(32)

    def format_message(self, data):
        data_as_base64 = base64.b64encode(data).decode()
        return f"-----BEGIN ENCRYPTED MESSAGE-----\r\n{data_as_base64}\r\n-----END ENCRYPTED MESSAGE-----"

    def aes_encrypt(self, data, symmetric_key):
        iv = os.urandom(12)
        cipher = Cipher(algorithms.AES(symmetric_key), modes.GCM(iv), backend=default_backend())
        encryptor = cipher.encryptor()
        ciphertext = encryptor.update(data) + encryptor.finalize()
        return iv + ciphertext + encryptor.tag

    def aes_decrypt(self, encrypted_data, symmetric_key):
        iv = encrypted_data[:12]
        tag = encrypted_data[-16:]
        ciphertext = encrypted_data[12:-16]
        cipher = Cipher(algorithms.AES(symmetric_key), modes.GCM(iv, tag), backend=default_backend())
        decryptor = cipher.decryptor()
        return decryptor.update(ciphertext) + decryptor.finalize()

    def encrypt_message(self, message, pub_pem):
        aes_key = self.generate_symmetric_key()
        server_key = serialization.load_pem_public_key(pub_pem.encode(), backend=default_backend())
        wrapped_aes_key = server_key.encrypt(
            aes_key,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=None
            )
        )
        encrypted_message = self.aes_encrypt(message.encode(), aes_key)
        payload = wrapped_aes_key + encrypted_message
        return self.format_message(payload)

    def decrypt_message(self, message, private_key):
        begin_tag = '-----BEGIN ENCRYPTED MESSAGE-----'
        end_tag = '-----END ENCRYPTED MESSAGE-----'
        message_contents = message[len(begin_tag)+1:-len(end_tag)-1].strip()
        binary_der_string = base64.b64decode(message_contents)

        wrapped_key_length = 256
        wrapped_key = binary_der_string[:wrapped_key_length]
        encrypted_message = binary_der_string[wrapped_key_length:]

        key = serialization.load_pem_private_key(
            private_key.encode(),
            password=None,
            backend=default_backend()
        )

        try:
            aes_key = key.decrypt(
                wrapped_key,
                padding.OAEP(
                    mgf=padding.MGF1(algorithm=hashes.SHA256()),
                    algorithm=hashes.SHA256(),
                    label=None
                )
            )
        except ValueError as e:
            return None

        try:
            decrypted_message = self.aes_decrypt(encrypted_message, aes_key)
            return decrypted_message.decode()
        except Exception as e:
            return None

    def generate_url(self, agen, nik_agen):
        key_pair = self.generate_key()
        public_key = key_pair['publicKey']
        private_key = key_pair['privateKey']

        api_url = 'https://api-satusehat.kemkes.go.id/kyc/v1/generate-url'

        pub_pem = '''-----BEGIN PUBLIC KEY-----
        MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAxLwvebfOrPLIODIxAwFp
        4Qhksdtn7bEby5OhkQNLTdClGAbTe2tOO5Tiib9pcdruKxTodo481iGXTHR5033I
        A5X55PegFeoY95NH5Noj6UUhyTFfRuwnhtGJgv9buTeBa4pLgHakfebqzKXr0Lce
        /Ff1MnmQAdJTlvpOdVWJggsb26fD3cXyxQsbgtQYntmek2qvex/gPM9Nqa5qYrXx
        8KuGuqHIFQa5t7UUH8WcxlLVRHWOtEQ3+Y6TQr8sIpSVszfhpjh9+Cag1EgaMzk+
        HhAxMtXZgpyHffGHmPJ9eXbBO008tUzrE88fcuJ5pMF0LATO6ayXTKgZVU0WO/4e
        iQIDAQAB
        -----END PUBLIC KEY-----'''

        data = {
            'agent_name': agen,
            'agent_nik': nik_agen,
            'public_key': public_key,
        }

        json_data = json.dumps(data)
        encrypted_payload = self.encrypt_message(json_data, pub_pem)

        headers = {
            'Content-Type': 'text/plain',
            'Authorization': f'Bearer {self.access_token}'
        }

        response = requests.post(api_url, data=encrypted_payload, headers=headers)

        if response.status_code != 200:
            return None

        return self.decrypt_message(response.text, private_key)

@app.route('/api/kyc/generate-url', methods=['POST'])
def generate_kyc_url():
    try:
        data = request.get_json()
        
        # Validate required fields
        required_fields = ['token', 'name', 'nik']
        for field in required_fields:
            if field not in data:
                return jsonify({'error': f'Missing required field: {field}'}), 400

        # Extract values
        access_token = data['token']['access_token']
        name = data['name']
        nik = data['nik']

        # Initialize KYC and generate URL
        kyc = KYC(access_token)
        result = kyc.generate_url(name, nik)

        if result is None:
            return jsonify({'error': 'Failed to generate URL'}), 500
            
        # Parse the nested JSON string into a proper JSON structure
        try:
            parsed_result = json.loads(result)
            return jsonify(parsed_result)
        except json.JSONDecodeError:
            return jsonify({'error': 'Failed to parse response'}), 500

    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8980)