#!/usr/bin/env python3
"""
AI Background Remover - Image Processing Script
Uses rembg (U2Net model) for high-quality background removal
"""

import sys
import os
import json
import argparse
from pathlib import Path
from io import BytesIO

# Force CPU-only ONNX to avoid CUDA library errors
os.environ["ORT_DISABLE_CUDA"] = "1"
os.environ["ONNXRUNTIME_PROVIDERS"] = "CPUExecutionProvider"

try:
    from rembg import remove, new_session
    from PIL import Image, ImageFilter
    import numpy as np
except ImportError as e:
    print(json.dumps({"error": f"Missing dependency: {str(e)}"}))
    sys.exit(1)

def remove_background_image(
    input_path: str,
    output_path: str,
    bg_type: str = "transparent",  # transparent, white, black, custom
    bg_color: str = None,          # hex color like #FF0000
    bg_image_path: str = None,     # path to custom background image
    output_quality: str = "standard",  # standard or 4k
    model: str = "u2net"
) -> dict:
    """
    Remove background from an image and apply custom background.
    
    Args:
        input_path: Path to input image
        output_path: Path to save output image
        bg_type: Background type (transparent, white, black, custom, custom_image)
        bg_color: Custom hex color (used when bg_type='custom')
        bg_image_path: Path to background image (used when bg_type='custom_image')
        output_quality: 'standard' or '4k'
        model: rembg model name
    
    Returns:
        dict with success status and output info
    """
    try:
        # Load input image
        input_img = Image.open(input_path).convert("RGBA")
        original_size = input_img.size
        
        # Create rembg session
        session = new_session(model)
        
        # Remove background - get image with alpha channel
        output_img = remove(input_img, session=session)
        
        # Determine target size
        if output_quality == "4k":
            # Scale up to 4K (3840x2160) maintaining aspect ratio
            target_w, target_h = 3840, 2160
            ratio = min(target_w / original_size[0], target_h / original_size[1])
            if ratio > 1:  # Only upscale if needed
                new_size = (int(original_size[0] * ratio), int(original_size[1] * ratio))
                output_img = output_img.resize(new_size, Image.LANCZOS)
            else:
                new_size = original_size
        else:
            new_size = original_size
        
        final_size = output_img.size
        
        # Apply background based on type
        if bg_type == "transparent":
            # Keep transparent, save as PNG
            final_img = output_img
            save_path = output_path if output_path.endswith('.png') else output_path.rsplit('.', 1)[0] + '.png'
            final_img.save(save_path, "PNG", optimize=True)
            
        elif bg_type == "white":
            bg = Image.new("RGBA", final_size, (255, 255, 255, 255))
            bg.paste(output_img, mask=output_img.split()[3])
            final_img = bg.convert("RGB")
            save_path = output_path if output_path.endswith('.jpg') or output_path.endswith('.jpeg') else output_path.rsplit('.', 1)[0] + '.jpg'
            final_img.save(save_path, "JPEG", quality=95, optimize=True)
            
        elif bg_type == "black":
            bg = Image.new("RGBA", final_size, (0, 0, 0, 255))
            bg.paste(output_img, mask=output_img.split()[3])
            final_img = bg.convert("RGB")
            save_path = output_path if output_path.endswith('.jpg') or output_path.endswith('.jpeg') else output_path.rsplit('.', 1)[0] + '.jpg'
            final_img.save(save_path, "JPEG", quality=95, optimize=True)
            
        elif bg_type == "custom" and bg_color:
            # Parse hex color
            color_hex = bg_color.lstrip('#')
            if len(color_hex) == 6:
                r, g, b = int(color_hex[0:2], 16), int(color_hex[2:4], 16), int(color_hex[4:6], 16)
            elif len(color_hex) == 3:
                r = int(color_hex[0] * 2, 16)
                g = int(color_hex[1] * 2, 16)
                b = int(color_hex[2] * 2, 16)
            else:
                r, g, b = 255, 255, 255
            
            bg = Image.new("RGBA", final_size, (r, g, b, 255))
            bg.paste(output_img, mask=output_img.split()[3])
            final_img = bg.convert("RGB")
            save_path = output_path if output_path.endswith('.jpg') or output_path.endswith('.jpeg') else output_path.rsplit('.', 1)[0] + '.jpg'
            final_img.save(save_path, "JPEG", quality=95, optimize=True)
            
        elif bg_type == "custom_image" and bg_image_path and os.path.exists(bg_image_path):
            # Use custom background image
            bg_img = Image.open(bg_image_path).convert("RGBA")
            bg_img = bg_img.resize(final_size, Image.LANCZOS)
            bg_img.paste(output_img, mask=output_img.split()[3])
            final_img = bg_img.convert("RGB")
            save_path = output_path if output_path.endswith('.jpg') or output_path.endswith('.jpeg') else output_path.rsplit('.', 1)[0] + '.jpg'
            final_img.save(save_path, "JPEG", quality=95, optimize=True)
        else:
            # Default: transparent PNG
            final_img = output_img
            save_path = output_path if output_path.endswith('.png') else output_path.rsplit('.', 1)[0] + '.png'
            final_img.save(save_path, "PNG", optimize=True)
        
        # Get file size
        file_size = os.path.getsize(save_path)
        
        return {
            "success": True,
            "output_path": save_path,
            "original_size": list(original_size),
            "output_size": list(final_size),
            "file_size": file_size,
            "bg_type": bg_type
        }
        
    except Exception as e:
        return {
            "success": False,
            "error": str(e),
            "input_path": input_path
        }


def main():
    parser = argparse.ArgumentParser(description='AI Background Remover - Image')
    parser.add_argument('--input', required=True, help='Input image path')
    parser.add_argument('--output', required=True, help='Output image path')
    parser.add_argument('--bg-type', default='transparent', 
                        choices=['transparent', 'white', 'black', 'custom', 'custom_image'],
                        help='Background type')
    parser.add_argument('--bg-color', default=None, help='Custom background color (hex)')
    parser.add_argument('--bg-image', default=None, help='Custom background image path')
    parser.add_argument('--quality', default='standard', choices=['standard', '4k'],
                        help='Output quality')
    parser.add_argument('--model', default='u2netp', help='rembg model')
    
    args = parser.parse_args()
    
    result = remove_background_image(
        input_path=args.input,
        output_path=args.output,
        bg_type=args.bg_type,
        bg_color=args.bg_color,
        bg_image_path=args.bg_image,
        output_quality=args.quality,
        model=args.model or 'u2netp'
    )
    
    print(json.dumps(result))


if __name__ == "__main__":
    main()
