#!/usr/bin/env python3

# Description:
#   Change color theme for various applications based on an image, color, or user selection.
#   This is done by running application-specific scripts located in the config directory.
#
# Requirements:
#   - colorthief (python3 package) # too lazy to implement color extraction myself :D

import os
import sys
import argparse
import subprocess
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from threading import Lock

MAX_WORKERS = 8

PALETTES = {
    "catppuccin-mocha": {
        "rosewater": "f5e0dc",
        "flamingo": "f2cdcd",
        "pink": "f5c2e7",
        "mauve": "cba6f7",
        "red": "f38ba8",
        "maroon": "eba0ac",
        "peach": "fab387",
        "yellow": "f9e2af",
        "green": "a6e3a1",
        "teal": "94e2d5",
        "sky": "89dceb",
        "sapphire": "74c7ec",
        "blue": "89b4fa",
        "lavender": "b4befe",
    },
}

CONFIG_DIR = Path("~/.config").expanduser()

# An application may have multiple scripts (e.g. due to config-switch)
SCRIPTS = {
    "fastfetch": [CONFIG_DIR / "fastfetch" / "apply-color"],
    "kvantum": [CONFIG_DIR / "Kvantum" / "apply-color"],
    "nwg-look": [CONFIG_DIR / "nwg-look" / "apply-color"],
    "niri": [CONFIG_DIR / "niri" / "apply-color"],
    "oh-my-posh": [
        CONFIG_DIR / "fish" / "apply-color-omp"
    ],  # borrowing fish's directory
    "starship": [
        CONFIG_DIR / "fish" / "apply-color-starship"
    ],  # borrowing fish's directory
    "quickshell": [CONFIG_DIR / "quickshell" / "apply-color"],
    "wlogout": [CONFIG_DIR / "wlogout" / "apply-color"],
    "yazi": [CONFIG_DIR / "yazi" / "apply-color"],
}
# or simply `find [-L] <CONFIG_DIR> -type f -name 'apply-color*'` to get all available scripts,
# but I need the exact application names anyway, so hardcoding does make some sense

# A thread-safe counter
success_count = 0
success_count_lock = Lock()


def hex2rgb(hex_color: str) -> tuple[int, int, int]:
    """#rrggbb to (r, g, b)"""
    return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))  # type: ignore


def clamp(x, minimum, maximum) -> float:
    """Clamp x to the range [minimum, maximum]"""
    return max(minimum, min(x, maximum))


def rgb2hsv(rr: int, gg: int, bb: int) -> tuple[float, float, float]:
    """(r, g, b) 0-255 to (h, s, v)"""
    r, g, b = rr / 255.0, gg / 255.0, bb / 255.0
    r = clamp(r, 0.0, 1.0)
    g = clamp(g, 0.0, 1.0)
    b = clamp(b, 0.0, 1.0)
    mx = max(r, g, b)
    mn = min(r, g, b)
    df = mx - mn
    h = 0.0
    if mx == mn:
        h = 0.0
    elif mx == r:
        h = (60 * ((g - b) / df) + 360) % 360
    elif mx == g:
        h = (60 * ((b - r) / df) + 120) % 360
    elif mx == b:
        h = (60 * ((r - g) / df) + 240) % 360
    if mx == 0:
        s = 0.0
    else:
        s = (df / mx) * 100
    v = mx * 100
    return h, s, v


def extract_color(image_path: str) -> str:
    """Extract a dominant color from the image and return it as a #rrggbb string."""
    # Only import when needed
    from colorthief import ColorThief

    ct = ColorThief(image_path)

    # Get first 5 dominant colors
    palette = ct.get_palette(color_count=5, quality=10)

    best_color = None
    max_score = -1.0

    for color in palette:
        h, s, v = rgb2hsv(*color)

        # Filter out undesirable colors
        # Too dark
        if v < 20:
            continue
        # Too light
        if v > 95 and s < 5:
            continue

        # Saturation first, then value
        score = s * 2.0 + v

        if score > max_score:
            max_score = score
            best_color = color

    # Fallback to the most dominant color
    if best_color is None:
        best_color = ct.get_color(quality=10)

    return "#{:02x}{:02x}{:02x}".format(*best_color)


def match_color(color: str, palette: dict[str, str]) -> str:
    """Match the given #rrggbb color to the closest flavor in the palette."""
    color = color.lower().strip().removeprefix("#")
    target_rgb = hex2rgb(color)
    target_h, target_s, target_v = rgb2hsv(*target_rgb)

    # Warn if not representative (nearly grayscale)
    if target_s < 5:
        print(
            f"Warning: Extracted color {color} is nearly grayscale. Matching might be inaccurate."
        )

    def get_weighted_distance(hex_val: str) -> float:
        p_rgb = hex2rgb(hex_val)
        p_h, p_s, p_v = rgb2hsv(*p_rgb)

        # RGB distance with weighting
        rmean = (target_rgb[0] + p_rgb[0]) / 2
        dr = target_rgb[0] - p_rgb[0]
        dg = target_rgb[1] - p_rgb[1]
        db = target_rgb[2] - p_rgb[2]
        rgb_distance = (
            (2 + rmean / 256) * dr**2 + 4 * dg**2 + (2 + (255 - rmean) / 256) * db**2
        ) ** 0.5

        # Hue difference (with wrapping)
        hue_diff = abs(target_h - p_h)
        if hue_diff > 180:
            hue_diff = 360 - hue_diff

        # Increase hue weight when saturation is high
        hue_weight = 2.0 if target_s > 20 else 0.5

        return rgb_distance + (hue_diff * hue_weight * 3)

    closest_flavor = min(
        palette.keys(), key=lambda k: get_weighted_distance(palette[k])
    )
    print(f"Matched color #{color} to {closest_flavor} (#{palette[closest_flavor]})")
    return closest_flavor


def pick_flavor_interactive(palette: dict[str, str]) -> str:
    """Prompt the user to pick a flavor interactively."""

    def is_interactive() -> bool:
        return sys.stdin.isatty() and sys.stdout.isatty()

    def is_truecolor() -> bool:
        colorterm = os.environ.get("COLORTERM", "")
        term = os.environ.get("TERM", "")

        return (
            "truecolor" in colorterm
            or "24bit" in colorterm
            or term.endswith("-256color")
        )

    if is_interactive():
        isTruecolor = is_truecolor()
        print("Available flavors:")
        for i, flavor in enumerate(palette.keys(), 1):
            r, g, b = hex2rgb(palette[flavor])
            if isTruecolor:
                print(
                    f"\033[38;2;{r};{g};{b}m█ {i}. {flavor}: #{palette[flavor]}\033[0m"
                )
            else:
                print(f"{i}. {flavor}")
        while True:
            choice = input("Pick a flavor by number: ")
            if choice.isdigit() and 1 <= int(choice) <= len(palette):
                return list(palette.keys())[int(choice) - 1]
            print("Invalid choice. Try again.")
    else:
        print("No flavor specified.")
        sys.exit(1)


def run_script(script_path: Path, args: list[str]):
    """Helper to run a single script safely."""
    script_str = str(script_path)
    if not script_path.exists():
        print(f"Warning: Script not found: {script_str}")
        return
    if not os.access(script_path, os.X_OK):
        print(f"Warning: Script not executable: {script_str}")
        return

    try:
        cmd = [script_str] + args
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            print(f"Error running {script_path}:\n{result.stderr.strip()}")
        else:
            print(f"✓ {script_path}")
            with success_count_lock:
                global success_count
                success_count += 1
    except Exception as e:
        print(f"Exception running {script_path}: {e}")


def main():
    parser = argparse.ArgumentParser(
        description="Change color theme for various applications."
    )
    parser.add_argument("-i", "--image", type=str, help="Path to the image")
    parser.add_argument("-f", "--flavor", type=str, help="Flavor to apply")
    parser.add_argument(
        "-c", "--color", type=str, help="Color to match from the palette"
    )
    parser.add_argument(
        "apps",
        nargs="*",
        help="'app1 !app2' to include(only) / exclude(all but) specific applications. "
        "Available apps: " + ", ".join(SCRIPTS.keys()),
    )

    arguments = parser.parse_args()

    # for future use, probably
    def parse_palette_name() -> str:
        return "catppuccin-mocha"

    def parse_flavor(palette: dict[str, str]) -> str:
        if arguments.flavor:
            if arguments.flavor not in palette:
                print(
                    f"Unknown flavor: {arguments.flavor}. Available flavors: {', '.join(palette.keys())}"
                )
                sys.exit(1)
            flavor = arguments.flavor
        elif arguments.color:
            flavor = match_color(arguments.color, palette)
            print(f"Matched color: {flavor}")
        elif arguments.image:
            if not Path(arguments.image).exists():
                print(f"Image file {arguments.image} does not exist.")
                sys.exit(1)
            color = extract_color(arguments.image)
            print(f"Extracted color {color} from image {arguments.image}")
            flavor = match_color(color, palette)
            print(f"Matched color: {flavor}")
        else:
            flavor = pick_flavor_interactive(palette)
        return flavor

    def parse_apps() -> tuple[set[str], set[str]]:
        includes = set()
        excludes = set()
        for arg in arguments.apps:
            if arg.startswith("!"):
                excludes.add(arg[1:])
            else:
                includes.add(arg)
        return includes, excludes

    palette_name = parse_palette_name()
    palette = PALETTES[palette_name]
    flavor = parse_flavor(palette)
    includes, excludes = parse_apps()

    apps = set()
    if includes:
        print(f"Including only: {', '.join(includes)}")
        for app in includes:
            if app in SCRIPTS:
                apps.add(app)
            else:
                print(
                    f"Unknown application: {app}. Available applications: {', '.join(SCRIPTS.keys())}"
                )
                sys.exit(1)
    else:
        apps = set(SCRIPTS.keys())

    if excludes:
        print(f"Excluding: {', '.join(excludes)}")
        apps -= excludes

    print(f"Applying flavor '{flavor}' for {len(apps)} applications.")

    script_args = [palette_name, flavor, palette[flavor]]
    tasks = []

    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        for app in apps:
            for script in SCRIPTS[app]:
                tasks.append(executor.submit(run_script, script, script_args))

    # subprocess.run(
    #     [
    #         "notify-send",
    #         "-a",
    #         "change-colortheme",
    #         "Colortheme Changed",
    #         f"Palette: {palette_name}\nFlavor: {flavor}\nApplied to {success_count} apps",
    #     ]
    # )


if __name__ == "__main__":
    main()
