#!/usr/bin/env python3

import argparse
import os
import re
import shlex
import shutil
import subprocess
import sys
import tempfile
import tarfile
import zipfile
from pathlib import Path

DEFAULT_SHELL = "bash"
DEFAULT_TMPFS_SIZE = "8G"
SIZE_PATTERN = re.compile(r"^[1-9][0-9]*[KMGkmg]?$")


def download_to(url: str, dest_dir: Path) -> Path:
    # Import when needed
    import requests
    local_filename = url.split('/')[-1]
    dest_path = dest_dir / local_filename

    print(f"Downloading '{url}' to '{dest_path}'...")
    try:
        with requests.get(url, stream=True) as r:
            r.raise_for_status()
            with open(dest_path, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
    except Exception as e:
        sys.exit(f"Error: Failed to download '{url}': {e}")

    return dest_path


def check_dependencies(use_tmpfs: bool, is_zip: bool, is_tar: bool):
    deps = []
    if use_tmpfs:
        deps.append("sudo")
    if is_zip:
        deps.append("unzip")
    if is_tar:
        deps.append("tar")

    missing = [cmd for cmd in deps if shutil.which(cmd) is None]
    if missing:
        sys.exit(f"Error: Missing required system dependencies: {', '.join(missing)}")


def is_dir_empty(path: Path) -> bool:
    return not any(path.iterdir())


def get_strip_count(archive_path: Path, is_tar: bool, is_zip: bool) -> int:
    top_level = None

    if is_tar:
        cmd = ["tar", "-tf", str(archive_path)]
        with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1) as proc:
            if proc is None or proc.stdout is None:
                sys.exit("Error: Failed to list archive contents.")
            for line in proc.stdout:
                line = line.strip()
                if not line:
                    continue

                current_top = line.split('/', 1)[0]

                if top_level is None:
                    top_level = current_top
                elif top_level != current_top:
                    proc.terminate()
                    proc.wait()
                    return 0
            proc.wait()

    elif is_zip:
        with zipfile.ZipFile(archive_path, 'r') as zf:
            for name in zf.namelist():
                current_top = name.split('/', 1)[0]
                if top_level is None:
                    top_level = current_top
                elif top_level != current_top:
                    return 0

    else:
        sys.exit("Error: Unsupported archive type.")

    return 1 if top_level else 0


def extract_archive(archive_path: Path, dest_dir: Path, strip_components: int = 0):
    if tarfile.is_tarfile(archive_path):
        cmd = ["tar", "-xf", str(archive_path), "-C", str(dest_dir)]
        if strip_components > 0:
            cmd.append(f"--strip-components={strip_components}")
        subprocess.run(cmd, check=True)
    elif zipfile.is_zipfile(archive_path):
        if strip_components > 0:
            with tempfile.TemporaryDirectory() as tmp_ext:
                subprocess.run(["unzip", "-q", str(archive_path), "-d", tmp_ext], check=True)
                top_dir = next(Path(tmp_ext).iterdir())
                for item in top_dir.iterdir():
                    shutil.move(str(item), str(dest_dir))
        else:
            subprocess.run(["unzip", "-q", str(archive_path), "-d", str(dest_dir)], check=True)


def main():
    parser = argparse.ArgumentParser(description="Extract an archive to a directory and spawn a shell.")
    parser.add_argument("archive", type=str, help="Path to the tarball or zip file")
    parser.add_argument("--exec", "-e", dest="cmd", default=DEFAULT_SHELL,
                        help=f"Command to spawn (default: '{DEFAULT_SHELL}')")
    parser.add_argument("--target", "-t", type=Path, help="Target directory for extraction.")
    parser.add_argument("--disable-tmpfs", "-d", action="store_true", help="Disable tmpfs mounting, extract directly to dir")
    parser.add_argument("--size", "-s", default=DEFAULT_TMPFS_SIZE,
                        help=f"Size of the tmpfs if used (default: {DEFAULT_TMPFS_SIZE})")
    parser.add_argument("--no-cleanup", "-n", action="store_true", help="Disable cleanup (unmount and remove dir)")

    args = parser.parse_args()
    archive = args.archive.strip()

    if archive.startswith(('https://', 'http://')):
        archive = download_to(archive, Path.cwd())
    else:
        archive = Path(args.archive).resolve()

    if not archive.exists() or not archive.is_file():
        sys.exit(f"Error: Archive '{archive}' does not exist or is not a file.")

    if not SIZE_PATTERN.match(args.size):
        sys.exit(f"Error: Invalid size format '{args.size}'. Expected format like '4G', '500M'.")

    is_tar = tarfile.is_tarfile(archive)
    is_zip = zipfile.is_zipfile(archive)
    use_tmpfs = not args.disable_tmpfs
    do_cleanup = not args.no_cleanup

    check_dependencies(use_tmpfs, is_zip, is_tar)

    dir_existed_before = False
    if args.target:
        target_dir = args.target.resolve()
    else:
        target_dir = archive.parent / archive.stem
    if target_dir.exists():
        dir_existed_before = True
        if not target_dir.is_dir():
            sys.exit(f"Error: Target '{target_dir}' exists and is not a directory.")
        if not is_dir_empty(target_dir):
            sys.exit(f"Error: Target '{target_dir}' exists and is not empty.")
    else:
        target_dir.mkdir(parents=True, exist_ok=True)

    strip_components = strip_components = get_strip_count(archive, is_tar, is_zip)

    try:
        if use_tmpfs:
            print(f"Mounting tmpfs at '{target_dir}' with size {args.size}...")
            uid, gid = os.getuid(), os.getgid()
            mount_opts = f"size={args.size},uid={uid},gid={gid},mode=0700"
            subprocess.run(["sudo", "mount", "-t", "tmpfs", "-o", mount_opts, "tmpfs", str(target_dir)], check=True)

        print(f"Extracting '{archive}' to '{target_dir}'...")
        extract_archive(archive, target_dir, strip_components)

        print(f"Spawning '{args.cmd}' in {target_dir}...")

        parsed_cmd = shlex.split(args.cmd)
        subprocess.run(parsed_cmd, cwd=str(target_dir))

    finally:
        if do_cleanup:
            print("Cleaning up...")
            if use_tmpfs:
                subprocess.run(["sudo", "umount", str(target_dir)], stderr=subprocess.DEVNULL)
            if not dir_existed_before:
                try:
                    shutil.rmtree(target_dir)
                except PermissionError:
                    if shutil.which("sudo"):
                        subprocess.run(["sudo", "rm", "-rf", str(target_dir)], stderr=subprocess.DEVNULL)
        else:
            print(f"Cleanup disabled. Extracted contents are left at: {target_dir}")


if __name__ == "__main__":
    main()
