# Software for the Autonomous Robotic Observation and Behavioral Analysis system
#
# Agent high-level control
#
# Copyright 2025 Tomas Roucek 
#
# Commercial use of the software requires written consent of the copyright holders. 
#
# For Open Research and Educational use, the following applies:
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at

#   http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#!/usr/bin/env python

import numpy as np
import cv2
from geometry_msgs.msg import Point


class InteractionUtils:
    def __init__(self, config):
        """Initialize the InteractionUtils class with configuration parameters.
        
        Args:
            config: Dictionary containing configuration parameters
        """
        # Queen detection parameters
        self.min_queen_size = config['min_queen_size']
        self.max_queen_size = config['max_queen_size']
        self.stable_position_threshold = config['stable_position_threshold']
        self.stable_position_measurements = config['stable_position_measurements']
        self.image_center_threshold = config['image_center_threshold']
        self.queen_head_angle_offset = config['queen_head_angle_offset']
        # Path planning parameters
        self.queen_head_distance = config['queen_head_distance']
        self.path_search_distance = config['path_search_distance']
        self.path_resolution = config['path_resolution']
        self.heatmap_threshold = config['heatmap_threshold']
        
        # Agent parameters
        self.agent_retracted_position = config['agent_retracted_position']
        self.agent_extended_position = config['agent_extended_position']
        self.agent_position_tolerance = config['agent_position_tolerance']
    
    def calculate_queen_head_position(self, queen_position, image_width, image_height, queen_angle=None):
        """Calculate the queen's head position based on queen position and angle.
        
        Args:
            queen_position: (u, v) position of queen in image coordinates (pixels)
            image_width: Width of the image
            image_height: Height of the image
            queen_angle: Optional angle of the queen in radians (if None, uses default offset)
            
        Returns:
            tuple: (head_x_img, head_y_img) position in image coordinates
        """
        queen_x_img = int(queen_position[0])  # u coordinate
        queen_y_img = int(queen_position[1])  # v coordinate
        
        # Use provided angle or default offset
        angle = queen_angle if queen_angle is not None else self.queen_head_angle_offset
        
        # Calculate head position using angle and distance
        dx = self.queen_head_distance * image_width * np.cos(angle)
        dy = self.queen_head_distance * image_width * np.sin(angle)
        
        head_x_img = int(queen_x_img + dx)
        head_y_img = int(queen_y_img + dy)
        
        # Ensure coordinates are within image bounds
        head_x_img = max(0, min(image_width - 1, head_x_img))
        head_y_img = max(0, min(image_height - 1, head_y_img))
        
        return head_x_img, head_y_img

    def find_lowering_position(self, heatmap, queen_position, queen_angle=None):
        """Find a suitable position to lower the agent.
        
        Args:
            heatmap: OpenCV image containing bee heatmap
            queen_position: (u, v) position of queen in image coordinates (pixels)
            queen_angle: Optional angle of the queen in radians
            
        Returns:
            (x, y, z) position in pixel coordinates or None if no suitable position found, and visualization image
        """
        if heatmap is None:
            return None, None
        
        height, width = heatmap.shape[:2]
        
        # Threshold the heatmap
        _, thresholded = cv2.threshold(heatmap, self.heatmap_threshold, 255, cv2.THRESH_BINARY)
        # Convert thresholded to single layer if it's RGB
        if len(thresholded.shape) > 2:
            thresholded = cv2.cvtColor(thresholded, cv2.COLOR_BGR2GRAY)
        # Inflate obstacles to ensure safety margin
        kernel_size = int(0.20 * min(width, height))  # 10% of image size
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
        inflated = cv2.dilate(thresholded, kernel)
        
        # Create visualization
        viz_img = cv2.cvtColor(inflated, cv2.COLOR_GRAY2BGR)
        
        # Use queen position directly in image coordinates
        queen_x_img = int(queen_position[0])  # u coordinate
        queen_y_img = int(queen_position[1])  # v coordinate
        
        # Calculate queen's head position using the centralized method
        queen_head_x_img, queen_head_y_img = self.calculate_queen_head_position(
            queen_position, width, height, queen_angle)
        
        # Draw queen position and head position
        cv2.circle(viz_img, (queen_x_img, queen_y_img), 10, (0, 0, 255), -1)  # Queen in red
        cv2.circle(viz_img, (queen_head_x_img, queen_head_y_img), 8, (255, 255, 0), -1)  # Head in yellow
        
        # Find valid positions (where inflated map is 0)
        valid_positions = []
        search_radius = int(self.path_search_distance * width)
        
        step = 10  # Check every 10th pixel
        for y in range(max(0, queen_y_img - search_radius), min(height, queen_y_img + search_radius), step):
            for x in range(max(0, queen_x_img - search_radius), min(width, queen_x_img + search_radius), step):
                if inflated[y, x] == 0:  # Valid position
                    # Calculate distance to queen's head
                    dist = np.sqrt((x - queen_head_x_img)**2 + (y - queen_head_y_img)**2)
                    valid_positions.append((x, y, dist))
        
        if not valid_positions:
            return None, viz_img
            
        # Sort by distance to queen's head
        valid_positions.sort(key=lambda pos: pos[2])  # Sort by distance (third element)
        best_position = valid_positions[0]
        
        # Draw chosen position
        chosen_x =  best_position[0] 
        chosen_y = best_position[1] 
        cv2.circle(viz_img, (chosen_x, chosen_y), 8, (0, 255, 0), -1)  # Chosen position in green
        cv2.line(viz_img, (chosen_x, chosen_y), (queen_head_x_img, queen_head_y_img), (0, 255, 0), 2)
        
        return (best_position[0], best_position[1], 0.0), viz_img  # Return x,y and default z=0
    
    def plan_path_to_queen(self, heatmap, start_pos_uv, queen_position_uv, queen_z_distance, queen_angle=None):
        """Plan a path from the lowering position to the queen's head in image coordinates.
        
        Args:
            heatmap: OpenCV image containing bee heatmap (used for collision checking)
            start_pos_uv: (u, v) starting position in image coordinates (pixels)
            queen_position_uv: (u, v) position of queen in image coordinates (pixels)
            queen_z_distance: Queen's Z distance in metric coordinates (for Z height)
            queen_angle: Optional angle of the queen in radians
            
        Returns:
            List of (u, v, z) positions forming the path, where u,v are in pixels and z in meters
        """
        if heatmap is None:
            return []
        
        height, width = heatmap.shape[:2]
        
        # Calculate queen's head position in image coordinates using existing method
        queen_head_u, queen_head_v = self.calculate_queen_head_position(
            queen_position_uv, width, height, queen_angle)
        
        # Calculate number of steps based on pixel distance
        distance_pixels = np.sqrt((queen_head_u - start_pos_uv[0])**2 + 
                                (queen_head_v - start_pos_uv[1])**2)
        # Use path resolution scaled by image width to get appropriate number of steps
        steps_per_pixel = 1.0 / (self.path_resolution * width)
        num_steps = max(10, int(distance_pixels * steps_per_pixel))
        
        # Generate straight path in image coordinates
        path = []
        for i in range(num_steps + 1):
            t = i / num_steps
            u = int(start_pos_uv[0] * (1 - t) + queen_head_u * t)
            v = int(start_pos_uv[1] * (1 - t) + queen_head_v * t)
            z = queen_z_distance  # Keep Z constant during approach
            
            # Ensure coordinates are within image bounds
            u = max(0, min(width-1, u))
            v = max(0, min(height-1, v))
            
            # Check for collisions in heatmap
            if np.sum(heatmap[v, u]) >= self.heatmap_threshold:
                print(f"Warning: Path point ({u}, {v}) goes through occupied area")
            
            path.append((u, v, z))
        
        return path
    
    def visualize_planned_path(self, heatmap, queen_position_uv, lower_position_uv, planned_path):
        """Create visualization of the planned path in image coordinates.
        
        Args:
            heatmap: OpenCV image containing bee heatmap
            queen_position_uv: (u, v) position of queen in image coordinates
            lower_position_uv: (u, v, z) position to lower the agent in image coordinates
            planned_path: List of (u, v, z) positions forming the path in image coordinates
            
        Returns:
            OpenCV image with visualization
        """
        if heatmap is None or not planned_path:
            return None
        
        viz_img = heatmap.copy()
        if len(viz_img.shape) == 2:  # Convert to BGR if grayscale
            viz_img = cv2.cvtColor(viz_img, cv2.COLOR_GRAY2BGR)
        
        # Draw queen position
        cv2.circle(viz_img, (int(queen_position_uv[0]), int(queen_position_uv[1])), 
                  10, (0, 0, 255), -1)
        
        # Draw lowering position
        if lower_position_uv is not None:
            cv2.circle(viz_img, (int(lower_position_uv[0]), int(lower_position_uv[1])), 
                      8, (255, 0, 0), -1)
        
        # Draw planned path
        for i in range(1, len(planned_path)):
            u1, v1 = int(planned_path[i-1][0]), int(planned_path[i-1][1])
            u2, v2 = int(planned_path[i][0]), int(planned_path[i][1])
            cv2.line(viz_img, (u1, v1), (u2, v2), (0, 255, 0), 2)
        
        return viz_img
    
    def check_queen_stability(self, position_history):
        """Check if the queen's position is stable.
        
        Args:
            position_history: List of (x, y) positions
            
        Returns:
            bool: True if position is stable, False otherwise
        """
        if len(position_history) < self.stable_position_measurements:
            return False
        
        max_displacement = 0
        for i in range(1, len(position_history)):
            dx = position_history[i][0] - position_history[i-1][0]
            dy = position_history[i][1] - position_history[i-1][1]
            displacement = np.sqrt(dx**2 + dy**2)
            max_displacement = max(max_displacement, displacement)
        
        return max_displacement < self.stable_position_threshold
    
    def check_queen_centered(self, queen_u, queen_v, image_width, image_height):
        """Check if the queen is centered in the image.
        
        Args:
            queen_u: Queen's U coordinate in image
            queen_v: Queen's V coordinate in image
            image_width: Width of the image
            image_height: Height of the image
            
        Returns:
            bool: True if queen is centered, False otherwise
        """
        center_x = image_width / 2
        center_y = image_height / 2
        
        dx = abs(queen_u - center_x) / (image_width / 2)
        dy = abs(queen_v - center_y) / (image_height / 2)
        
        return dx < self.image_center_threshold and dy < self.image_center_threshold
    
    def check_agent_position(self, current_position, target_position):
        """Check if the agent has reached its target position.
        
        Args:
            current_position: Current agent position
            target_position: Target agent position
            
        Returns:
            bool: True if position reached, False otherwise
        """
        return abs(current_position - target_position) < self.agent_position_tolerance 
    
    def check_agent_state(self, current_fill_percentage, target_percentage, tolerance=5):
        """Check if the agent has reached its target state.
        
        Args:
            current_fill_percentage: Current agent fill percentage
            target_percentage: Target agent percentage
            tolerance: Acceptable percentage difference (default: 5%)
            
        Returns:
            bool: True if target state reached, False otherwise
        """
        return abs(current_fill_percentage - target_percentage) < tolerance
        
    def check_position_reached(self, current_x, current_y, target_x, target_y, inpos_x, inpos_y, current_z=None, target_z=None, tolerance=0.002):
        """Check if the target position has been reached.
        
        Args:
            current_x: Current X position
            current_y: Current Y position
            target_x: Target X position
            target_y: Target Y position
            inpos_x: X axis in-position flag
            inpos_y: Y axis in-position flag 
            current_z: Current Z position (optional)
            target_z: Target Z position (optional)
            tolerance: Position tolerance (default: 0.002)
            
        Returns:
            bool: True if position reached, False otherwise
        """
        # Check XY distance
        dx = abs(current_x - target_x)
        dy = abs(current_y - target_y)
        distance_ok = dx < tolerance and dy < tolerance
        xy_reached = inpos_x and inpos_y and distance_ok
        
        # Check Z if provided
        if current_z is not None and target_z is not None:
            z_reached = abs(current_z - target_z) < tolerance * 100.0
            return xy_reached and z_reached
        
        return xy_reached
