Files
task_gen_with_llm/llm_task_generator.py

247 lines
7.5 KiB
Python
Raw Normal View History

2026-01-12 18:25:04 +09:00
# llm_task_generator_activity.py
from __future__ import annotations
import os
import json
from typing import Any, Dict, List, Optional
from openai import OpenAI
# -----------------------------
# OpenAI client / model
# -----------------------------
def _get_openai_client() -> "OpenAI":
if OpenAI is None:
raise RuntimeError(
"Python package `openai` is not installed. Install it (e.g. `pip install openai`) "
"to use generate_task_spec_from_activity()."
)
return OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
MODEL_NAME = "gpt-4.1-mini" # 필요 시 "gpt-4.1"로 변경
# -----------------------------
# Output contract (TaskSpec JSON)
# -----------------------------
TASK_SCHEMA_DESCRIPTION = """
Output JSON ONLY (no markdown, no comments), with this exact structure:
{
"task_id": string,
"activity": string,
"task_name": string,
"skills": [
{ "name": string, "args": object }
],
"goal": {
"summary": string,
"success_conditions": [
{ "type": string, ...predicate args... }
]
}
}
Rules:
- Use only allowed objects/fixtures/containers/tools provided.
- Use only allowed skills provided.
- Use only allowed goal predicates provided.
- Steps must be between min_steps and max_steps (inclusive).
- Do NOT output coordinates, poses, or any continuous control targets.
- Ensure preconditions: pick(object) must occur before place/align/insert/release for that object.
- Skill args must use the EXACT keys specified in the "Skill argument schema" section of the prompt.
"""
DEFAULT_ALLOWED_SKILLS = [
"pick", "release", "place",
"align", "insert",
]
DEFAULT_ALLOWED_PREDICATES = [
"is_grasped",
"object_in_fixture",
"object_in_container",
"pose_in_tolerance",
"depth_in_range",
"fixture_state",
"button_pressed",
"switch_state",
"thread_engaged",
]
SKILL_ARG_SCHEMA: Dict[str, Dict[str, List[str]]] = {
"pick": {"required": ["object"], "optional": []},
"release": {"required": ["object"], "optional": []},
"place": {"required": ["object", "target"], "optional": []},
"align": {"required": ["object", "target_fixture"], "optional": []},
"insert": {"required": ["object", "target_fixture"], "optional": ["depth"]},
}
def _format_skill_arg_schema(allowed_skills: List[str]) -> str:
lines: List[str] = []
for skill_name in allowed_skills:
spec = SKILL_ARG_SCHEMA.get(skill_name)
if spec is None:
continue
req = ", ".join(spec["required"]) if spec["required"] else "(none)"
opt = ", ".join(spec["optional"]) if spec["optional"] else "(none)"
lines.append(f'- {skill_name}: required args = [{req}], optional args = [{opt}]')
if not lines:
return "- (no schema available)"
return "\n".join(lines)
def _normalize_task_spec(task: Dict[str, Any]) -> Dict[str, Any]:
skills = task.get("skills", [])
if not isinstance(skills, list):
return task
for s in skills:
if not isinstance(s, dict):
continue
name = s.get("name")
args = s.get("args", {})
if not isinstance(args, dict):
args = {}
if name in ("align", "insert"):
if "target_fixture" not in args and "fixture" in args:
args["target_fixture"] = args.pop("fixture")
if name == "place":
if "target" not in args and "fixture" in args:
args["target"] = args.pop("fixture")
s["args"] = args
return task
def _build_task_prompt_from_activity(
activity_key: str,
activity_def: Dict[str, Any],
*,
allowed_skills: List[str],
allowed_predicates: List[str],
min_steps: int,
max_steps: int,
) -> str:
desc = activity_def.get("description", "")
objs = activity_def.get("allowed_objects", [])
fxs = activity_def.get("allowed_fixtures", [])
containers = activity_def.get("allowed_containers", [])
tools = activity_def.get("allowed_tools", [])
tags = activity_def.get("difficulty_tags", [])
# activity-specific overrides are allowed
activity_skills = activity_def.get("allowed_skills", allowed_skills)
activity_preds = activity_def.get("allowed_predicates", allowed_predicates)
return f"""
You are generating a robotics task specification for a WORKBENCH dexterous manipulation setup.
SELECTED ACTIVITY KEY: {activity_key}
Activity description: {desc}
Difficulty tags: {tags}
Allowed entities:
- objects: {objs}
- fixtures: {fxs}
- containers: {containers}
- tools: {tools}
Allowed skills (use EXACT names):
{activity_skills}
Allowed goal predicates (use EXACT names):
{activity_preds}
Skill argument schema (MUST follow EXACT arg keys):
{_format_skill_arg_schema(activity_skills)}
Hard constraints:
- Output JSON only.
- Do not invent new entities (must be from allowed lists).
- Do not invent new skills/predicates.
- Steps must be between {min_steps} and {max_steps}.
- Must satisfy skill preconditions:
* pick(object) must happen before align/insert/place/release for that object
* release(object) should only happen after object was picked
- No coordinates, no poses, no numeric target positions in skills.
{TASK_SCHEMA_DESCRIPTION}
Now generate ONE task JSON. Make it unique, concrete, and feasible for the selected activity.
"""
def _parse_json_strict(text: str) -> Dict[str, Any]:
text = text.strip()
# minimal cleanup for cases with accidental code fences
if text.startswith("```"):
text = text.strip("`")
# try to remove language hint
lines = text.splitlines()
if lines and lines[0].lower().startswith("json"):
text = "\n".join(lines[1:])
try:
return json.loads(text)
except json.JSONDecodeError as e:
raise RuntimeError(f"LLM output was not valid JSON.\n---\n{text}\n---") from e
def generate_task_spec_from_activity(
activity_key: str,
activity_def: Dict[str, Any],
*,
model: str = MODEL_NAME,
temperature: float = 0.2,
max_output_tokens: int = 900,
min_steps: int = 5,
max_steps: int = 10,
allowed_skills: Optional[List[str]] = None,
allowed_predicates: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Generate ONE TaskSpec JSON dict for a pre-defined activity.
Returns a dict like:
{ task_id, activity, task_name, skills:[...], goal:{summary, success_conditions:[...]} }
"""
if allowed_skills is None:
allowed_skills = DEFAULT_ALLOWED_SKILLS
if allowed_predicates is None:
allowed_predicates = DEFAULT_ALLOWED_PREDICATES
prompt = _build_task_prompt_from_activity(
activity_key,
activity_def,
allowed_skills=allowed_skills,
allowed_predicates=allowed_predicates,
min_steps=min_steps,
max_steps=max_steps,
)
client = _get_openai_client()
resp = client.responses.create(
model=model,
input=[
{"role": "system", "content": "You output strictly valid JSON task specs for robotics. No extra text."},
{"role": "user", "content": prompt},
],
temperature=temperature,
max_output_tokens=max_output_tokens,
)
task = _parse_json_strict(resp.output_text)
task = _normalize_task_spec(task)
# Ensure activity field matches selected activity_key (stability)
task["activity"] = activity_key
if "task_id" not in task or not task["task_id"]:
task["task_id"] = f"{activity_key}_auto_0001"
if "task_name" not in task or not task["task_name"]:
task["task_name"] = task["task_id"]
return task