from base64 import b64encode, b64decode
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

import os


ENCODING: str = "utf-8"


def get_encryption_key(iv: bytes, password: str) -> bytes:
    return PBKDF2HMAC(
        algorithm = hashes.SHA256(),
        length = 32,
        salt = iv,
        iterations = 100_000,
        backend = default_backend()
    ).derive(password.encode(ENCODING))

def encrypt_message_text(message_text: str, password: str) -> str:
    iv: bytes = os.urandom(16)

    key: bytes = get_encryption_key(iv, password)

    cipher: Cipher = Cipher(
        algorithm = algorithms.AES(key),
        mode = modes.CBC(iv),
        backend = default_backend()
    )

    encryptor = cipher.encryptor()

    message_bytes: bytes = message_text.encode(ENCODING)

    pad_length: int = 16 - len(message_bytes) % 16
    padded_message: bytes = message_bytes + bytes([pad_length] * pad_length)

    encrypted_data: bytes = encryptor.update(padded_message) + encryptor.finalize()

    return b64encode(iv + encrypted_data).decode(ENCODING)

def decrypt_message_text(message_text: str, password: str) -> str:
    encrypted_data_with_iv: bytes = b64decode(message_text)

    iv: bytes = encrypted_data_with_iv[:16]
    encrypted_data: bytes = encrypted_data_with_iv[16:]

    key: bytes = get_encryption_key(iv, password)

    cipher: Cipher = Cipher(
        algorithm = algorithms.AES(key),
        mode = modes.CBC(iv),
        backend = default_backend()
    )

    decryptor = cipher.decryptor()

    decrypted_padded_message: bytes = decryptor.update(encrypted_data) + decryptor.finalize()

    pad_length: bytes = decrypted_padded_message[-1]
    decrypted_message: str = decrypted_padded_message[:-pad_length].decode(ENCODING)

    return decrypted_message
