# Software for the Autonomous Robotic Observation and Behavioral Analysis system
#
# Agent ROS driver example
#
# Copyright 2025 Agent Tomas Roucek 
#
# Commercial use of the software requires written consent of the copyright holders. 
#
# For Open Research and Educational use, the following applies:
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at

#   http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#!/usr/bin/env python

import rospy
from sensor_msgs.msg import JointState
import termios
import tty
import sys
import select
from std_msgs.msg import Header

class KeyboardJointController:
    def __init__(self):
        # Initialize the ROS node
        rospy.init_node('keyboard_joint_controller', anonymous=True)
        self.init = True
        self.using_velocity = False

        # Load parameters from the ROS parameter server
        self.j1_min = rospy.get_param('~j2_min', 0.23)
        self.j1_max = rospy.get_param('~j2_max', 2.4)
        self.j2_min = rospy.get_param('~j1_min', 0.0)
        self.j2_max = rospy.get_param('~j1_max', 4.1 * 3)

        self.j1_pose_min = rospy.get_param('~j2_min', 0.23)
        self.j1_pose_max = rospy.get_param('~j2_max', 2.4)
        self.j2_pose_min = rospy.get_param('~j1_min', 0.0)
        self.j2_pose_max = rospy.get_param('~j1_max', 4.1 * 3)

        self.j1_velocity_min = rospy.get_param('~j1_velocity_min', -1.75)
        self.j1_velocity_max = rospy.get_param('~j1_velocity_max', 0.1)
        # Velocity control limits
        self.j2_velocity_min = rospy.get_param('~j2_velocity_min', -1.5)
        self.j2_velocity_max = rospy.get_param('~j2_velocity_max', 9)


        # Increment and correction factors
        self.increment = rospy.get_param('~increment', 0.01)
        self.correction1 = rospy.get_param('~correction1', 10)
        self.correction2 = rospy.get_param('~correction2', 30)

        # Initialize joint state
        self.joint_names = ['joint_2', 'joint_1']
        self.joint_positions = [0.0, 0.0]
        self.joint_velocity = [0.0, 0.0]
        self.desired_velocity = [0.0, 0.0]

        # Initialize publisher and subscriber
        self.joint_state_pub = rospy.Publisher('/desired_joint_states', JointState, queue_size=10)
        rospy.Subscriber('/joint_states', JointState, self.pose_callback)

        rospy.loginfo("Keyboard Joint Controller Node Initialized.")
        rospy.loginfo("Use W/S for joint 1, A/D for joint 2.")
        self.settings = termios.tcgetattr(sys.stdin)

    def pose_callback(self, joint_msg):
        if self.init or self.using_velocity:
            self.joint_positions[0] = joint_msg.position[0]
            self.joint_positions[1] = joint_msg.position[1]
            self.init = False

    def get_key(self):
        tty.setraw(sys.stdin.fileno())
        rlist, _, _ = select.select([sys.stdin], [], [], 0.1)
        if rlist:
            key = sys.stdin.read(1)
        else:
            key = ''
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self.settings)
        return key

    def update_limits(self):
        if not self.using_velocity:
            self.j1_min = self.j1_pose_min
            self.j1_max = self.j1_pose_max
            self.j2_min = self.j2_pose_min
            self.j2_max = self.j2_pose_max
        else:
            self.j1_min = self.j1_velocity_min
            self.j1_max = self.j1_velocity_max
            self.j2_min = self.j2_velocity_min
            self.j2_max = self.j2_velocity_max


        # Apply position limits
        if not self.using_velocity:
            self.joint_positions[0] = max(self.j1_min, min(self.joint_positions[0], self.j1_max))
            self.joint_positions[1] = max(self.j2_min, min(self.joint_positions[1], self.j2_max))
        else:
            # Joint 1 velocity control with limits
            if self.joint_positions[0] + self.desired_velocity[0] >= self.j1_max:
                self.joint_velocity[0] = -(self.joint_positions[0] - self.j1_max) * (self.desired_velocity[0] + self.increment * self.correction1)
                #self.desired_velocity[0] = -self.joint_positions[0] + self.j1_max
            elif self.joint_positions[0] + self.desired_velocity[0] <= self.j1_min:
                self.joint_velocity[0] = -(self.joint_positions[0] - self.j1_min) * (self.desired_velocity[0] + self.increment * self.correction1)
                #self.desired_velocity[0] = -self.joint_positions[0] + self.j1_min
            else:
                self.joint_velocity[0] = self.desired_velocity[0]

            # Joint 2 velocity control with limits
            if self.joint_positions[1] + self.desired_velocity[1] >= self.j2_max:
                self.joint_velocity[1] = -(self.joint_positions[1] - self.j2_max) * (self.desired_velocity[1] + self.increment * self.correction2)
                #self.desired_velocity[1] = -self.joint_positions[1] + self.j2_max
            elif self.joint_positions[1] + self.desired_velocity[1] <= self.j2_min:
                self.joint_velocity[1] = -(self.joint_positions[1] - self.j2_min) * (self.desired_velocity[1] + self.increment * self.correction2)
                #self.desired_velocity[1] = -self.joint_positions[1] + self.j2_min
            else:
                self.joint_velocity[1] = self.desired_velocity[1]
            print(self.desired_velocity, self.joint_velocity, self.j1_velocity_max)
            print(self.joint_positions[0],self.j1_velocity_min,  self.j1_velocity_max)
            print(self.joint_positions[1] , self.j2_velocity_min, self.j2_velocity_max)

    def publish_joint_state(self):
        joint_state_msg = JointState()
        joint_state_msg.header = Header()
        joint_state_msg.header.stamp = rospy.Time.now()
        joint_state_msg.name = self.joint_names
        joint_state_msg.position = self.joint_positions
        joint_state_msg.velocity = self.joint_velocity
        self.joint_state_pub.publish(joint_state_msg)

    def start(self):
        rate = rospy.Rate(50)  # 50 Hz
        while not rospy.is_shutdown():
            if self.init:
                rospy.loginfo("Waiting for initial joint state...")
                rate.sleep()
                continue
            key = self.get_key()
            self.update_joint_state(key)
            self.publish_joint_state()
            rate.sleep()

if __name__ == '__main__':
    try:
        controller = KeyboardJointController()
        controller.start()
    except rospy.ROSInterruptException:
        pass
