refactor: (headless) cleanup mainly arm

This commit is contained in:
David Sharpe
2026-03-16 00:26:35 -05:00
parent 62fd1b110d
commit 292b3a742d

View File

@@ -6,23 +6,28 @@ from rclpy.duration import Duration
import signal import signal
import time import time
import atexit
import os import os
import sys import sys
import threading
import glob
import pwd import pwd
import grp import grp
from math import copysign from math import copysign
from std_srvs.srv import Trigger from std_srvs.srv import Trigger
from std_msgs.msg import String from std_msgs.msg import Header
from geometry_msgs.msg import Twist, TwistStamped from geometry_msgs.msg import Twist, TwistStamped
from control_msgs.msg import JointJog from control_msgs.msg import JointJog
from astra_msgs.msg import CoreControl, ArmManual, BioControl from astra_msgs.msg import CoreControl, ArmManual, BioControl
from astra_msgs.msg import CoreCtrlState from astra_msgs.msg import CoreCtrlState
import warnings
# Literally headless
warnings.filterwarnings(
"ignore",
message="Your system is avx2 capable but pygame was not built with support for it.",
)
import pygame import pygame
os.environ["SDL_VIDEODRIVER"] = "dummy" # Prevents pygame from trying to open a display os.environ["SDL_VIDEODRIVER"] = "dummy" # Prevents pygame from trying to open a display
@@ -49,6 +54,7 @@ control_qos = qos.QoSProfile(
STICK_DEADZONE = float(os.getenv("STICK_DEADZONE", "0.05")) STICK_DEADZONE = float(os.getenv("STICK_DEADZONE", "0.05"))
ARM_DEADZONE = float(os.getenv("ARM_DEADZONE", "0.2"))
class Headless(Node): class Headless(Node):
@@ -70,6 +76,14 @@ class Headless(Node):
pygame.joystick.init() pygame.joystick.init()
super().__init__("headless") super().__init__("headless")
# TODO: move the STOP_MSGs somewhere better
global ARM_STOP_JOG_MSG
ARM_STOP_JOG_MSG = JointJog(
header=Header(frame_id="base_link", stamp=self.get_clock().now().to_msg()),
joint_names=self.all_joint_names,
velocities=[0.0] * len(self.all_joint_names),
)
################################################## ##################################################
# Preamble # Preamble
@@ -131,11 +145,15 @@ class Headless(Node):
self.get_parameter("use_old_topics").get_parameter_value().bool_value self.get_parameter("use_old_topics").get_parameter_value().bool_value
) )
self.declare_parameter("use_bio", False)
self.use_bio = self.get_parameter("use_bio").get_parameter_value().bool_value
self.declare_parameter("arm_mode", "manual") self.declare_parameter("arm_mode", "manual")
self.arm_mode = ( self.arm_mode = (
self.get_parameter("arm_mode").get_parameter_value().string_value self.get_parameter("arm_mode").get_parameter_value().string_value
) )
# NOTE: only applicable if use_old_topics == True
self.declare_parameter("arm_manual_scheme", "old") self.declare_parameter("arm_manual_scheme", "old")
self.arm_manual_scheme = ( self.arm_manual_scheme = (
self.get_parameter("arm_manual_scheme").get_parameter_value().string_value self.get_parameter("arm_manual_scheme").get_parameter_value().string_value
@@ -168,23 +186,24 @@ class Headless(Node):
################################################## ##################################################
# New Topics # New Topics
self.core_twist_pub_ = self.create_publisher( if not self.use_old_topics:
Twist, "/core/twist", qos_profile=control_qos self.core_twist_pub_ = self.create_publisher(
) Twist, "/core/twist", qos_profile=control_qos
self.core_state_pub_ = self.create_publisher( )
CoreCtrlState, "/core/control/state", qos_profile=control_qos self.core_state_pub_ = self.create_publisher(
) CoreCtrlState, "/core/control/state", qos_profile=control_qos
)
self.arm_manual_pub_ = self.create_publisher( self.arm_manual_pub_ = self.create_publisher(
JointJog, "/arm/manual_new", qos_profile=control_qos JointJog, "/arm/manual_new", qos_profile=control_qos
) )
self.arm_ik_twist_publisher = self.create_publisher( self.arm_ik_twist_publisher = self.create_publisher(
TwistStamped, "/servo_node/delta_twist_cmds", qos_profile=control_qos TwistStamped, "/servo_node/delta_twist_cmds", qos_profile=control_qos
) )
self.arm_ik_jointjog_publisher = self.create_publisher( self.arm_ik_jointjog_publisher = self.create_publisher(
JointJog, "/servo_node/delta_joint_cmds", qos_profile=control_qos JointJog, "/servo_node/delta_joint_cmds", qos_profile=control_qos
) )
################################################## ##################################################
# Timers # Timers
@@ -223,6 +242,7 @@ class Headless(Node):
self.bio_publisher.publish(BIO_STOP_MSG) self.bio_publisher.publish(BIO_STOP_MSG)
else: else:
self.core_twist_pub_.publish(CORE_STOP_TWIST_MSG) self.core_twist_pub_.publish(CORE_STOP_TWIST_MSG)
self.arm_manual_pub_.publish(ARM_STOP_JOG_MSG)
def send_controls(self): def send_controls(self):
"""Read the gamepad state and publish control messages""" """Read the gamepad state and publish control messages"""
@@ -255,26 +275,35 @@ class Headless(Node):
self.gamepad.rumble(0.6, 0.7, 75) self.gamepad.rumble(0.6, 0.7, 75)
self.ctrl_mode = new_ctrl_mode self.ctrl_mode = new_ctrl_mode
self.get_logger().info(f"Switched to {self.ctrl_mode} control mode") self.get_logger().info(f"Switched to {self.ctrl_mode} control mode")
if self.ctrl_mode == "arm" and self.use_bio:
self.get_logger().warning("NOTE: Using bio instead of arm.")
# Actually send the controls # Actually send the controls
if self.ctrl_mode == "core": if self.ctrl_mode == "core":
self.send_core() self.send_core()
if self.use_old_topics: if self.use_old_topics:
self.arm_publisher.publish(ARM_STOP_MSG) if self.use_bio:
self.bio_publisher.publish(BIO_STOP_MSG)
else:
self.arm_publisher.publish(ARM_STOP_MSG)
# New topics shouldn't need to constantly send zeroes imo
else: else:
self.send_arm() if self.use_bio:
# self.send_bio() self.send_bio()
else:
self.send_arm()
if self.use_old_topics: if self.use_old_topics:
self.core_publisher.publish(CORE_STOP_MSG) self.core_publisher.publish(CORE_STOP_MSG)
# Ditto
def send_core(self): def send_core(self):
# Collect controller state # Collect controller state
left_stick_x = deadzone(self.gamepad.get_axis(0)) left_stick_x = stick_deadzone(self.gamepad.get_axis(0))
left_stick_y = deadzone(self.gamepad.get_axis(1)) left_stick_y = stick_deadzone(self.gamepad.get_axis(1))
left_trigger = deadzone(self.gamepad.get_axis(2)) left_trigger = stick_deadzone(self.gamepad.get_axis(2))
right_stick_x = deadzone(self.gamepad.get_axis(3)) right_stick_x = stick_deadzone(self.gamepad.get_axis(3))
right_stick_y = deadzone(self.gamepad.get_axis(4)) right_stick_y = stick_deadzone(self.gamepad.get_axis(4))
right_trigger = deadzone(self.gamepad.get_axis(5)) right_trigger = stick_deadzone(self.gamepad.get_axis(5))
button_a = self.gamepad.get_button(0) button_a = self.gamepad.get_button(0)
button_b = self.gamepad.get_button(1) button_b = self.gamepad.get_button(1)
button_x = self.gamepad.get_button(2) button_x = self.gamepad.get_button(2)
@@ -345,12 +374,12 @@ class Headless(Node):
def send_arm(self): def send_arm(self):
# Collect controller state # Collect controller state
left_stick_x = deadzone(self.gamepad.get_axis(0)) left_stick_x = stick_deadzone(self.gamepad.get_axis(0))
left_stick_y = deadzone(self.gamepad.get_axis(1)) left_stick_y = stick_deadzone(self.gamepad.get_axis(1))
left_trigger = deadzone(self.gamepad.get_axis(2)) left_trigger = stick_deadzone(self.gamepad.get_axis(2))
right_stick_x = deadzone(self.gamepad.get_axis(3)) right_stick_x = stick_deadzone(self.gamepad.get_axis(3))
right_stick_y = deadzone(self.gamepad.get_axis(4)) right_stick_y = stick_deadzone(self.gamepad.get_axis(4))
right_trigger = deadzone(self.gamepad.get_axis(5)) right_trigger = stick_deadzone(self.gamepad.get_axis(5))
button_a = self.gamepad.get_button(0) button_a = self.gamepad.get_button(0)
button_b = self.gamepad.get_button(1) button_b = self.gamepad.get_button(1)
button_x = self.gamepad.get_button(2) button_x = self.gamepad.get_button(2)
@@ -419,28 +448,16 @@ class Headless(Node):
# X: _ # X: _
# Y: linear actuator out # Y: linear actuator out
ARM_THRESHOLD = 0.2
# Right stick: EF yaw and axis 3 # Right stick: EF yaw and axis 3
arm_input.effector_yaw = ( arm_input.effector_yaw = stick_to_arm_direction(right_stick_x)
0 if abs(right_stick_x) < ARM_THRESHOLD else int(copysign(1, right_stick_x)) arm_input.axis3 = -1 * stick_to_arm_direction(right_stick_y)
)
arm_input.axis3 = (
0 if abs(right_stick_y) < ARM_THRESHOLD else int(-1 * copysign(1, right_stick_y))
)
# Left stick: axis 1 and 2 # Left stick: axis 1 and 2
arm_input.axis1 = ( arm_input.axis1 = stick_to_arm_direction(left_stick_x)
0 if abs(left_stick_x) < ARM_THRESHOLD else int(copysign(1, left_stick_x)) arm_input.axis2 = -1 * stick_to_arm_direction(left_stick_y)
)
arm_input.axis2 = (
0 if abs(left_stick_y) < ARM_THRESHOLD else int(-1 * copysign(1, left_stick_y))
)
# D-pad: axis 0 and _ # D-pad: axis 0 and _
arm_input.axis0 = ( arm_input.axis0 = int(dpad_input[0])
0 if dpad_input[0] == 0 else int(copysign(1, dpad_input[0]))
)
# Triggers: EF Grippers # Triggers: EF Grippers
if left_trigger > 0 and right_trigger > 0: if left_trigger > 0 and right_trigger > 0:
@@ -493,36 +510,44 @@ class Headless(Node):
ARM_THRESHOLD = 0.2 ARM_THRESHOLD = 0.2
# Right stick: EF yaw and axis 3 # Right stick: EF yaw and axis 3
arm_input.velocities[self.all_joint_names.index("wrist_yaw_joint")] = ( arm_input.velocities[self.all_joint_names.index("wrist_yaw_joint")] = float(
float(copysign(1, right_stick_x)) if abs(right_stick_x) >= ARM_THRESHOLD else 0.0 stick_to_arm_direction(right_stick_x)
) )
arm_input.velocities[self.all_joint_names.index("axis_3_joint")] = ( arm_input.velocities[self.all_joint_names.index("axis_3_joint")] = float(
float(-1 * copysign(1, right_stick_y)) if abs(right_stick_y) >= ARM_THRESHOLD else 0.0 -1 * stick_to_arm_direction(right_stick_y)
) )
# Left stick: axis 1 and 2 # Left stick: axis 1 and 2
arm_input.velocities[self.all_joint_names.index("axis_1_joint")] = ( arm_input.velocities[self.all_joint_names.index("axis_1_joint")] = float(
float(copysign(1, left_stick_x)) if abs(left_stick_x) >= ARM_THRESHOLD else 0.0 stick_to_arm_direction(left_stick_x)
) )
arm_input.velocities[self.all_joint_names.index("axis_2_joint")] = ( arm_input.velocities[self.all_joint_names.index("axis_2_joint")] = float(
float(-1 * copysign(1, left_stick_y)) if abs(left_stick_y) >= ARM_THRESHOLD else 0.0 -1 * stick_to_arm_direction(left_stick_y)
) )
# D-pad: axis 0 and _ # D-pad: axis 0 and _
arm_input.velocities[self.all_joint_names.index("axis_0_joint")] = ( arm_input.velocities[self.all_joint_names.index("axis_0_joint")] = float(
float(copysign(1, dpad_input[0])) if dpad_input[0] != 0 else 0.0 dpad_input[0]
) )
# Triggers: EF Grippers # Triggers: EF Grippers
if left_trigger > 0 and right_trigger > 0: if left_trigger > 0 and right_trigger > 0:
arm_input.velocities[self.all_joint_names.index("ef_gripper_left_joint")] = 0.0 arm_input.velocities[
self.all_joint_names.index("ef_gripper_left_joint")
] = 0.0
elif left_trigger > 0: elif left_trigger > 0:
arm_input.velocities[self.all_joint_names.index("ef_gripper_left_joint")] = -1.0 arm_input.velocities[
self.all_joint_names.index("ef_gripper_left_joint")
] = -1.0
elif right_trigger > 0: elif right_trigger > 0:
arm_input.velocities[self.all_joint_names.index("ef_gripper_left_joint")] = 1.0 arm_input.velocities[
self.all_joint_names.index("ef_gripper_left_joint")
] = 1.0
# Bumpers: EF roll # Bumpers: EF roll
arm_input.velocities[self.all_joint_names.index("wrist_roll_joint")] = right_bumper - left_bumper arm_input.velocities[self.all_joint_names.index("wrist_roll_joint")] = (
right_bumper - left_bumper
)
# A: brake # A: brake
# TODO: Brake mode # TODO: Brake mode
@@ -533,7 +558,7 @@ class Headless(Node):
self.arm_manual_pub_.publish(arm_input) self.arm_manual_pub_.publish(arm_input)
# IK # IK
elif self.arm_mode == "ik": elif self.arm_mode == "ik" and not self.use_old_topics:
arm_twist = TwistStamped() arm_twist = TwistStamped()
arm_twist.header.frame_id = "base_link" arm_twist.header.frame_id = "base_link"
arm_twist.header.stamp = self.get_clock().now().to_msg() arm_twist.header.stamp = self.get_clock().now().to_msg()
@@ -583,12 +608,12 @@ class Headless(Node):
def send_bio(self): def send_bio(self):
# Collect controller state # Collect controller state
left_stick_x = deadzone(self.gamepad.get_axis(0)) left_stick_x = stick_deadzone(self.gamepad.get_axis(0))
left_stick_y = deadzone(self.gamepad.get_axis(1)) left_stick_y = stick_deadzone(self.gamepad.get_axis(1))
left_trigger = deadzone(self.gamepad.get_axis(2)) left_trigger = stick_deadzone(self.gamepad.get_axis(2))
right_stick_x = deadzone(self.gamepad.get_axis(3)) right_stick_x = stick_deadzone(self.gamepad.get_axis(3))
right_stick_y = deadzone(self.gamepad.get_axis(4)) right_stick_y = stick_deadzone(self.gamepad.get_axis(4))
right_trigger = deadzone(self.gamepad.get_axis(5)) right_trigger = stick_deadzone(self.gamepad.get_axis(5))
button_a = self.gamepad.get_button(0) button_a = self.gamepad.get_button(0)
button_b = self.gamepad.get_button(1) button_b = self.gamepad.get_button(1)
button_x = self.gamepad.get_button(2) button_x = self.gamepad.get_button(2)
@@ -604,7 +629,7 @@ class Headless(Node):
) )
# Drill motor (FAERIE) # Drill motor (FAERIE)
if deadzone(left_trigger) > 0 or deadzone(right_trigger) > 0: if left_trigger > 0 or right_trigger > 0:
bio_input.drill = int( bio_input.drill = int(
30 * (right_trigger - left_trigger) 30 * (right_trigger - left_trigger)
) # Max duty cycle 30% ) # Max duty cycle 30%
@@ -612,13 +637,20 @@ class Headless(Node):
self.bio_publisher.publish(bio_input) self.bio_publisher.publish(bio_input)
def deadzone(value: float, threshold=STICK_DEADZONE) -> float: def stick_deadzone(value: float, threshold=STICK_DEADZONE) -> float:
"""Apply a deadzone to a joystick input so the motors don't sound angry""" """Apply a deadzone to a joystick input so the motors don't sound angry"""
if abs(value) < threshold: if abs(value) < threshold:
return 0 return 0
return value return value
def stick_to_arm_direction(value: float, threshold=ARM_DEADZONE) -> int:
"""Apply a larger deadzone to a stick input and make digital/binary instead of analog"""
if abs(value) < threshold:
return 0
return int(copysign(1, value))
def is_user_in_group(group_name: str) -> bool: def is_user_in_group(group_name: str) -> bool:
# Copied from https://zetcode.com/python/os-getgrouplist/ # Copied from https://zetcode.com/python/os-getgrouplist/
try: try:
@@ -637,10 +669,19 @@ def is_user_in_group(group_name: str) -> bool:
return False return False
def exit_handler(signum, frame):
print("Caught SIGTERM. Exiting...")
rclpy.try_shutdown()
sys.exit(0)
def main(args=None): def main(args=None):
try: try:
rclpy.init(args=args) rclpy.init(args=args)
# Catch termination signals and exit cleanly
signal.signal(signal.SIGTERM, exit_handler)
node = Headless() node = Headless()
rclpy.spin(node) rclpy.spin(node)
except (KeyboardInterrupt, ExternalShutdownException): except (KeyboardInterrupt, ExternalShutdownException):
@@ -650,7 +691,4 @@ def main(args=None):
if __name__ == "__main__": if __name__ == "__main__":
signal.signal(
signal.SIGTERM, lambda signum, frame: sys.exit(0)
) # Catch termination signals and exit cleanly
main() main()