#!/usr/bin/env python3
import sys
import glob
import argparse
import select
import os
import fcntl
import struct

# struct input_event: struct timeval (typically 2 longs), type (uint16), code (uint16), value (int32)
# '@' ensures native byte order, size, and alignment.
EVENT_FORMAT = '@llHHi'
EVENT_SIZE = struct.calcsize(EVENT_FORMAT)

EV_LED = 0x11

LED_CODES = {
    'numlock': 0x00,
    'capslock': 0x01,
    'scrolllock': 0x02
}

def get_led_state(led_type: str) -> int:
    pattern = f"/sys/class/leds/*::{led_type}/brightness"
    paths = glob.glob(pattern)
    if not paths:
        return 0
    for path in paths:
        try:
            with open(path, 'r') as f:
                if int(f.read().strip()) > 0:
                    return 1
        except (IOError, ValueError, OSError):
            continue
    return 0

def has_led_capability(event_path: str) -> bool:
    try:
        basename = os.path.basename(event_path)
        cap_path = f"/sys/class/input/{basename}/device/capabilities/led"
        if os.path.exists(cap_path):
            with open(cap_path, 'r') as f:
                return f.read().strip() != "0"
    except OSError:
        pass
    return False

class DeviceMonitor:
    def __init__(self, target_led: str):
        self.target_led = target_led
        self.target_led_code = LED_CODES[target_led]

        self.poller = select.poll()
        self.active_fds = {}

        self.last_state = get_led_state(self.target_led)
        self.emit_state(self.last_state)

    def emit_state(self, state: int) -> None:
        sys.stdout.write(f"{state}\n")
        sys.stdout.flush()

    def scan_devices(self) -> None:
        current_fds = set(self.active_fds.keys())
        paths = glob.glob('/dev/input/event*')
        found_fds = set()

        for path in paths:
            if not has_led_capability(path):
                continue

            already_open = False
            for fd, opened_path in self.active_fds.items():
                if opened_path == path:
                    found_fds.add(fd)
                    already_open = True
                    break

            if not already_open:
                try:
                    # Pure unbuffered OS file descriptor
                    fd = os.open(path, os.O_RDONLY | os.O_NONBLOCK)
                    self.poller.register(fd, select.POLLIN)
                    self.active_fds[fd] = path
                    found_fds.add(fd)
                except PermissionError:
                    sys.stderr.write(f"Error: Permission denied when opening {path}.\n")
                    sys.exit(1)
                except OSError:
                    continue

        for fd in current_fds - found_fds:
            self.remove_device(fd)

    def remove_device(self, fd: int) -> None:
        if fd in self.active_fds:
            try:
                self.poller.unregister(fd)
                os.close(fd)
                del self.active_fds[fd]
            except Exception:
                pass

    def run(self) -> None:
        self.scan_devices()

        try:
            while True:
                events = self.poller.poll(1000)

                if not events:
                    self.scan_devices()
                    current_state = get_led_state(self.target_led)
                    if current_state != self.last_state:
                        self.emit_state(current_state)
                        self.last_state = current_state
                    continue

                fds_to_remove = []
                # Flag to check if we intercepted an EV_LED during this batch
                led_event_triggered = False

                for fd, event in events:
                    if event & select.POLLIN:
                        try:
                            while True:
                                # Drain heavily with pure OS call
                                data = os.read(fd, 4096)
                                if not data:
                                    fds_to_remove.append(fd)
                                    break

                                # Process the chunk if we want direct parsing
                                events_count = len(data) // EVENT_SIZE
                                for i in range(events_count):
                                    chunk = data[i * EVENT_SIZE : (i + 1) * EVENT_SIZE]
                                    _, _, ev_type, ev_code, _ = struct.unpack(EVENT_FORMAT, chunk)

                                    if ev_type == EV_LED and ev_code == self.target_led_code:
                                        # Instead of acting immediately on the value, we just flag it.
                                        # This prevents state bouncing between multiple kernel devices.
                                        led_event_triggered = True

                        except BlockingIOError:
                            pass
                        except OSError:
                            fds_to_remove.append(fd)

                for fd in fds_to_remove:
                    self.remove_device(fd)

                # If ANY EV_LED was detected during the drain, fall back to sysfs ONCE
                # at the end of the batch. This solves the bouncing/dropping issue perfectly.
                if led_event_triggered:
                    current_state = get_led_state(self.target_led)
                    if current_state != self.last_state:
                        self.emit_state(current_state)
                        self.last_state = current_state

        except KeyboardInterrupt:
            sys.exit(0)
        finally:
            for fd in list(self.active_fds.keys()):
                try:
                    os.close(fd)
                except Exception:
                    pass

def main():
    parser = argparse.ArgumentParser(description="Keyboard LED monitor.")
    parser.add_argument(
        '-l', '--led',
        type=str,
        default='capslock',
        choices=['capslock', 'numlock', 'scrolllock'],
        help="Target LED to monitor"
    )
    args = parser.parse_args()

    monitor = DeviceMonitor(args.led)
    monitor.run()

if __name__ == "__main__":
    main()
