#!/usr/bin/env python3
import json
import os
import queue
import shutil
import subprocess
import threading
from dataclasses import dataclass
from pathlib import Path
import tkinter as tk
from tkinter import filedialog, messagebox, ttk

VIDEO_EXTENSIONS = {".mkv", ".mp4", ".mov", ".m4v", ".avi", ".ts", ".m2ts", ".webm"}

FILTER_HDR_PQ_TO_SDR8 = (
    "setparams=color_primaries=bt2020:color_trc=smpte2084:colorspace=bt2020nc,"
    "zscale=t=linear:npl=100,"
    "format=gbrpf32le,"
    "tonemap=mobius:desat=0,"
    "zscale=p=bt709:t=bt709:m=bt709:r=tv,"
    "format=yuv420p"
)

FILTER_HLG_TO_SDR8 = (
    "setparams=color_primaries=bt2020:color_trc=arib-std-b67:colorspace=bt2020nc,"
    "zscale=t=linear:npl=100,"
    "format=gbrpf32le,"
    "tonemap=mobius:desat=0,"
    "zscale=p=bt709:t=bt709:m=bt709:r=tv,"
    "format=yuv420p"
)

FILTER_SDR10_TO_SDR8 = "format=yuv420p"


@dataclass
class FileDecision:
    path: Path
    detected_type: str
    reason: str
    action: str
    status: str = "En attente"


class FFmpegBatchApp:
    def __init__(self, root: tk.Tk):
        self.root = root
        self.root.title("Batch vidéo - HDR/SDR + metadata")
        self.root.geometry("1280x760")

        self.folder_var = tk.StringVar()
        self.codec_var = tk.StringVar(value="libx265")
        self.crf_var = tk.StringVar(value="18")
        self.preset_var = tk.StringVar(value="slow")
        self.skip_sdr8_var = tk.BooleanVar(value=True)
        self.force_convert_var = tk.BooleanVar(value=False)
        self.reset_only_var = tk.BooleanVar(value=False)
        self.keep_chapters_var = tk.BooleanVar(value=False)
        self.keep_backup_var = tk.BooleanVar(value=False)
        self.delete_source_var = tk.BooleanVar(value=True)
        self.copy_attachments_var = tk.BooleanVar(value=True)
        self.copy_audio_var = tk.BooleanVar(value=True)
        self.copy_subs_var = tk.BooleanVar(value=True)

        self.items: list[FileDecision] = []
        self.worker_thread = None
        self.stop_requested = False
        self.log_queue: queue.Queue = queue.Queue()

        self._build_ui()
        self._poll_log_queue()

    def _build_ui(self):
        top = ttk.Frame(self.root, padding=10)
        top.pack(fill="x")

        ttk.Label(top, text="Dossier").grid(row=0, column=0, sticky="w")
        ttk.Entry(top, textvariable=self.folder_var).grid(row=0, column=1, sticky="ew", padx=8)
        ttk.Button(top, text="Choisir", command=self.choose_folder).grid(row=0, column=2, padx=4)
        ttk.Button(top, text="Analyser", command=self.scan_files).grid(row=0, column=3, padx=4)
        ttk.Button(top, text="Lancer", command=self.start_processing).grid(row=0, column=4, padx=4)
        ttk.Button(top, text="Stop", command=self.request_stop).grid(row=0, column=5, padx=4)
        top.columnconfigure(1, weight=1)

        opts = ttk.LabelFrame(self.root, text="Options", padding=10)
        opts.pack(fill="x", padx=10, pady=(0, 10))

        ttk.Checkbutton(opts, text="Skip SDR 8-bit", variable=self.skip_sdr8_var).grid(row=0, column=0, sticky="w", padx=6, pady=4)
        ttk.Checkbutton(opts, text="Force convertir quand même", variable=self.force_convert_var, command=self.recompute_actions).grid(row=0, column=1, sticky="w", padx=6, pady=4)
        ttk.Checkbutton(opts, text="Reset metadata uniquement", variable=self.reset_only_var, command=self.recompute_actions).grid(row=0, column=2, sticky="w", padx=6, pady=4)
        ttk.Checkbutton(opts, text="Garder les chapitres", variable=self.keep_chapters_var).grid(row=0, column=3, sticky="w", padx=6, pady=4)

        ttk.Checkbutton(opts, text="Garder backup", variable=self.keep_backup_var).grid(row=1, column=0, sticky="w", padx=6, pady=4)
        ttk.Checkbutton(opts, text="Supprimer la source si OK", variable=self.delete_source_var).grid(row=1, column=1, sticky="w", padx=6, pady=4)
        ttk.Checkbutton(opts, text="Copier les attachments", variable=self.copy_attachments_var).grid(row=1, column=2, sticky="w", padx=6, pady=4)
        ttk.Checkbutton(opts, text="Copier audio / sous-titres", variable=self.copy_audio_var).grid(row=1, column=3, sticky="w", padx=6, pady=4)

        ttk.Label(opts, text="Codec vidéo").grid(row=2, column=0, sticky="w", padx=6)
        ttk.Combobox(opts, textvariable=self.codec_var, values=["libx265", "libx264", "libsvtav1"], state="readonly", width=12).grid(row=2, column=1, sticky="w", padx=6)
        ttk.Label(opts, text="CRF").grid(row=2, column=2, sticky="e", padx=6)
        ttk.Entry(opts, textvariable=self.crf_var, width=8).grid(row=2, column=3, sticky="w", padx=6)
        ttk.Label(opts, text="Preset").grid(row=2, column=4, sticky="e", padx=6)
        ttk.Combobox(opts, textvariable=self.preset_var, values=["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"], state="readonly", width=12).grid(row=2, column=5, sticky="w", padx=6)

        middle = ttk.Panedwindow(self.root, orient="vertical")
        middle.pack(fill="both", expand=True, padx=10, pady=(0, 10))

        table_frame = ttk.Frame(middle)
        middle.add(table_frame, weight=3)

        cols = ("path", "detected", "action", "status", "reason")
        self.tree = ttk.Treeview(table_frame, columns=cols, show="headings")
        self.tree.heading("path", text="Fichier")
        self.tree.heading("detected", text="Type détecté")
        self.tree.heading("action", text="Action choisie")
        self.tree.heading("status", text="Statut")
        self.tree.heading("reason", text="Raison")
        self.tree.column("path", width=560)
        self.tree.column("detected", width=140)
        self.tree.column("action", width=170)
        self.tree.column("status", width=120)
        self.tree.column("reason", width=220)
        yscroll = ttk.Scrollbar(table_frame, orient="vertical", command=self.tree.yview)
        self.tree.configure(yscrollcommand=yscroll.set)
        self.tree.pack(side="left", fill="both", expand=True)
        yscroll.pack(side="right", fill="y")

        log_frame = ttk.LabelFrame(middle, text="Journal", padding=8)
        middle.add(log_frame, weight=1)
        self.log = tk.Text(log_frame, height=12, wrap="word")
        self.log.pack(fill="both", expand=True)

        bottom = ttk.Frame(self.root, padding=10)
        bottom.pack(fill="x")
        self.progress = ttk.Progressbar(bottom, mode="determinate")
        self.progress.pack(fill="x", side="left", expand=True)
        self.status_label = ttk.Label(bottom, text="Prêt")
        self.status_label.pack(side="left", padx=10)

    def choose_folder(self):
        folder = filedialog.askdirectory(title="Choisis le dossier")
        if folder:
            self.folder_var.set(folder)

    def log_msg(self, msg: str):
        self.log_queue.put(("log", msg))

    def set_status(self, msg: str):
        self.log_queue.put(("status", msg))

    def _poll_log_queue(self):
        try:
            while True:
                kind, value = self.log_queue.get_nowait()
                if kind == "log":
                    self.log.insert("end", value + "\n")
                    self.log.see("end")
                elif kind == "status":
                    self.status_label.config(text=value)
                elif kind == "refresh":
                    self.refresh_tree()
                elif kind == "progress":
                    cur, total = value
                    self.progress["maximum"] = max(total, 1)
                    self.progress["value"] = cur
        except queue.Empty:
            pass
        self.root.after(120, self._poll_log_queue)

    def scan_files(self):
        folder = self.folder_var.get().strip()
        if not folder:
            messagebox.showwarning("Dossier", "Choisis un dossier d'abord.")
            return

        root_path = Path(folder)
        if not root_path.exists():
            messagebox.showerror("Erreur", "Le dossier n'existe pas.")
            return

        self.items.clear()
        self.tree.delete(*self.tree.get_children())
        self.log.delete("1.0", "end")
        self.log_msg(f"Analyse de {root_path}")

        for path in root_path.rglob("*"):
            if path.is_file() and path.suffix.lower() in VIDEO_EXTENSIONS:
                detected_type, reason = self.classify_video(path)
                action = self.decide_action(detected_type)
                self.items.append(FileDecision(path=path, detected_type=detected_type, reason=reason, action=action))

        self.refresh_tree()
        self.set_status(f"{len(self.items)} fichier(s) détecté(s)")
        self.log_msg(f"Analyse terminée: {len(self.items)} fichier(s)")

    def recompute_actions(self):
        for item in self.items:
            item.action = self.decide_action(item.detected_type)
        self.refresh_tree()

    def refresh_tree(self):
        self.tree.delete(*self.tree.get_children())
        for item in self.items:
            self.tree.insert("", "end", values=(str(item.path), item.detected_type, item.action, item.status, item.reason))

    def request_stop(self):
        self.stop_requested = True
        self.set_status("Arrêt demandé...")
        self.log_msg("Arrêt demandé par l'utilisateur.")

    def start_processing(self):
        if not self.items:
            messagebox.showwarning("Aucun fichier", "Analyse d'abord le dossier.")
            return
        if self.worker_thread and self.worker_thread.is_alive():
            messagebox.showinfo("Traitement", "Un traitement est déjà en cours.")
            return

        self.stop_requested = False
        self.worker_thread = threading.Thread(target=self._worker_process, daemon=True)
        self.worker_thread.start()

    def _worker_process(self):
        total = len(self.items)
        done = 0
        for item in self.items:
            if self.stop_requested:
                break
            try:
                item.status = "En cours"
                self.log_queue.put(("refresh", None))
                self.set_status(f"Traitement: {item.path.name}")
                self.process_one(item)
            except Exception as e:
                item.status = f"Erreur"
                self.log_msg(f"ERREUR {item.path}: {e}")
            done += 1
            self.log_queue.put(("progress", (done, total)))
            self.log_queue.put(("refresh", None))

        self.set_status("Terminé" if not self.stop_requested else "Arrêté")
        self.log_msg("Traitement terminé." if not self.stop_requested else "Traitement arrêté.")

    def process_one(self, item: FileDecision):
        if item.action == "skip":
            item.status = "Ignoré"
            self.log_msg(f"SKIP {item.path}")
            return

        src = item.path
        tmp = src.with_name(src.stem + ".processed" + src.suffix)
        backup = src.with_name(src.name + ".bak")
        cmd = self.build_ffmpeg_command(src, tmp, item.action, item.detected_type)

        self.log_msg(" ".join(cmd))
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            if tmp.exists():
                tmp.unlink(missing_ok=True)
            item.status = "Erreur"
            self.log_msg(result.stderr.strip() or f"ffmpeg a échoué pour {src}")
            return

        if self.keep_backup_var.get():
            if backup.exists():
                backup.unlink()
            src.rename(backup)
        elif self.delete_source_var.get():
            src.unlink(missing_ok=True)

        tmp.rename(src)
        item.status = "OK"
        self.log_msg(f"OK {src}")

    def ffprobe_json(self, file_path: Path) -> dict:
        cmd = [
            "ffprobe", "-v", "error", "-print_format", "json",
            "-show_format", "-show_streams", "-show_chapters", str(file_path)
        ]
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            raise RuntimeError(result.stderr.strip() or f"ffprobe a échoué: {file_path}")
        return json.loads(result.stdout)

    def classify_video(self, file_path: Path) -> tuple[str, str]:
        data = self.ffprobe_json(file_path)
        v = next((s for s in data.get("streams", []) if s.get("codec_type") == "video"), None)
        if not v:
            return "unknown", "no_video_stream"

        transfer = (v.get("color_transfer") or "").lower()
        primaries = (v.get("color_primaries") or "").lower()
        matrix = (v.get("color_space") or "").lower()
        pix_fmt = (v.get("pix_fmt") or "").lower()
        bit_depth = self.parse_bit_depth(v)
        has_dovi = self.detect_dovi(v)

        if has_dovi:
            return "hdr_dovi", "dolby_vision_side_data"
        if transfer == "smpte2084":
            return "hdr_pq", "transfer=smpte2084"
        if transfer == "arib-std-b67":
            return "hdr_hlg", "transfer=arib-std-b67"
        if (primaries == "bt2020" or "bt2020" in matrix) and (bit_depth or 0) >= 10:
            return "hdr_assumed", "bt2020_and_10bit"
        if (bit_depth or 0) >= 10:
            return "sdr_10bit", f"bit_depth={bit_depth}"
        if bit_depth == 8 or ("yuv420p" in pix_fmt and "10" not in pix_fmt and "12" not in pix_fmt):
            return "sdr_8bit", f"pix_fmt={pix_fmt or 'unknown'}"
        return "unknown", "could_not_classify"

    @staticmethod
    def parse_bit_depth(stream: dict) -> int | None:
        bits = stream.get("bits_per_raw_sample")
        if bits:
            try:
                return int(bits)
            except ValueError:
                pass
        pix_fmt = (stream.get("pix_fmt") or "").lower()
        if "12" in pix_fmt:
            return 12
        if "10" in pix_fmt or "p10" in pix_fmt:
            return 10
        if pix_fmt:
            return 8
        return None

    @staticmethod
    def detect_dovi(stream: dict) -> bool:
        for side in stream.get("side_data_list", []) or []:
            side_type = (side.get("side_data_type") or "").lower()
            if "dovi" in side_type or "dolby vision" in side_type:
                return True
        return False

    def decide_action(self, detected_type: str) -> str:
        if self.reset_only_var.get():
            return "reset_only"
        if self.force_convert_var.get():
            return "convert_to_sdr8"
        if detected_type in {"hdr_pq", "hdr_hlg", "hdr_dovi", "hdr_assumed", "sdr_10bit"}:
            return "convert_to_sdr8"
        if detected_type == "sdr_8bit":
            return "skip" if self.skip_sdr8_var.get() else "reset_only"
        return "reset_only"

    def get_stream_preservation_info(self, file_path: Path):
        data = self.ffprobe_json(file_path)
        saved = []
        for out_idx, stream in enumerate(data.get("streams", [])):
            tags = stream.get("tags", {}) or {}
            disp = stream.get("disposition", {}) or {}
            saved.append({
                "out_idx": out_idx,
                "language": tags.get("language"),
                "title": tags.get("title"),
                "default": int(disp.get("default", 0)),
                "forced": int(disp.get("forced", 0)),
            })
        return saved

    def build_metadata_args(self, saved_streams):
        args = []
        for s in saved_streams:
            idx = s["out_idx"]

            flags = []
            if s["default"]:
                flags.append("default")
            if s["forced"]:
                flags.append("forced")

            disposition_value = "+".join(flags) if flags else "0"
            args += [f"-disposition:{idx}", disposition_value]

            if s["language"]:
                args += [f"-metadata:s:{idx}", f"language={s['language']}"]
            if s["title"]:
                args += [f"-metadata:s:{idx}", f"title={s['title']}"]
        return args

    def build_ffmpeg_command(self, src: Path, dst: Path, action: str, detected_type: str):
        saved = self.get_stream_preservation_info(src)
        cmd = [
            "ffmpeg", "-hide_banner", "-loglevel", "warning", "-stats", "-y",
            "-i", str(src), "-map", "0", "-map_metadata", "-1"
        ]

        cmd += ["-map_chapters", "0" if self.keep_chapters_var.get() else "-1"]

        if action == "convert_to_sdr8":
            vf = self.choose_filter(detected_type)
            cmd += [
                "-c:v", self.codec_var.get(),
                "-preset", self.preset_var.get(),
                "-crf", self.crf_var.get(),
                "-pix_fmt", "yuv420p",
                "-vf", vf,
                "-color_primaries", "bt709",
                "-color_trc", "bt709",
                "-colorspace", "bt709",
                "-color_range", "tv",
            ]
        elif action == "reset_only":
            cmd += ["-c:v", "copy"]
        else:
            raise ValueError(f"Action inconnue: {action}")

        if self.copy_audio_var.get():
            cmd += ["-c:a", "copy"]
        else:
            cmd += ["-c:a", "aac", "-b:a", "192k"]

        if self.copy_subs_var.get():
            cmd += ["-c:s", "copy"]
        else:
            cmd += ["-sn"]

        if self.copy_attachments_var.get():
            cmd += ["-c:t", "copy"]
        else:
            cmd += ["-map", "-0:t?"]

        cmd += self.build_metadata_args(saved)
        cmd.append(str(dst))
        return cmd

    @staticmethod
    def choose_filter(detected_type: str) -> str:
        if detected_type == "hdr_hlg":
            return FILTER_HLG_TO_SDR8
        if detected_type in {"hdr_pq", "hdr_dovi", "hdr_assumed"}:
            return FILTER_HDR_PQ_TO_SDR8
        return FILTER_SDR10_TO_SDR8


def main():
    root = tk.Tk()
    try:
        style = ttk.Style(root)
        if "clam" in style.theme_names():
            style.theme_use("clam")
    except Exception:
        pass
    app = FFmpegBatchApp(root)
    root.mainloop()


if __name__ == "__main__":
    main()
