# llm_fsm_multi_env.py from __future__ import annotations from dataclasses import dataclass from typing import Any, Dict, List, Literal, Optional, Tuple import torch # -------------------------- # Types # -------------------------- Vec3T = torch.Tensor # (N,3) QuatT = torch.Tensor # (N,4) wxyz MaskT = torch.Tensor # (N,) bool SkillName = Literal[ "pick", "release", "place", "drop", "align", "insert", "remove", "rotate", "rotate_until_limit", "fasten", "loosen", "open", "close", "clamp", "unclamp", "press", "toggle", "insert_into", "remove_from", "regrasp", "adjust_pose", ] PredicateType = Literal[ "is_grasped", "object_on_surface", "object_in_container", "object_in_fixture", "pose_in_tolerance", "depth_in_range", "fixture_state", "button_pressed", "switch_state", "thread_engaged", ] # -------------------------- # Spec dataclasses # -------------------------- @dataclass class SkillSpec: name: SkillName args: Dict[str, Any] @dataclass class PredicateSpec: type: PredicateType args: Dict[str, Any] # free-form params @dataclass class GoalSpec: summary: str success_conditions: List[PredicateSpec] @dataclass class TaskSpec: task_id: str activity: str task_name: str skills: List[SkillSpec] goal: GoalSpec @staticmethod def from_dict(task: Dict[str, Any]) -> "TaskSpec": skills_in = task.get("skills", []) skills: List[SkillSpec] = [] for s in skills_in: if not isinstance(s, dict): raise TypeError(f"Skill entry must be a dict, got: {type(s)}") skills.append(SkillSpec(name=s["name"], args=dict(s.get("args", {})))) goal_in = task.get("goal", {}) preds_in = goal_in.get("success_conditions", []) preds: List[PredicateSpec] = [] for p in preds_in: if not isinstance(p, dict): raise TypeError(f"Predicate entry must be a dict, got: {type(p)}") p_type = p["type"] p_args = dict(p) p_args.pop("type", None) preds.append(PredicateSpec(type=p_type, args=p_args)) goal = GoalSpec(summary=goal_in.get("summary", ""), success_conditions=preds) return TaskSpec( task_id=task.get("task_id", ""), activity=task.get("activity", ""), task_name=task.get("task_name", ""), skills=skills, goal=goal, ) # -------------------------- # Minimal Batched Env API (YOU implement these) # -------------------------- class BatchedEnvAPI: """ You adapt this to your IsaacLab env. Required to be batched over num_envs (N). All returned tensors are on the same device as env (typically CUDA). """ device: torch.device num_envs: int # --- asset queries --- def get_pose_w(self, name: str) -> tuple[Vec3T, QuatT]: """Return pose of an entity by name: pos (N,3), quat (N,4) wxyz.""" raise NotImplementedError def contains(self, fixture_or_container: str, obj: str) -> MaskT: """Return (N,) bool if obj is in fixture/container region.""" raise NotImplementedError def fixture_state(self, fixture: str, state_name: str) -> Any: """ Return a batched state. Common choices: - bool tensor (N,) for open/closed - float tensor (N,) for joint position / open ratio - int tensor (N,) for discrete states """ raise NotImplementedError def button_pressed(self, fixture: str, button: str) -> MaskT: raise NotImplementedError def switch_state(self, fixture: str, switch: str) -> Any: raise NotImplementedError # --- stepping / resets --- def step_physics(self, action: Optional[torch.Tensor] = None) -> None: """Advance sim one step. action can be dummy if you drive controllers elsewhere.""" raise NotImplementedError def reset(self, env_ids: Optional[torch.Tensor] = None) -> None: """Reset all envs if env_ids is None else only subset env_ids (1D int tensor).""" raise NotImplementedError # -------------------------- # Robot Facade (Batched) - YOU wire to IK/hand control # -------------------------- @dataclass class MoveResultBatched: reached: MaskT # (N,) bool class RobotFacadeBatched: """ Replace internals with your actual IsaacLab controllers. This interface is batched: everything operates on all envs, masked by active_mask. """ def __init__(self, env: BatchedEnvAPI): self.env = env self.device = env.device self.N = env.num_envs # TODO: create & store your DifferentialIKController / hand joint controllers here # self.ik = DifferentialIKController(...) # self.hand = ... def get_ee_pose_w(self) -> tuple[Vec3T, QuatT]: """Return EE pose (N,3),(N,4).""" raise NotImplementedError def is_grasped(self, object_name: str) -> MaskT: """Return (N,) grasp status. Implement using contact sensors + relative pose.""" raise NotImplementedError def move_ee_pose( self, pos_w: Vec3T, quat_w: QuatT, active_mask: Optional[MaskT] = None, pos_tol: float = 0.01, rot_tol: float = 0.2, ) -> MoveResultBatched: """ Move EE toward target pose for envs in active_mask. Return reached mask for ALL envs; unreachable envs should be False. """ if active_mask is None: active_mask = torch.ones(self.N, device=self.device, dtype=torch.bool) # TODO: implement your IK step here: # - compute joint targets for active envs # - write into controllers / action buffer # - evaluate reached: compare current ee pose vs target # Placeholder: pretend reached immediately for active envs reached = torch.zeros(self.N, device=self.device, dtype=torch.bool) reached[active_mask] = True return MoveResultBatched(reached=reached) def open_hand(self, active_mask: Optional[MaskT] = None) -> None: if active_mask is None: active_mask = torch.ones(self.N, device=self.device, dtype=torch.bool) # TODO: set finger joint targets open for active envs return def close_hand(self, active_mask: Optional[MaskT] = None) -> None: if active_mask is None: active_mask = torch.ones(self.N, device=self.device, dtype=torch.bool) # TODO: set finger joint targets close for active envs return def push_along_axis( self, start_pos_w: Vec3T, quat_w: QuatT, axis_w: Vec3T, dist: float, active_mask: Optional[MaskT] = None, pos_tol: float = 0.01, rot_tol: float = 0.2, ) -> MoveResultBatched: """Target = start + axis*dist.""" target = start_pos_w + axis_w * dist return self.move_ee_pose(target, quat_w, active_mask=active_mask, pos_tol=pos_tol, rot_tol=rot_tol) # -------------------------- # Success Evaluator (Batched) # -------------------------- class SuccessEvaluatorBatched: def __init__(self, env: BatchedEnvAPI, robot: RobotFacadeBatched, predicates: List[PredicateSpec]): self.env = env self.robot = robot self.predicates = predicates self.device = env.device self.N = env.num_envs def check(self) -> MaskT: """Return success mask (N,).""" if len(self.predicates) == 0: return torch.zeros(self.N, device=self.device, dtype=torch.bool) ok = torch.ones(self.N, device=self.device, dtype=torch.bool) for p in self.predicates: ok = ok & self._eval(p) return ok def _eval(self, p: PredicateSpec) -> MaskT: t = p.type a = p.args if t == "is_grasped": return self.robot.is_grasped(a["object"]) if t == "object_in_fixture": return self.env.contains(a["fixture"], a["object"]) if t == "object_in_container": return self.env.contains(a["container"], a["object"]) if t == "fixture_state": cur = self.env.fixture_state(a["fixture"], a["state_name"]) # common: cur is bool tensor if torch.is_tensor(cur): return cur == torch.tensor(a["value"], device=cur.device, dtype=cur.dtype) # if not tensor, you can adapt here raise TypeError("fixture_state must return a tensor for batched evaluator") if t == "button_pressed": return self.env.button_pressed(a["fixture"], a["button"]) if t == "switch_state": cur = self.env.switch_state(a["fixture"], a["switch"]) if torch.is_tensor(cur): # value might be str -> you need mapping; keep it tensor-based in env return cur == torch.tensor(a["value"], device=cur.device, dtype=cur.dtype) raise TypeError("switch_state must return a tensor for batched evaluator") if t == "pose_in_tolerance": pos, _ = self.env.get_pose_w(a["object"]) tgt_pos, _ = self.env.get_pose_w(a["target_frame"]) pos_tol = float(a.get("pos_tol", 0.02)) d = torch.linalg.norm(pos - tgt_pos, dim=-1) return d <= pos_tol if t == "depth_in_range": obj_pos, _ = self.env.get_pose_w(a["object"]) hole_pos, _ = self.env.get_pose_w(a["fixture"]) axis = a.get("axis_w", [0.0, 0.0, 1.0]) axis_w = torch.tensor(axis, device=self.device, dtype=obj_pos.dtype).view(1, 3).repeat(self.N, 1) depth = torch.sum((obj_pos - hole_pos) * axis_w, dim=-1) return (depth >= float(a["min_depth"])) & (depth <= float(a["max_depth"])) # optional ones not implemented here: if t in ("object_on_surface", "thread_engaged"): return torch.zeros(self.N, device=self.device, dtype=torch.bool) raise ValueError(f"Unknown predicate type: {t}") # -------------------------- # Batched Skills # - Each skill keeps per-env internal phase/timers as tensors. # - step(active_mask) returns done_mask for ALL envs (True only for envs that finished this skill). # -------------------------- @dataclass class SkillContext: env: BatchedEnvAPI robot: RobotFacadeBatched class BaseSkillBatched: def __init__(self, ctx: SkillContext): self.ctx = ctx self.device = ctx.env.device self.N = ctx.env.num_envs self.phase = torch.zeros(self.N, device=self.device, dtype=torch.long) # per-env phase def reset(self, enter_mask: MaskT) -> None: """Called when some envs enter this skill (idx changes to this skill).""" self.phase[enter_mask] = 0 def step(self, active_mask: MaskT) -> MaskT: raise NotImplementedError class PickSkillBatched(BaseSkillBatched): def __init__(self, ctx: SkillContext, object_name: str, pregrasp_z: float = 0.08): super().__init__(ctx) self.object_name = object_name self.pregrasp_z = pregrasp_z def step(self, active_mask: MaskT) -> MaskT: done = torch.zeros(self.N, device=self.device, dtype=torch.bool) # phase 0: move to pregrasp m0 = active_mask & (self.phase == 0) if m0.any(): obj_pos, obj_quat = self.ctx.env.get_pose_w(self.object_name) target_pos = obj_pos.clone() target_pos[m0, 2] = obj_pos[m0, 2] + self.pregrasp_z r = self.ctx.robot.move_ee_pose(target_pos, obj_quat, active_mask=m0) reached = r.reached & m0 self.phase[reached] = 1 # phase 1: close hand (one-shot) m1 = active_mask & (self.phase == 1) if m1.any(): self.ctx.robot.close_hand(active_mask=m1) self.phase[m1] = 2 # phase 2: verify m2 = active_mask & (self.phase == 2) if m2.any(): grasped = self.ctx.robot.is_grasped(self.object_name) & m2 done[grasped] = True # keep phase=2 for those not yet grasped (could add timeout logic) return done class ReleaseSkillBatched(BaseSkillBatched): def __init__(self, ctx: SkillContext, object_name: str): super().__init__(ctx) self.object_name = object_name def step(self, active_mask: MaskT) -> MaskT: done = torch.zeros(self.N, device=self.device, dtype=torch.bool) m0 = active_mask & (self.phase == 0) if m0.any(): self.ctx.robot.open_hand(active_mask=m0) self.phase[m0] = 1 # immediate done (optionally verify not grasped) m1 = active_mask & (self.phase == 1) done[m1] = True return done class PlaceSkillBatched(BaseSkillBatched): def __init__(self, ctx: SkillContext, object_name: str, target_name: str, above_z: float = 0.10): super().__init__(ctx) self.object_name = object_name self.target_name = target_name self.above_z = above_z def step(self, active_mask: MaskT) -> MaskT: done = torch.zeros(self.N, device=self.device, dtype=torch.bool) # phase 0: move above target m0 = active_mask & (self.phase == 0) if m0.any(): tgt_pos, tgt_quat = self.ctx.env.get_pose_w(self.target_name) above = tgt_pos.clone() above[m0, 2] = tgt_pos[m0, 2] + self.above_z r = self.ctx.robot.move_ee_pose(above, tgt_quat, active_mask=m0) reached = r.reached & m0 self.phase[reached] = 1 # phase 1: move down to target m1 = active_mask & (self.phase == 1) if m1.any(): tgt_pos, tgt_quat = self.ctx.env.get_pose_w(self.target_name) r = self.ctx.robot.move_ee_pose(tgt_pos, tgt_quat, active_mask=m1) reached = r.reached & m1 self.phase[reached] = 2 # phase 2: open hand then done m2 = active_mask & (self.phase == 2) if m2.any(): self.ctx.robot.open_hand(active_mask=m2) done[m2] = True return done class AlignSkillBatched(BaseSkillBatched): """ Placeholder alignment: go to pre-insert pose above fixture frame. Replace internally with your pinch-axis + face selection + micro-adjust. """ def __init__(self, ctx: SkillContext, object_name: str, target_fixture: str, standoff_z: float = 0.06): super().__init__(ctx) self.object_name = object_name self.target_fixture = target_fixture self.standoff_z = standoff_z def step(self, active_mask: MaskT) -> MaskT: done = torch.zeros(self.N, device=self.device, dtype=torch.bool) m0 = active_mask & (self.phase == 0) if m0.any(): hole_pos, hole_quat = self.ctx.env.get_pose_w(self.target_fixture) pre = hole_pos.clone() pre[m0, 2] = hole_pos[m0, 2] + self.standoff_z r = self.ctx.robot.move_ee_pose(pre, hole_quat, active_mask=m0) reached = r.reached & m0 done[reached] = True return done class InsertSkillBatched(BaseSkillBatched): """ Placeholder insertion: move to pre, then push along -Z. Replace axis with fixture insertion axis, and add contact-based micro-adjust. """ def __init__(self, ctx: SkillContext, object_name: str, target_fixture: str, pre_z: float = 0.03, depth: float = 0.02): super().__init__(ctx) self.object_name = object_name self.target_fixture = target_fixture self.pre_z = pre_z self.depth = depth def step(self, active_mask: MaskT) -> MaskT: done = torch.zeros(self.N, device=self.device, dtype=torch.bool) # phase 0: pre-insert m0 = active_mask & (self.phase == 0) if m0.any(): hole_pos, hole_quat = self.ctx.env.get_pose_w(self.target_fixture) pre = hole_pos.clone() pre[m0, 2] = hole_pos[m0, 2] + self.pre_z r = self.ctx.robot.move_ee_pose(pre, hole_quat, active_mask=m0) reached = r.reached & m0 self.phase[reached] = 1 # phase 1: push along axis m1 = active_mask & (self.phase == 1) if m1.any(): hole_pos, hole_quat = self.ctx.env.get_pose_w(self.target_fixture) start = hole_pos.clone() start[m1, 2] = hole_pos[m1, 2] + self.pre_z axis_w = torch.zeros((self.N, 3), device=self.device, dtype=hole_pos.dtype) axis_w[:, 2] = -1.0 r = self.ctx.robot.push_along_axis(start, hole_quat, axis_w, self.depth, active_mask=m1) reached = r.reached & m1 done[reached] = True return done # -------------------------- # Skill Factory (Batched) # -------------------------- def build_skill_batched(ctx: SkillContext, spec: SkillSpec) -> BaseSkillBatched: n = spec.name a = spec.args if n == "pick": return PickSkillBatched(ctx, a["object"]) if n == "release": return ReleaseSkillBatched(ctx, a["object"]) if n == "place": return PlaceSkillBatched(ctx, a["object"], a["target"]) if n == "align": return AlignSkillBatched(ctx, a["object"], a["target_fixture"]) if n == "insert": depth = float(a.get("depth", 0.02)) return InsertSkillBatched(ctx, a["object"], a["target_fixture"], depth=depth) raise NotImplementedError(f"Skill not implemented (batched): {n}") # -------------------------- # Batched Skill FSM # -------------------------- class SkillFSMBatched: """ One task_spec replicated across N envs. Maintain env-wise idx; execute appropriate skill for env subsets. """ def __init__(self, ctx: SkillContext, skill_specs: List[SkillSpec]): self.ctx = ctx self.device = ctx.env.device self.N = ctx.env.num_envs self.skill_specs = skill_specs self.K = len(skill_specs) # Prebuild all skill objects (K of them), each with per-env phase tensor self.skills: List[BaseSkillBatched] = [build_skill_batched(ctx, s) for s in skill_specs] # env-wise pointer self.idx = torch.zeros(self.N, device=self.device, dtype=torch.long) # current skill index per env self.done = torch.zeros(self.N, device=self.device, dtype=torch.bool) # whether sequence finished def reset(self, env_ids: Optional[torch.Tensor] = None) -> None: """Reset FSM for all envs or subset.""" if env_ids is None: mask = torch.ones(self.N, device=self.device, dtype=torch.bool) else: mask = torch.zeros(self.N, device=self.device, dtype=torch.bool) mask[env_ids] = True self.idx[mask] = 0 self.done[mask] = False # envs enter skill 0 self.skills[0].reset(mask) def step(self, active_mask: MaskT) -> MaskT: """ Advance one tick for envs in active_mask. Returns seq_done_mask (N,) indicating which envs have finished all skills. """ # Don't run already finished envs run_mask = active_mask & (~self.done) if not run_mask.any(): return self.done.clone() # For each skill k, run for envs where idx==k for k in range(self.K): mk = run_mask & (self.idx == k) if not mk.any(): continue skill_done = self.skills[k].step(mk) # returns (N,) finished_here = mk & skill_done if finished_here.any(): # advance idx next_idx = k + 1 if next_idx >= self.K: # sequence finished self.done[finished_here] = True else: self.idx[finished_here] = next_idx # envs "enter" next skill => reset its internal phase for these envs self.skills[next_idx].reset(finished_here) return self.done.clone() # -------------------------- # Compiled Task Bundle (Batched) # -------------------------- @dataclass class CompiledTaskBatched: env: BatchedEnvAPI robot: RobotFacadeBatched fsm: SkillFSMBatched evaluator: SuccessEvaluatorBatched # -------------------------- # Compiler (Batched) # -------------------------- class LLMFSMCompilerBatched: def compile(self, env: BatchedEnvAPI, task_spec: TaskSpec) -> CompiledTaskBatched: robot = RobotFacadeBatched(env) ctx = SkillContext(env=env, robot=robot) fsm = SkillFSMBatched(ctx, task_spec.skills) evaluator = SuccessEvaluatorBatched(env, robot, task_spec.goal.success_conditions) # full reset fsm.reset() return CompiledTaskBatched(env=env, robot=robot, fsm=fsm, evaluator=evaluator) # -------------------------- # Runner loop (multi-env) # -------------------------- def run_compiled_task_multi_env( compiled: CompiledTaskBatched, max_steps: int = 2000, auto_reset_success: bool = False, ) -> Dict[str, Any]: env = compiled.env N = env.num_envs device = env.device # initial reset env.reset() compiled.fsm.reset() success = torch.zeros(N, device=device, dtype=torch.bool) for t in range(max_steps): # 1) check success before step (optional) success_now = compiled.evaluator.check() success = success | success_now # 2) active envs = not success and not fsm-done (you can choose policy) active = (~success) & (~compiled.fsm.done) # 3) advance FSM for active envs compiled.fsm.step(active) # 4) physics step env.step_physics(action=None) # 5) optional: auto-reset envs that succeeded (common for dataset generation) if auto_reset_success and success_now.any(): env_ids = torch.nonzero(success_now, as_tuple=False).squeeze(-1) env.reset(env_ids) compiled.fsm.reset(env_ids) success[env_ids] = False # start new episode for those envs # early stop: all envs succeeded or finished if ((success | compiled.fsm.done).all()): break return { "success_mask": success.detach().clone(), "fsm_done_mask": compiled.fsm.done.detach().clone(), "steps": t + 1, } # -------------------------- # Example: constructing TaskSpec (no JSON parser here) # -------------------------- def make_example_task_spec() -> TaskSpec: skills = [ SkillSpec("pick", {"object": "peg_small"}), SkillSpec("align", {"object": "peg_small", "target_fixture": "alignment_jig_hole"}), SkillSpec("insert", {"object": "peg_small", "target_fixture": "alignment_jig_hole", "depth": 0.02}), SkillSpec("release", {"object": "peg_small"}), ] preds = [ PredicateSpec("object_in_fixture", {"object": "peg_small", "fixture": "alignment_jig_hole"}), PredicateSpec("depth_in_range", {"object": "peg_small", "fixture": "alignment_jig_hole", "min_depth": 0.018, "max_depth": 0.022, "axis_w": [0, 0, 1]}), ] return TaskSpec( task_id="peg_seat_001", activity="Peg Alignment & Seating", task_name="Insert Small Peg into Alignment Jig", skills=skills, goal=GoalSpec(summary="Insert peg into jig", success_conditions=preds), )