# Software for the Autonomous Robotic Observation and Behavioral Analysis system
#
# Agent ROS driver example
#
# Copyright 2025 Agent Tomas Roucek 
#
# Commercial use of the software requires written consent of the copyright holders. 
#
# For Open Research and Educational use, the following applies:
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at

#   http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import rospy
from sensor_msgs.msg import JointState
from std_msgs.msg import Header

class JointController:
    def __init__(self,  increment):
        # Initialize the ROS node
        self.init = True
        # Load parameters from the ROS parameter server
        self.min_limit = [rospy.get_param('~linear_min', 0.23), rospy.get_param('~rotation_min', 0.23)]
        self.max_limit = [rospy.get_param('~linear_max', 0.23),rospy.get_param('~rotation_max', 0.23)]


        # Increment and correction factors
        self.increment = increment #speed modifier
        self.max_speed = [rospy.get_param('~max_speed_lin', 0.23),rospy.get_param('~max_speed_rot', 0.23)]

        # Initialize joint state
        self.joint_names = ['linear', 'rotation']
        self.joint_positions = [0.0, 0.0]
        self.joint_velocity = [0.0, 0.0]
        self.desired_velocity = [0.0, 0.0]
        self.desired_positions = [0.0, 0.0]
        self.desired_max_speed = [0.0, 0.0]

        # Initialize publisher and subscriber
        self.joint_state_pub = rospy.Publisher('/desired_joint_states', JointState, queue_size=10)
        rospy.Subscriber('/joint_states', JointState, self.pose_callback)
        rospy.Subscriber("xy_cmd", JointState, self.position_callback, queue_size=1)

    def pose_callback(self, joint_msg):
        self.joint_positions[0] = joint_msg.position[0]
        self.joint_positions[1] = joint_msg.position[1]
        self.init = False


    def position_callback(self, joint_msg):
        desired_position = [joint_msg.position[2], joint_msg.position[3]]
        desired_velocity = [joint_msg.velocity[2], joint_msg.velocity[3]]
        self.desired_positions = desired_position
        self.desired_velocity = desired_velocity


    def publish_joint_state(self):
        joint_state_msg = JointState()
        joint_state_msg.header = Header()
        joint_state_msg.header.stamp = rospy.Time.now()
        joint_state_msg.name = self.joint_names
        joint_state_msg.position = self.joint_positions
        joint_state_msg.velocity = self.joint_velocity
        self.joint_state_pub.publish(joint_state_msg)

    def update_joint_state(self):
        # PID gains
        Kp = 5.0
        Ki = 0.1
        Kd = 0.5

        if not hasattr(self, 'integral'):
            self.integral = [0.0, 0.0]
            self.prev_error = [0.0, 0.0]
            self.last_time = rospy.Time.now()

        current_time = rospy.Time.now()
        dt = (current_time - self.last_time).to_sec()
        if dt == 0:
            return
        self.last_time = current_time

        for i in range(2):
            # Get effective speed limit (minimum of max_speed and desired_velocity)
            speed_limit = min(self.max_speed[i], abs(self.desired_velocity[i])) if self.desired_velocity[i] != 0 else self.max_speed[i]

            error = self.desired_positions[i] - self.joint_positions[i]

            self.integral[i] += error * dt
            self.integral[i] = max(min(self.integral[i], speed_limit), -speed_limit)

            derivative = (error - self.prev_error[i]) / dt if dt > 0 else 0
            self.prev_error[i] = error

            output = Kp * error + Ki * self.integral[i] + Kd * derivative

            # Limit output to effective speed limit
            self.joint_velocity[i] = max(min(output, speed_limit), -speed_limit)

            if abs(error) < 0.001:
                self.joint_velocity[i] = 0.0
                self.integral[i] = 0.0



    def start(self):
        rate = rospy.Rate(50)  # 50 Hz
        while not rospy.is_shutdown():
            if self.init:
                rospy.loginfo("Waiting for initial joint state...")
                rate.sleep()
                continue
            self.update_joint_state()
            self.publish_joint_state()
            rate.sleep()