Files
task_gen_with_llm/llm_fsm_multi_env.py

678 lines
23 KiB
Python
Raw Normal View History

2026-01-12 18:25:04 +09:00
# llm_fsm_multi_env.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple
import torch
# --------------------------
# Types
# --------------------------
Vec3T = torch.Tensor # (N,3)
QuatT = torch.Tensor # (N,4) wxyz
MaskT = torch.Tensor # (N,) bool
SkillName = Literal[
"pick", "release", "place", "drop",
"align", "insert", "remove",
"rotate", "rotate_until_limit",
"fasten", "loosen",
"open", "close",
"clamp", "unclamp",
"press", "toggle",
"insert_into", "remove_from",
"regrasp", "adjust_pose",
]
PredicateType = Literal[
"is_grasped",
"object_on_surface",
"object_in_container",
"object_in_fixture",
"pose_in_tolerance",
"depth_in_range",
"fixture_state",
"button_pressed",
"switch_state",
"thread_engaged",
]
# --------------------------
# Spec dataclasses
# --------------------------
@dataclass
class SkillSpec:
name: SkillName
args: Dict[str, Any]
@dataclass
class PredicateSpec:
type: PredicateType
args: Dict[str, Any] # free-form params
@dataclass
class GoalSpec:
summary: str
success_conditions: List[PredicateSpec]
@dataclass
class TaskSpec:
task_id: str
activity: str
task_name: str
skills: List[SkillSpec]
goal: GoalSpec
@staticmethod
def from_dict(task: Dict[str, Any]) -> "TaskSpec":
skills_in = task.get("skills", [])
skills: List[SkillSpec] = []
for s in skills_in:
if not isinstance(s, dict):
raise TypeError(f"Skill entry must be a dict, got: {type(s)}")
skills.append(SkillSpec(name=s["name"], args=dict(s.get("args", {}))))
goal_in = task.get("goal", {})
preds_in = goal_in.get("success_conditions", [])
preds: List[PredicateSpec] = []
for p in preds_in:
if not isinstance(p, dict):
raise TypeError(f"Predicate entry must be a dict, got: {type(p)}")
p_type = p["type"]
p_args = dict(p)
p_args.pop("type", None)
preds.append(PredicateSpec(type=p_type, args=p_args))
goal = GoalSpec(summary=goal_in.get("summary", ""), success_conditions=preds)
return TaskSpec(
task_id=task.get("task_id", ""),
activity=task.get("activity", ""),
task_name=task.get("task_name", ""),
skills=skills,
goal=goal,
)
# --------------------------
# Minimal Batched Env API (YOU implement these)
# --------------------------
class BatchedEnvAPI:
"""
You adapt this to your IsaacLab env.
Required to be batched over num_envs (N).
All returned tensors are on the same device as env (typically CUDA).
"""
device: torch.device
num_envs: int
# --- asset queries ---
def get_pose_w(self, name: str) -> tuple[Vec3T, QuatT]:
"""Return pose of an entity by name: pos (N,3), quat (N,4) wxyz."""
raise NotImplementedError
def contains(self, fixture_or_container: str, obj: str) -> MaskT:
"""Return (N,) bool if obj is in fixture/container region."""
raise NotImplementedError
def fixture_state(self, fixture: str, state_name: str) -> Any:
"""
Return a batched state. Common choices:
- bool tensor (N,) for open/closed
- float tensor (N,) for joint position / open ratio
- int tensor (N,) for discrete states
"""
raise NotImplementedError
def button_pressed(self, fixture: str, button: str) -> MaskT:
raise NotImplementedError
def switch_state(self, fixture: str, switch: str) -> Any:
raise NotImplementedError
# --- stepping / resets ---
def step_physics(self, action: Optional[torch.Tensor] = None) -> None:
"""Advance sim one step. action can be dummy if you drive controllers elsewhere."""
raise NotImplementedError
def reset(self, env_ids: Optional[torch.Tensor] = None) -> None:
"""Reset all envs if env_ids is None else only subset env_ids (1D int tensor)."""
raise NotImplementedError
# --------------------------
# Robot Facade (Batched) - YOU wire to IK/hand control
# --------------------------
@dataclass
class MoveResultBatched:
reached: MaskT # (N,) bool
class RobotFacadeBatched:
"""
Replace internals with your actual IsaacLab controllers.
This interface is batched: everything operates on all envs, masked by active_mask.
"""
def __init__(self, env: BatchedEnvAPI):
self.env = env
self.device = env.device
self.N = env.num_envs
# TODO: create & store your DifferentialIKController / hand joint controllers here
# self.ik = DifferentialIKController(...)
# self.hand = ...
def get_ee_pose_w(self) -> tuple[Vec3T, QuatT]:
"""Return EE pose (N,3),(N,4)."""
raise NotImplementedError
def is_grasped(self, object_name: str) -> MaskT:
"""Return (N,) grasp status. Implement using contact sensors + relative pose."""
raise NotImplementedError
def move_ee_pose(
self,
pos_w: Vec3T,
quat_w: QuatT,
active_mask: Optional[MaskT] = None,
pos_tol: float = 0.01,
rot_tol: float = 0.2,
) -> MoveResultBatched:
"""
Move EE toward target pose for envs in active_mask.
Return reached mask for ALL envs; unreachable envs should be False.
"""
if active_mask is None:
active_mask = torch.ones(self.N, device=self.device, dtype=torch.bool)
# TODO: implement your IK step here:
# - compute joint targets for active envs
# - write into controllers / action buffer
# - evaluate reached: compare current ee pose vs target
# Placeholder: pretend reached immediately for active envs
reached = torch.zeros(self.N, device=self.device, dtype=torch.bool)
reached[active_mask] = True
return MoveResultBatched(reached=reached)
def open_hand(self, active_mask: Optional[MaskT] = None) -> None:
if active_mask is None:
active_mask = torch.ones(self.N, device=self.device, dtype=torch.bool)
# TODO: set finger joint targets open for active envs
return
def close_hand(self, active_mask: Optional[MaskT] = None) -> None:
if active_mask is None:
active_mask = torch.ones(self.N, device=self.device, dtype=torch.bool)
# TODO: set finger joint targets close for active envs
return
def push_along_axis(
self,
start_pos_w: Vec3T,
quat_w: QuatT,
axis_w: Vec3T,
dist: float,
active_mask: Optional[MaskT] = None,
pos_tol: float = 0.01,
rot_tol: float = 0.2,
) -> MoveResultBatched:
"""Target = start + axis*dist."""
target = start_pos_w + axis_w * dist
return self.move_ee_pose(target, quat_w, active_mask=active_mask, pos_tol=pos_tol, rot_tol=rot_tol)
# --------------------------
# Success Evaluator (Batched)
# --------------------------
class SuccessEvaluatorBatched:
def __init__(self, env: BatchedEnvAPI, robot: RobotFacadeBatched, predicates: List[PredicateSpec]):
self.env = env
self.robot = robot
self.predicates = predicates
self.device = env.device
self.N = env.num_envs
def check(self) -> MaskT:
"""Return success mask (N,)."""
if len(self.predicates) == 0:
return torch.zeros(self.N, device=self.device, dtype=torch.bool)
ok = torch.ones(self.N, device=self.device, dtype=torch.bool)
for p in self.predicates:
ok = ok & self._eval(p)
return ok
def _eval(self, p: PredicateSpec) -> MaskT:
t = p.type
a = p.args
if t == "is_grasped":
return self.robot.is_grasped(a["object"])
if t == "object_in_fixture":
return self.env.contains(a["fixture"], a["object"])
if t == "object_in_container":
return self.env.contains(a["container"], a["object"])
if t == "fixture_state":
cur = self.env.fixture_state(a["fixture"], a["state_name"])
# common: cur is bool tensor
if torch.is_tensor(cur):
return cur == torch.tensor(a["value"], device=cur.device, dtype=cur.dtype)
# if not tensor, you can adapt here
raise TypeError("fixture_state must return a tensor for batched evaluator")
if t == "button_pressed":
return self.env.button_pressed(a["fixture"], a["button"])
if t == "switch_state":
cur = self.env.switch_state(a["fixture"], a["switch"])
if torch.is_tensor(cur):
# value might be str -> you need mapping; keep it tensor-based in env
return cur == torch.tensor(a["value"], device=cur.device, dtype=cur.dtype)
raise TypeError("switch_state must return a tensor for batched evaluator")
if t == "pose_in_tolerance":
pos, _ = self.env.get_pose_w(a["object"])
tgt_pos, _ = self.env.get_pose_w(a["target_frame"])
pos_tol = float(a.get("pos_tol", 0.02))
d = torch.linalg.norm(pos - tgt_pos, dim=-1)
return d <= pos_tol
if t == "depth_in_range":
obj_pos, _ = self.env.get_pose_w(a["object"])
hole_pos, _ = self.env.get_pose_w(a["fixture"])
axis = a.get("axis_w", [0.0, 0.0, 1.0])
axis_w = torch.tensor(axis, device=self.device, dtype=obj_pos.dtype).view(1, 3).repeat(self.N, 1)
depth = torch.sum((obj_pos - hole_pos) * axis_w, dim=-1)
return (depth >= float(a["min_depth"])) & (depth <= float(a["max_depth"]))
# optional ones not implemented here:
if t in ("object_on_surface", "thread_engaged"):
return torch.zeros(self.N, device=self.device, dtype=torch.bool)
raise ValueError(f"Unknown predicate type: {t}")
# --------------------------
# Batched Skills
# - Each skill keeps per-env internal phase/timers as tensors.
# - step(active_mask) returns done_mask for ALL envs (True only for envs that finished this skill).
# --------------------------
@dataclass
class SkillContext:
env: BatchedEnvAPI
robot: RobotFacadeBatched
class BaseSkillBatched:
def __init__(self, ctx: SkillContext):
self.ctx = ctx
self.device = ctx.env.device
self.N = ctx.env.num_envs
self.phase = torch.zeros(self.N, device=self.device, dtype=torch.long) # per-env phase
def reset(self, enter_mask: MaskT) -> None:
"""Called when some envs enter this skill (idx changes to this skill)."""
self.phase[enter_mask] = 0
def step(self, active_mask: MaskT) -> MaskT:
raise NotImplementedError
class PickSkillBatched(BaseSkillBatched):
def __init__(self, ctx: SkillContext, object_name: str, pregrasp_z: float = 0.08):
super().__init__(ctx)
self.object_name = object_name
self.pregrasp_z = pregrasp_z
def step(self, active_mask: MaskT) -> MaskT:
done = torch.zeros(self.N, device=self.device, dtype=torch.bool)
# phase 0: move to pregrasp
m0 = active_mask & (self.phase == 0)
if m0.any():
obj_pos, obj_quat = self.ctx.env.get_pose_w(self.object_name)
target_pos = obj_pos.clone()
target_pos[m0, 2] = obj_pos[m0, 2] + self.pregrasp_z
r = self.ctx.robot.move_ee_pose(target_pos, obj_quat, active_mask=m0)
reached = r.reached & m0
self.phase[reached] = 1
# phase 1: close hand (one-shot)
m1 = active_mask & (self.phase == 1)
if m1.any():
self.ctx.robot.close_hand(active_mask=m1)
self.phase[m1] = 2
# phase 2: verify
m2 = active_mask & (self.phase == 2)
if m2.any():
grasped = self.ctx.robot.is_grasped(self.object_name) & m2
done[grasped] = True
# keep phase=2 for those not yet grasped (could add timeout logic)
return done
class ReleaseSkillBatched(BaseSkillBatched):
def __init__(self, ctx: SkillContext, object_name: str):
super().__init__(ctx)
self.object_name = object_name
def step(self, active_mask: MaskT) -> MaskT:
done = torch.zeros(self.N, device=self.device, dtype=torch.bool)
m0 = active_mask & (self.phase == 0)
if m0.any():
self.ctx.robot.open_hand(active_mask=m0)
self.phase[m0] = 1
# immediate done (optionally verify not grasped)
m1 = active_mask & (self.phase == 1)
done[m1] = True
return done
class PlaceSkillBatched(BaseSkillBatched):
def __init__(self, ctx: SkillContext, object_name: str, target_name: str, above_z: float = 0.10):
super().__init__(ctx)
self.object_name = object_name
self.target_name = target_name
self.above_z = above_z
def step(self, active_mask: MaskT) -> MaskT:
done = torch.zeros(self.N, device=self.device, dtype=torch.bool)
# phase 0: move above target
m0 = active_mask & (self.phase == 0)
if m0.any():
tgt_pos, tgt_quat = self.ctx.env.get_pose_w(self.target_name)
above = tgt_pos.clone()
above[m0, 2] = tgt_pos[m0, 2] + self.above_z
r = self.ctx.robot.move_ee_pose(above, tgt_quat, active_mask=m0)
reached = r.reached & m0
self.phase[reached] = 1
# phase 1: move down to target
m1 = active_mask & (self.phase == 1)
if m1.any():
tgt_pos, tgt_quat = self.ctx.env.get_pose_w(self.target_name)
r = self.ctx.robot.move_ee_pose(tgt_pos, tgt_quat, active_mask=m1)
reached = r.reached & m1
self.phase[reached] = 2
# phase 2: open hand then done
m2 = active_mask & (self.phase == 2)
if m2.any():
self.ctx.robot.open_hand(active_mask=m2)
done[m2] = True
return done
class AlignSkillBatched(BaseSkillBatched):
"""
Placeholder alignment: go to pre-insert pose above fixture frame.
Replace internally with your pinch-axis + face selection + micro-adjust.
"""
def __init__(self, ctx: SkillContext, object_name: str, target_fixture: str, standoff_z: float = 0.06):
super().__init__(ctx)
self.object_name = object_name
self.target_fixture = target_fixture
self.standoff_z = standoff_z
def step(self, active_mask: MaskT) -> MaskT:
done = torch.zeros(self.N, device=self.device, dtype=torch.bool)
m0 = active_mask & (self.phase == 0)
if m0.any():
hole_pos, hole_quat = self.ctx.env.get_pose_w(self.target_fixture)
pre = hole_pos.clone()
pre[m0, 2] = hole_pos[m0, 2] + self.standoff_z
r = self.ctx.robot.move_ee_pose(pre, hole_quat, active_mask=m0)
reached = r.reached & m0
done[reached] = True
return done
class InsertSkillBatched(BaseSkillBatched):
"""
Placeholder insertion: move to pre, then push along -Z.
Replace axis with fixture insertion axis, and add contact-based micro-adjust.
"""
def __init__(self, ctx: SkillContext, object_name: str, target_fixture: str, pre_z: float = 0.03, depth: float = 0.02):
super().__init__(ctx)
self.object_name = object_name
self.target_fixture = target_fixture
self.pre_z = pre_z
self.depth = depth
def step(self, active_mask: MaskT) -> MaskT:
done = torch.zeros(self.N, device=self.device, dtype=torch.bool)
# phase 0: pre-insert
m0 = active_mask & (self.phase == 0)
if m0.any():
hole_pos, hole_quat = self.ctx.env.get_pose_w(self.target_fixture)
pre = hole_pos.clone()
pre[m0, 2] = hole_pos[m0, 2] + self.pre_z
r = self.ctx.robot.move_ee_pose(pre, hole_quat, active_mask=m0)
reached = r.reached & m0
self.phase[reached] = 1
# phase 1: push along axis
m1 = active_mask & (self.phase == 1)
if m1.any():
hole_pos, hole_quat = self.ctx.env.get_pose_w(self.target_fixture)
start = hole_pos.clone()
start[m1, 2] = hole_pos[m1, 2] + self.pre_z
axis_w = torch.zeros((self.N, 3), device=self.device, dtype=hole_pos.dtype)
axis_w[:, 2] = -1.0
r = self.ctx.robot.push_along_axis(start, hole_quat, axis_w, self.depth, active_mask=m1)
reached = r.reached & m1
done[reached] = True
return done
# --------------------------
# Skill Factory (Batched)
# --------------------------
def build_skill_batched(ctx: SkillContext, spec: SkillSpec) -> BaseSkillBatched:
n = spec.name
a = spec.args
if n == "pick":
return PickSkillBatched(ctx, a["object"])
if n == "release":
return ReleaseSkillBatched(ctx, a["object"])
if n == "place":
return PlaceSkillBatched(ctx, a["object"], a["target"])
if n == "align":
return AlignSkillBatched(ctx, a["object"], a["target_fixture"])
if n == "insert":
depth = float(a.get("depth", 0.02))
return InsertSkillBatched(ctx, a["object"], a["target_fixture"], depth=depth)
raise NotImplementedError(f"Skill not implemented (batched): {n}")
# --------------------------
# Batched Skill FSM
# --------------------------
class SkillFSMBatched:
"""
One task_spec replicated across N envs.
Maintain env-wise idx; execute appropriate skill for env subsets.
"""
def __init__(self, ctx: SkillContext, skill_specs: List[SkillSpec]):
self.ctx = ctx
self.device = ctx.env.device
self.N = ctx.env.num_envs
self.skill_specs = skill_specs
self.K = len(skill_specs)
# Prebuild all skill objects (K of them), each with per-env phase tensor
self.skills: List[BaseSkillBatched] = [build_skill_batched(ctx, s) for s in skill_specs]
# env-wise pointer
self.idx = torch.zeros(self.N, device=self.device, dtype=torch.long) # current skill index per env
self.done = torch.zeros(self.N, device=self.device, dtype=torch.bool) # whether sequence finished
def reset(self, env_ids: Optional[torch.Tensor] = None) -> None:
"""Reset FSM for all envs or subset."""
if env_ids is None:
mask = torch.ones(self.N, device=self.device, dtype=torch.bool)
else:
mask = torch.zeros(self.N, device=self.device, dtype=torch.bool)
mask[env_ids] = True
self.idx[mask] = 0
self.done[mask] = False
# envs enter skill 0
self.skills[0].reset(mask)
def step(self, active_mask: MaskT) -> MaskT:
"""
Advance one tick for envs in active_mask.
Returns seq_done_mask (N,) indicating which envs have finished all skills.
"""
# Don't run already finished envs
run_mask = active_mask & (~self.done)
if not run_mask.any():
return self.done.clone()
# For each skill k, run for envs where idx==k
for k in range(self.K):
mk = run_mask & (self.idx == k)
if not mk.any():
continue
skill_done = self.skills[k].step(mk) # returns (N,)
finished_here = mk & skill_done
if finished_here.any():
# advance idx
next_idx = k + 1
if next_idx >= self.K:
# sequence finished
self.done[finished_here] = True
else:
self.idx[finished_here] = next_idx
# envs "enter" next skill => reset its internal phase for these envs
self.skills[next_idx].reset(finished_here)
return self.done.clone()
# --------------------------
# Compiled Task Bundle (Batched)
# --------------------------
@dataclass
class CompiledTaskBatched:
env: BatchedEnvAPI
robot: RobotFacadeBatched
fsm: SkillFSMBatched
evaluator: SuccessEvaluatorBatched
# --------------------------
# Compiler (Batched)
# --------------------------
class LLMFSMCompilerBatched:
def compile(self, env: BatchedEnvAPI, task_spec: TaskSpec) -> CompiledTaskBatched:
robot = RobotFacadeBatched(env)
ctx = SkillContext(env=env, robot=robot)
fsm = SkillFSMBatched(ctx, task_spec.skills)
evaluator = SuccessEvaluatorBatched(env, robot, task_spec.goal.success_conditions)
# full reset
fsm.reset()
return CompiledTaskBatched(env=env, robot=robot, fsm=fsm, evaluator=evaluator)
# --------------------------
# Runner loop (multi-env)
# --------------------------
def run_compiled_task_multi_env(
compiled: CompiledTaskBatched,
max_steps: int = 2000,
auto_reset_success: bool = False,
) -> Dict[str, Any]:
env = compiled.env
N = env.num_envs
device = env.device
# initial reset
env.reset()
compiled.fsm.reset()
success = torch.zeros(N, device=device, dtype=torch.bool)
for t in range(max_steps):
# 1) check success before step (optional)
success_now = compiled.evaluator.check()
success = success | success_now
# 2) active envs = not success and not fsm-done (you can choose policy)
active = (~success) & (~compiled.fsm.done)
# 3) advance FSM for active envs
compiled.fsm.step(active)
# 4) physics step
env.step_physics(action=None)
# 5) optional: auto-reset envs that succeeded (common for dataset generation)
if auto_reset_success and success_now.any():
env_ids = torch.nonzero(success_now, as_tuple=False).squeeze(-1)
env.reset(env_ids)
compiled.fsm.reset(env_ids)
success[env_ids] = False # start new episode for those envs
# early stop: all envs succeeded or finished
if ((success | compiled.fsm.done).all()):
break
return {
"success_mask": success.detach().clone(),
"fsm_done_mask": compiled.fsm.done.detach().clone(),
"steps": t + 1,
}
# --------------------------
# Example: constructing TaskSpec (no JSON parser here)
# --------------------------
def make_example_task_spec() -> TaskSpec:
skills = [
SkillSpec("pick", {"object": "peg_small"}),
SkillSpec("align", {"object": "peg_small", "target_fixture": "alignment_jig_hole"}),
SkillSpec("insert", {"object": "peg_small", "target_fixture": "alignment_jig_hole", "depth": 0.02}),
SkillSpec("release", {"object": "peg_small"}),
]
preds = [
PredicateSpec("object_in_fixture", {"object": "peg_small", "fixture": "alignment_jig_hole"}),
PredicateSpec("depth_in_range", {"object": "peg_small", "fixture": "alignment_jig_hole", "min_depth": 0.018, "max_depth": 0.022, "axis_w": [0, 0, 1]}),
]
return TaskSpec(
task_id="peg_seat_001",
activity="Peg Alignment & Seating",
task_name="Insert Small Peg into Alignment Jig",
skills=skills,
goal=GoalSpec(summary="Insert peg into jig", success_conditions=preds),
)