#!/usr/bin/env python3

import sys
import subprocess


def has_ping6() -> bool:
    try:
        result = subprocess.run(['ping6', '-c', '1', '::1'],
                                stdout=subprocess.DEVNULL,
                                stderr=subprocess.DEVNULL)
        return result.returncode == 0
    except Exception:
        return False


def ping(host: str, packet_size: int, is_ipv6: bool = False) -> bool:
    ping6 = has_ping6()
    if is_ipv6:
        header_size = 48  # 40 bytes IPv6 header + 8 bytes ICMPv6 header
        if ping6:
            cmd = ['ping6', '-c', '2', '-M', 'do', '-s', str(packet_size - header_size), host]
        else:
            cmd = ['ping', '-6', '-c', '2', '-M', 'do', '-s', str(packet_size - header_size), host]
    else:
        header_size = 28  # 20 bytes IP header + 8 bytes ICMP header
        if ping6:
            cmd = ['ping', '-c', '2', '-M', 'do', '-s', str(packet_size - header_size), host]
        else:
            cmd = ['ping', '-4', '-c', '2', '-M', 'do', '-s', str(packet_size - header_size), host]

    try:
        result = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        return result.returncode == 0
    except Exception:
        return False


def bin_search_mtu(host: str, min_mtu: int, max_mtu: int, is_ipv6: bool = False) -> int | None:
    l = min_mtu
    h = max_mtu
    ret: int = 0

    print(f"Testing {'IPv6' if is_ipv6 else 'IPv4'} MTU for {host}...")

    if not ping(host, l, is_ipv6):
        print(f"  Cannot reach host even with MTU {l}.")
        return None

    while l <= h:
        m = (l + h) // 2
        if ping(host, m, is_ipv6):
            ret = m
            print(f"  MTU {m} OK")
            l = m + 1
        else:
            print(f"  MTU {m} Failed")
            h = m - 1

    return ret


def main():
    if len(sys.argv) < 2 or len(sys.argv) == 3 or len(sys.argv) > 4:
        print(f"Usage: {sys.argv[0]} <domain> [min_mtu max_mtu]")
        sys.exit(1)

    domain: str = sys.argv[1]
    min_search: int = int(sys.argv[2]) if len(sys.argv) >= 4 else 1000
    max_search: int = int(sys.argv[3]) if len(sys.argv) >= 4 else 1500

    ipv4_mtu = bin_search_mtu(domain, min_search, max_search, is_ipv6=False)
    if ipv4_mtu:
        print(f"Max IPv4 MTU: {ipv4_mtu}")
    else:
        print("Could not determine IPv4 MTU.")

    print("-" * 20)

    ipv6_mtu = bin_search_mtu(domain, min_search, max_search, is_ipv6=True)
    if ipv6_mtu:
        print(f"Max IPv6 MTU: {ipv6_mtu}")
    else:
        print("Could not determine IPv6 MTU.")


if __name__ == "__main__":
    main()
