from base64 import b64encode, b64decode
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

import os


ENCODING = "utf-8"


def get_encryption_key(iv: bytes, password: str) -> bytes:
    return PBKDF2HMAC(
        algorithm = hashes.SHA256(),
        length = 32,
        salt = iv,
        iterations = 100_000,
        backend = default_backend()
    ).derive(password.encode(ENCODING))

def encrypt_message_text(message_text: str, password: str) -> str:
    iv = os.urandom(16)

    key = get_encryption_key(iv, password)

    cipher = Cipher(
        algorithm = algorithms.AES(key),
        mode = modes.CBC(iv),
        backend = default_backend()
    )

    encryptor = cipher.encryptor()

    message_bytes = message_text.encode(ENCODING)

    pad_length = 16 - len(message_bytes) % 16
    padded_message = message_bytes + bytes([pad_length] * pad_length)

    encrypted_data = encryptor.update(padded_message) + encryptor.finalize()

    return b64encode(iv + encrypted_data).decode(ENCODING)

def decrypt_message_text(message_text: str, password: str) -> str:
    encrypted_data_with_iv = b64decode(message_text)

    iv = encrypted_data_with_iv[:16]
    encrypted_data = encrypted_data_with_iv[16:]

    key = get_encryption_key(iv, password)

    cipher = Cipher(
        algorithm = algorithms.AES(key),
        mode = modes.CBC(iv),
        backend = default_backend()
    )

    decryptor = cipher.decryptor()

    decrypted_padded_message = decryptor.update(encrypted_data) + decryptor.finalize()

    pad_length = decrypted_padded_message[-1]
    decrypted_message = decrypted_padded_message[:-pad_length].decode(ENCODING)

    return decrypted_message
