# Software for the Autonomous Robotic Observation and Behavioral Analysis system
#
# Actuator firmware and driver
#
# Copyright 2025 Fatemeh Rekabi-Bana, Tomas Krajnik 
#
# Commercial use of the software requires written consent of the copyright holders. 
#
# For Open Research and Educational use, the following applies:
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at

#   http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 10 10:01:49 2023

@author: frekabi
"""

import os
import sys
script_path = os.path.abspath(os.path.dirname(__file__))

if script_path not in sys.path:
    sys.path.append(script_path)

import rospy
from geometry_msgs.msg import PoseStamped, Pose, Twist 
from sensor_msgs.msg  import JointState
import sys, select, os
import tty, termios
from std_msgs.msg import String, Int32MultiArray, Header 
import socket
import struct
import time
import serial
import math
from threading import Thread
from std_srvs.srv import Trigger, TriggerResponse
# from actionlib_msgs.msg import GoalStatusArray
from rr_msgs.msg import ActuatorStatus
#from rr_xy_operator.srv import ResetPosition,ResetPositionResponse
from rr_msgs.srv import ResetPosition,ResetPositionResponse
import numpy as np
#from Vib_ctl import Vib_ctl

rospy.init_node("rr_xytable")
 
#UDP_IP = "127.0.0.1"
#UDP_PORT = 5005
#MESSAGE = b"Hello, World!"
   
# print("UDP target IP: %s" % UDP_IP)
# print("UDP target port: %s" % UDP_PORT)
# print("message: %s" % MESSAGE)
#UDP_IP = "192.168.1.200"
#UDP_PORT = 8745

Port = '/dev/xy_table_0'
    
class RR_XY_TABLE:
    
    def __init__(self,Port):
        
        self.ser_port = Port #int(Port)
        sys_info_file = rospy.get_param('~sys_info_file', 'None')
        self.ser_port = rospy.get_param('~serial_port', '/dev/ttyUSB0')

        #self.udp_ip = IP
        self.ser = serial.Serial(port=self.ser_port,
                        baudrate=115200,
                        parity=serial.PARITY_NONE,
                        stopbits=serial.STOPBITS_ONE,
                        bytesize=serial.EIGHTBITS)
        #self.ser.BAUDRATES = 115200
        self.ser.timeout = 0.00
        self.ser.close()
        self.ser.open()
        print('is open: ' + str(self.ser.is_open))
        print('baudrate: ' + str(self.ser.baudrate))
        print('name: ' + str(self.ser.name)) 
        print('bytesize: ' + str(self.ser.bytesize))
        print('parities: ' + str(self.ser.parity))
        print('stopbits: ' + str(self.ser.stopbits))
        
        time.sleep(5)
        self.rate = rospy.Rate(20)  #(Hz)#
        #self.vctl = Vib_ctl(sys_info_file)
        
        self.key = 'r'
        self.node_work = True
        self.xpos_command = 0 #mm
        self.ypos_command = 0 #mm
        self.xpos_vel = 10 # mm/s
        self.ypos_vel = 10 # mm/s
        self.xpos_acc = 5 #mm/s^2
        self.ypos_acc = 5 #mm/s^2
        self.dt = 0.05
        self.xpos_max = 580 #mm
        self.xpos_min = 0 #mm
        self.ypos_max = 610 #mm
        self.ypos_min = 0 #mm
        
        self.xpos_lmswitch_left = 579.99 #mm
        self.ypos_lmswitch_top = 609.99 #mm
        self.xpos_lmswitch_right = 0.01 #mm
        self.ypos_lmswitch_bottom = 0.01 #mm
        
        self.xpos_offset = 0
        self.ypos_offset = 0
        self.xpos_upper_bound = 0.58 # m
        self.ypos_upper_bound = 0.61
        print("1st x command is: %f  and y command is %f" %(self.xpos_command,self.ypos_command))
        
        self.MA_motion_sub = rospy.Subscriber("/RR_gs/ma_xmotion_command", Pose, self.motion_calback, queue_size=1) # The ros topic for the main actuator X and Y direction command
        self.MA_vel_sub = rospy.Subscriber("/RR_gs/ma_vel_motion_command", Twist, self.vel_calback, queue_size=1)
        self.MA_vel_sub = rospy.Subscriber("calib_vel", Twist, self.calibvel_calback, queue_size=1)
        self.MA_vel_sub = rospy.Subscriber("xy_cmd", JointState, self.cmd_com_calback, queue_size=1)
        print("2nd x command is: %f  and y command is %f" %(self.xpos_command,self.ypos_command))
        
        self.MA_rec_pos_pub = rospy.Publisher("/RR_gs/ma_rc_pos", PoseStamped, queue_size=1)
        self.calib_pos_pub = rospy.Publisher("xy_position", PoseStamped, queue_size=1)
        self.danode_pub = rospy.Publisher("xy_data", JointState, queue_size = 1)
        self.limit_switch_pub = rospy.Publisher("xy_status", ActuatorStatus, queue_size=1)
        #self.Ma_rec_pos_msg = PoseStamped(header=Header(stamp=rospy.Time.now()))
        self.Ma_rec_pos_msg = PoseStamped()
        self.Ma_rec_pos_msg.header.stamp = rospy.Time.now()
        self.xy_data_msg = JointState()
        self.xy_data_msg.header.stamp = rospy.Time.now()
        
        reset_service = rospy.Service('rr_xy_reset', ResetPosition, self.reset_serv_resp)
       
        self.sock = socket.socket(socket.AF_INET, # Internet
                       socket.SOCK_DGRAM) # UDP
        self.data = struct.pack("<ffffff", float(self.xpos_command), float(self.xpos_vel), float(self.xpos_acc), float(self.ypos_command), float(self.ypos_vel), float(self.ypos_acc))
        print("3rd x command is: %f  and y command is %f" %(self.xpos_command,self.ypos_command))
        self.settings = termios.tcgetattr(sys.stdin)
        
        self.chksum = 0
        self.h1 = 0xff
        self.h2 = 0xfe
        self.h3 = 0xfd
        
        self.h1_err_reset = 0xff
        self.h2_err_reset = 0xfe
        self.h3_err_reset = 0xfe
        
        self.LMSW_X_S = False
        self.LMSW_X_E = False
        self.LMSW_Y_S = False
        self.LMSW_Y_E = False
        
        self.act_status = ActuatorStatus()
        # self.act_status.header.stamp = rospy.Time.now()
        
        
        self.data = struct.pack("3B 6f B", self.h1, self.h2, self.h3, (self.xpos_command), (self.xpos_vel), (self.xpos_acc), (self.ypos_command), (self.ypos_vel), (self.ypos_acc), self.chksum)
        print(self.data)
        
        self.calib_pos_x = self.xpos_max * 1
        self.calib_pos_y = 0
        
        
        self.Ma_rec_pos_msg.pose.position.x = 0 # 0.36 - rec_X/1000 #for left to right motion
        self.Ma_rec_pos_msg.pose.position.y = 0 # rec_Y/1000
        
        self.stop_thread = Thread(target=self.stopkey, args=())
        self.stop_thread.daemon = True
        self.stop_thread.start()
        self.packet_health_check = 0
        
        self.corr_srv = False
        self.data_corr = []
        
        
        self.xpos_state_4 = [0,0,0,0]
        self.ypos_state_4 = [0,0,0,0]
        
        
        
    def stopkey(self):
        while True:
            self.getKey()
            if self.key =='c':
                break
            time.sleep(1)
        
        
    
    def motion_calback(self,msg):
        self.xpos_command = msg.position.x * 1000 # convert from m to mm
        self.ypos_command = msg.position.y * 1000 # convert from m to mm
        #print("4th x command is: %f  and y command is %f" %(self.xpos_command,self.ypos_command))
        #self.data = struct.pack("<ffff", float(self.xpos_command), float(self.xpos_vel), float(self.ypos_command), float(self.ypos_vel))
        #print(self.data)
        
    def vel_calback(self,msg):
        self.xpos_vel = msg.linear.x * 1000 # convert from m/s to mm/s
        self.ypos_vel = msg.linear.y * 1000 # convert from m/s to mm/s
        #print("4th x command is: %f  and y command is %f" %(self.xpos_vel,self.ypos_vel))
        #print("4th x com is: %.2f\t  and y com is: %.2f\t vel_x com is: %.2f\t vel_y com is: %.2f\t\r\n" %(self.xpos_command,self.ypos_command,self.xpos_vel,self.ypos_vel))
        #print(self.data)
        
    def calibvel_calback(self,msg):
        # Check the message if it is not in m/s
        
        calib_vel_x = msg.linear.x * 1000 
        calib_vel_y = msg.linear.y * 1000
        
        #print('xdes is: %.2f\t  ydes is: %.2f\t'%(calib_vel_x*self.dt,calib_vel_y*self.dt))
        
        calib_pos_x = self.calib_pos_x -  calib_vel_x * self.dt
        calib_pos_y = self.calib_pos_y +  calib_vel_y * self.dt
        
       
        
        
        
        self.calib_pos_x = calib_pos_x #self.limit_func(calib_pos_x + self.xpos_offset, self.xpos_max, self.xpos_min)
        self.calib_pos_y = calib_pos_y #self.limit_func(calib_pos_y + self.ypos_offset, self.ypos_max, self.ypos_min)
        
        self.xpos_vel = self.limit_func(abs(calib_vel_x), 20, 0)
        self.ypos_vel = self.limit_func(abs(calib_vel_y), 20, 0)
        
        self.xpos_command = self.calib_pos_x
        self.ypos_command = self.calib_pos_y
        
        self.xpos_command = self.limit_func(self.xpos_command , self.xpos_max, self.xpos_min)
        self.ypos_command = self.limit_func(self.ypos_command , self.ypos_max, self.ypos_min)
        print('current calibration pos in y: %f' %(self.ypos_command))
        
        #print("4th x com is: %.2f\t  and y com is: %.2f\t vel_x com is: %.2f\t vel_y com is: %.2f\t\r\n" %(self.xpos_command,self.ypos_command,self.xpos_vel,self.ypos_vel))
        
        #print(self.data)
        
    def cmd_com_calback(self,msg):
        
        calib_vel = msg.velocity
        calib_vel_x = calib_vel[0] * 1000
        calib_vel_y = calib_vel[1] * 1000
        
        calib_pos = msg.position
        calib_pos_x = (self.xpos_upper_bound * 1000) - (calib_pos[0] * 1000)
        calib_pos_y = calib_pos[1] * 1000
        
        
        self.calib_pos_x = calib_pos_x #self.limit_func(calib_pos_x + self.xpos_offset, self.xpos_max, self.xpos_min)
        self.calib_pos_y = calib_pos_y #self.limit_func(calib_pos_y + self.ypos_offset, self.xpos_min, self.ypos_min)
        
        self.xpos_vel = self.limit_func(abs(calib_vel_x), 20, 0)
        self.ypos_vel = self.limit_func(abs(calib_vel_y), 20, 0)
        
        self.xpos_command = self.calib_pos_x
        self.ypos_command = self.calib_pos_y
        self.xpos_command = self.limit_func(self.xpos_command , self.xpos_max, self.xpos_min)
        self.ypos_command = self.limit_func(self.ypos_command , self.ypos_max, self.ypos_min)
        
    def start(self) :
        self.ser.flush()
        #rospy.Timer(rospy.Duration(self.dt),self.ros_publisher)
        while self.node_work : 
            
            pos_x=self.xpos_state_4
            # ux = self.vctl.x_state_update(pos_x)
            ux = 3 #ux
            self.xpos_acc = self.limit_func(ux, 10, -2) + 3
            
            #print('input acc in x is %f mm/s2\t\n'%(self.xpos_acc))
            
            pos_y=self.ypos_state_4
            # uy = self.vctl.y_state_update(pos_y)
            uy = 3 #uy
            self.ypos_acc = self.limit_func(uy, 10, -2) + 3
            #print('input acc in y is %f mm/s2\t\n'%(self.ypos_acc))
            
            # Data packing and checksum calculation
            ba = bytearray(struct.pack("6f", (self.xpos_command), (self.xpos_vel), (self.xpos_acc), (self.ypos_command), (self.ypos_vel), (self.ypos_acc)))
            #print([ "0x%02x" % b for b in ba ])
            s = 0
            for b in ba:
                s = s + b
                #print(s)
                
            self.chksum = s%256 
            self.data = struct.pack("<3B 6f B", self.h1, self.h2, self.h3, float(self.xpos_command), float(self.xpos_vel), float(self.xpos_acc), float(self.ypos_command), float(self.ypos_vel), float(self.ypos_acc), self.chksum)
            ba2 = bytearray(self.data)
            #print([ "0x%02x" % b for b in ba2 ])
            #print('checksum is: ' + str(hex(self.chksum)))
            #elf.sock.sendto(self.data, (self.udp_ip, self.udp_port))
            
            self.ser.write(self.data)
            print('Command %6.3f %6.3f %6.3f %6.3f. ' %(self.xpos_command,self.ypos_command,self.xpos_vel,self.ypos_vel),end = '')
            time.sleep(0.01)
            self.serial_readout()
            
            
            if self.corr_srv:
                time.sleep(0.01)
                self.ser.write(self.data_corr)
                self.corr_srv = False
                time.sleep(0.01)
                self.serial_readout()
                
                
                
            #print('serial write - done')
            #self.ser.close()
            #self.ser.open()
            # Data unpack and set into the message structure for ros publish
            
            
            
            
            if self.key == 'c' :
                self.node_work = False
                self.ser.close()
                break
            else :
                self.node_work = True
            #time.sleep(self.dt)
            #print('End of loop ')
            #time.sleep(self.dt)
            self.rate.sleep()
            # rospy.spin()
            
            
        self.ser.close()
            
    def getKey(self):
        tty.setraw(sys.stdin.fileno())
        rlist, _, _ = select.select([sys.stdin], [], [], 1)
        if rlist:
            self.key = sys.stdin.read(1)
        else:
            self.key = ''
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self.settings)
        #print('key is: ' + self.key)
        
        
    def limit_func(self,inp,mx,mn):
        if inp < mn :
            return mn
        elif inp > mx :
            return mx
        else :
            return inp
        
    def ros_publisher(self,event):        
        self.MA_rec_pos_pub.publish(self.Ma_rec_pos_msg)
        self.calib_pos_pub.publish(self.Ma_rec_pos_msg)
        
    def reset_serv_resp(self,req):
        # self.xpos_command = self.xpos_upper_bound * 1000
        # self.ypos_command = 0
        # self.xpos_vel = 5
        # self.ypos_vel = 5
        print('The reset service is called')
        self.corr_srv = True
        
        x_correction = (self.xpos_upper_bound * 1000) - (req.x * 1000)
        y_correction = req.y * 1000
        self.xpos_command = (self.xpos_upper_bound * 1000) - (req.x * 1000)
        self.ypos_command = req.y * 1000
        
        ba = bytearray(struct.pack("2f", (x_correction), (y_correction)))
        #print([ "0x%02x" % b for b in ba ])
        s = 0
        for b in ba:
            s = s + b
            #print(s)
        
            
        self.chksum_corr = s%256 
        self.data_corr = struct.pack("<3B 2f B", self.h1_err_reset, self.h2_err_reset, self.h3_err_reset, float(x_correction), float(y_correction), self.chksum_corr)
        #ba2 = bytearray(self.data)
        #print([ "0x%02x" % b for b in ba2 ])
        #print('checksum is: ' + str(hex(self.chksum)))
        #elf.sock.sendto(self.data, (self.udp_ip, self.udp_port))
        
        
        #print('Data packet is set \r\n')
        
        #self.ser.write(self.data)
        
        time.sleep(0.01)
        
        response = ResetPositionResponse()
        response.xMeasure = self.Ma_rec_pos_msg.pose.position.x
        response.yMeasure = self.Ma_rec_pos_msg.pose.position.y
        
        # rec_packet = self.ser.read(12)
        # ba3 = bytearray(rec_packet)
        
        # if len(bytearray(rec_packet)) > 10:
        #     print('The packet length is ok: %f'%(len(bytearray(rec_packet))))
        #     self.packet_health_check = 0
        #     rec_h1, rec_h2, rec_h3, rec_X, rec_Y, rec_chksum = struct.unpack('<3B 2f B', rec_packet) 
        #     s1 = 0
        #     for b in ba3[3:11]:
        #         s1 = s1 + b
                
        #     rec_chksum_check = s1%256
        #     #print('Calculated Checksum is: %f, and received Checksum is: %f'%(rec_chksum_check,rec_chksum))
                
        #     if rec_chksum  == rec_chksum_check:
        #         print('checksum is ok')
        #         self.Ma_rec_pos_msg.pose.position.x = self.xpos_upper_bound - rec_X/1000 #self.xpos_max/1000 - rec_X/1000 #for left to right motion
        #         self.Ma_rec_pos_msg.pose.position.y = rec_Y/1000
        #         self.MA_rec_pos_pub.publish(self.Ma_rec_pos_msg)
        #         self.calib_pos_pub.publish(self.Ma_rec_pos_msg)
        #     else :
        #         print('The error has occured and the Calculated Checksum is: %f, and received Checksum is: %f'%(rec_chksum_check,rec_chksum))
        
        return response
        #print('The response has sent\r\n')
    
    def serial_readout(self):
        
        #time.sleep(0.01)
        rec_packet = self.ser.read(12)
        ba3 = bytearray(rec_packet)
        #if len(bytearray(rec_packet)) > 10:
        #    print("\n",ba3[0],ba3[1],ba3[2],ba3[3],ba3[4],ba3[5],ba3[6],ba3[7],ba3[8],ba3[9],ba3[10],ba3[11],"\n")
        print('Serial read ' + str(len(bytearray(rec_packet))) + ' bytes ',end = '')
        try:
            if len(bytearray(rec_packet)) > 10:
                rec_h1, rec_h2, rec_h3, rec_X, rec_Y, rec_chksum = struct.unpack('<3B 2f B', rec_packet) 
                print("%x %x %x %3.2f %3.2f %x. Comm errors %i"%(rec_h1,rec_h2,rec_h3,rec_X,rec_Y,rec_chksum,self.packet_health_check),end = '\r\n');
                s1 = 0
                for b in ba3[3:11]:
                    s1 = s1 + b
                    
                rec_chksum_check = s1%256
                #print('Calculated Checksum is: %f, and received Checksum is: %f'%(rec_chksum_check,rec_chksum))
                    
                if rec_chksum  == rec_chksum_check:
                    #print('right LS condition is: %f'%(self.act_status.switchRight))
                    # print('left LS condition is: %f'%(self.act_status.switchLeft))
                    # print('bottom LS condition is: %f'%(self.act_status.switchBottom))
                    # print('top LS condition is: %f'%(self.act_status.switchTop))
                    
                    if self.xpos_upper_bound*1000 - rec_X <= 0.01: #self.xpos_lmswitch_left:
                        #print('The value for right switch condition is: %f \n'%(self.xpos_upper_bound*1000 - rec_X))
                        #self.LMSW_X_S = True
                        self.act_status.switchRight =True
                    else:
                        self.LMSW_X_S = False
                        self.act_status.switchRight = False
                        #print('The value for right switch condition is: %f \n'%(self.xpos_upper_bound*1000 - rec_X))
                        
                    if self.xpos_upper_bound - rec_X >= 579.99: #self.xpos_lmswitch_right:
                        # self.LMSW_X_E = True
                        self.act_status.switchLeft = True
                        #print('The value for left switch condition is: %f \n'%(self.xpos_upper_bound*1000 - rec_X))
                    else:
                        # self.LMSW_X_E = False
                        self.act_status.switchLeft = False
                        #print('The value for right switch condition is: %f \n'%(self.xpos_upper_bound*1000 - rec_X))
                        
                    if rec_Y <= 0.01: #self.ypos_lmswitch_bottom:
                        # self.LMSW_Y_S = True
                        self.act_status.switchBottom = True
                        #print('The value for bottomswitch condition is: %f \n'%(rec_Y))
                    else:
                        # self.LMSW_Y_S = False
                        self.act_status.switchBottom = False
                        #print('The value for bottomswitch condition is: %f \n'%(rec_Y))
                        
                    if rec_Y >= 609.99:# self.ypos_lmswitch_top:
                        # self.LMSW_Y_E = True
                        self.act_status.switchTop = True
                        #print('The value for bottomswitch condition is: %f \n'%(rec_Y))
                    else:
                        # self.LMSW_Y_E = False
                        self.act_status.switchTop = False
                        #print('The value for bottomswitch condition is: %f \n'%(rec_Y))
                        
                    # self.act_status.switchRight = self.LMSW_X_S 
                    # self.act_status.switchLeft = self.LMSW_X_E
                    # self.act_status.switchBottom = self.LMSW_Y_S 
                    # self.act_status.switchTop = self.LMSW_Y_E 
                    # self.act_status.header.stamp = rospy.Time.now()
                    self.limit_switch_pub.publish(self.act_status)
                    
                    self.Ma_rec_pos_msg.pose.position.x = self.xpos_upper_bound - rec_X/1000 #self.xpos_max/1000 - rec_X/1000 #for left to right motion
                    self.Ma_rec_pos_msg.pose.position.y = rec_Y/1000
                    self.Ma_rec_pos_msg.header.stamp = rospy.Time.now()
                    #print('time stamp is: %f'%(self.Ma_rec_pos_msg.header.stamp))
                    #self.MA_rec_pos_msg.header.stamp = rospy.Time.now()
                    self.MA_rec_pos_pub.publish(self.Ma_rec_pos_msg)
                    self.calib_pos_pub.publish(self.Ma_rec_pos_msg)
                    #print('right LS condition is: %f'%(self.act_status.switchRight))
                    #print('left LS condition is: %f'%(self.act_status.switchLeft))
                    #print('bottom LS condition is: %f'%(self.act_status.switchBottom))
                    #print('top LS condition is: %f'%(self.act_status.switchTop))
                    self.xpos_state_4[1:4] = self.xpos_state_4[0:3]
                    self.xpos_state_4[0] = self.xpos_upper_bound * 1000  - rec_X 
                    
                    self.ypos_state_4[1:4] = self.ypos_state_4[0:3]
                    self.ypos_state_4[0] = rec_Y 
                    
                    self.xy_data_msg.position = [self.xpos_upper_bound - rec_X/1000,rec_Y/1000]
                    self.xy_data_msg.velocity = [self.xpos_vel, self.ypos_vel]
                    self.xy_data_msg.effort = [self.xpos_acc, self.ypos_acc]
                    
                    self.danode_pub.publish(self.xy_data_msg)
                else :
                    print('The error has occured and the Calculated Checksum is: %f, and received Checksum is: %f'%(rec_chksum_check,rec_chksum),end = '\r\n')
                    self.packet_health_check = self.packet_health_check + 1
                    rec_packet = self.ser.read(10000)
                    print('Flushed ' + str(len(bytearray(rec_packet))) + ' bytes ',end = '\r\n')
                        
        except:
            self.packet_health_check = self.packet_health_check + 1
            print('The received packet is corrupted for %f times'%(self.packet_health_check))
            rec_packet = self.ser.read(10000)
            print('Flushed ' + str(len(bytearray(rec_packet))) + ' bytes ',end = '\r\n')
            
            
            
            #print('current pos message is published')
           
        
        
    
        
    

            

            
if __name__ == '__main__':
    #communication = RR_XY_TABLE(sys.argv[1],sys.argv[2])
    communication = RR_XY_TABLE(Port)
    communication.start()
    
        
        
        
    
            
        
        
        
        
    
