#!/usr/bin/env python3
"""
AI Background Remover - Video Processing Script
Processes videos frame by frame using rembg model
"""

import sys
import os
import json
import argparse
import tempfile
import shutil
from pathlib import Path

# Force CPU-only ONNX to avoid CUDA library errors
os.environ["ORT_DISABLE_CUDA"] = "1"
os.environ["ONNXRUNTIME_PROVIDERS"] = "CPUExecutionProvider"

try:
    from rembg import remove, new_session
    from PIL import Image
    import cv2
    import numpy as np
except ImportError as e:
    print(json.dumps({"error": f"Missing dependency: {str(e)}"}))
    sys.exit(1)


def hex_to_rgb(hex_color: str):
    """Convert hex color to RGB tuple."""
    hex_color = hex_color.lstrip('#')
    if len(hex_color) == 6:
        return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    elif len(hex_color) == 3:
        return tuple(int(c * 2, 16) for c in hex_color)
    return (255, 255, 255)


def process_frame(frame_bgr, session, bg_type, bg_color_rgb, bg_image_resized):
    """Process a single video frame."""
    # Convert BGR (OpenCV) to RGB (PIL)
    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    pil_frame = Image.fromarray(frame_rgb).convert("RGBA")
    
    # Remove background
    output_frame = remove(pil_frame, session=session)
    
    h, w = frame_bgr.shape[:2]
    
    # Apply background
    if bg_type == "transparent":
        # For video, use black background for transparency areas (video can't be transparent)
        bg = Image.new("RGBA", (w, h), (0, 0, 0, 255))
        bg.paste(output_frame, mask=output_frame.split()[3])
        result = np.array(bg.convert("RGB"))
        result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
        
    elif bg_type == "white":
        bg = Image.new("RGBA", (w, h), (255, 255, 255, 255))
        bg.paste(output_frame, mask=output_frame.split()[3])
        result = np.array(bg.convert("RGB"))
        result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
        
    elif bg_type == "black":
        bg = Image.new("RGBA", (w, h), (0, 0, 0, 255))
        bg.paste(output_frame, mask=output_frame.split()[3])
        result = np.array(bg.convert("RGB"))
        result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
        
    elif bg_type == "custom" and bg_color_rgb:
        r, g, b = bg_color_rgb
        bg = Image.new("RGBA", (w, h), (r, g, b, 255))
        bg.paste(output_frame, mask=output_frame.split()[3])
        result = np.array(bg.convert("RGB"))
        result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
        
    elif bg_type == "custom_image" and bg_image_resized is not None:
        bg = bg_image_resized.copy().convert("RGBA")
        bg.paste(output_frame, mask=output_frame.split()[3])
        result = np.array(bg.convert("RGB"))
        result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
        
    else:
        # Default: black background
        bg = Image.new("RGBA", (w, h), (0, 0, 0, 255))
        bg.paste(output_frame, mask=output_frame.split()[3])
        result = np.array(bg.convert("RGB"))
        result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
    
    return result_bgr


def remove_background_video(
    input_path: str,
    output_path: str,
    bg_type: str = "black",
    bg_color: str = None,
    bg_image_path: str = None,
    model: str = "u2netp",
    progress_callback=None
) -> dict:
    """
    Remove background from a video.
    
    Args:
        input_path: Path to input video
        output_path: Path to save output video
        bg_type: Background type
        bg_color: Custom hex color
        bg_image_path: Path to background image
        model: rembg model name
        progress_callback: Optional callback for progress updates
    
    Returns:
        dict with success status and output info
    """
    try:
        # Open video
        cap = cv2.VideoCapture(input_path)
        if not cap.isOpened():
            return {"success": False, "error": "Cannot open video file"}
        
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        if fps <= 0:
            fps = 24

        # Quick file size check (server should also enforce this)
        try:
            file_bytes = os.path.getsize(input_path)
            max_bytes = 50 * 1024 * 1024
            if file_bytes > max_bytes:
                return {"success": False, "error": "Input video exceeds 50 MB limit"}
        except Exception:
            pass

        # Duration check: limit to 20 seconds
        try:
            duration = total_frames / fps if total_frames > 0 else 0
            if duration > 20:
                return {"success": False, "error": f"Video duration {round(duration,1)}s exceeds 20s limit"}
        except Exception:
            pass
        
        # Set up output video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        if not out.isOpened():
            return {"success": False, "error": "Failed to open video writer for output (codec/format may be unsupported)"}
        
        # Initialize rembg session
        session = new_session(model)
        
        # Parse background color
        bg_color_rgb = None
        if bg_type == "custom" and bg_color:
            bg_color_rgb = hex_to_rgb(bg_color)
        
        # Load and resize background image if needed
        bg_image_resized = None
        if bg_type == "custom_image" and bg_image_path and os.path.exists(bg_image_path):
            bg_image_resized = Image.open(bg_image_path).convert("RGBA")
            bg_image_resized = bg_image_resized.resize((width, height), Image.LANCZOS)

        # Downscale frames for faster processing if very large
        target_max_dim = 720
        max_dim = max(width, height) if max(width, height) > 0 else target_max_dim
        scale = min(1.0, target_max_dim / max_dim)
        small_w = max(1, int(width * scale))
        small_h = max(1, int(height * scale))
        bg_image_resized_small = None
        if bg_image_resized is not None and scale < 1.0:
            bg_image_resized_small = bg_image_resized.resize((small_w, small_h), Image.LANCZOS)
        else:
            bg_image_resized_small = bg_image_resized
        
        # Process frames
        frame_count = 0
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break

            # Optionally downscale frame for faster processing
            if scale < 1.0:
                small_frame = cv2.resize(frame, (small_w, small_h), interpolation=cv2.INTER_AREA)
            else:
                small_frame = frame

            processed_small = process_frame(
                small_frame, session, bg_type, bg_color_rgb, bg_image_resized_small
            )

            # Upscale back to original size if we downscaled
            if scale < 1.0 and (processed_small.shape[1] != width or processed_small.shape[0] != height):
                processed_frame = cv2.resize(processed_small, (width, height), interpolation=cv2.INTER_LINEAR)
            else:
                processed_frame = processed_small

            out.write(processed_frame)

            frame_count += 1
            
            # Write progress to stderr for monitoring
            if frame_count % 10 == 0 or frame_count == total_frames:
                progress = (frame_count / total_frames * 100) if total_frames > 0 else 0
                progress_data = json.dumps({
                    "type": "progress",
                    "frame": frame_count,
                    "total": total_frames,
                    "percent": round(progress, 1)
                })
                print(progress_data, file=sys.stderr)
        
        cap.release()
        out.release()
        
        # Get file size
        file_size = os.path.getsize(output_path) if os.path.exists(output_path) else 0
        
        return {
            "success": True,
            "output_path": output_path,
            "original_size": [width, height],
            "fps": fps,
            "total_frames": frame_count,
            "file_size": file_size,
            "bg_type": bg_type
        }
        
    except Exception as e:
        return {
            "success": False,
            "error": str(e),
            "input_path": input_path
        }


def main():
    parser = argparse.ArgumentParser(description='AI Background Remover - Video')
    parser.add_argument('--input', required=True, help='Input video path')
    parser.add_argument('--output', required=True, help='Output video path')
    parser.add_argument('--bg-type', default='black',
                        choices=['transparent', 'white', 'black', 'custom', 'custom_image'],
                        help='Background type')
    parser.add_argument('--bg-color', default=None, help='Custom background color (hex)')
    parser.add_argument('--bg-image', default=None, help='Custom background image path')
    parser.add_argument('--model', default='u2netp', help='rembg model')
    
    args = parser.parse_args()
    
    result = remove_background_video(
        input_path=args.input,
        output_path=args.output,
        bg_type=args.bg_type,
        bg_color=args.bg_color,
        bg_image_path=args.bg_image,
        model=args.model or 'u2netp'
    )
    
    print(json.dumps(result))


if __name__ == "__main__":
    main()
