Skip to content

fix: ReActAgent #538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions examples/agents/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,22 @@
# See the License for the specific language goveself.rning permissions and
# limitations under the License.

from rai.agents import AgentRunner, ReActAgent
from rai.communication.ros2 import ROS2Connector, ROS2Context, ROS2HRIConnector
from rai.agents import AgentRunner
from rai.agents.langchain.react_agent import ReActAgent
from rai.communication.ros2 import ROS2Connector, ROS2Context
from rai.communication.ros2.connectors.hri_connector import ROS2HRIConnector
from rai.tools.ros2 import ROS2Toolkit


@ROS2Context()
def main():
connector = ROS2HRIConnector(sources=["/from_human"], targets=["/to_human"])
ros2_connector = ROS2Connector()
hri_connector = ROS2HRIConnector()
target_connectors = {"/to_human": hri_connector}
source_connector = ("/from_human", hri_connector)
agent = ReActAgent(
connectors={"hri": connector},
target_connectors=target_connectors,
source_connector=source_connector,
tools=ROS2Toolkit(connector=ros2_connector).get_tools(),
) # type: ignore
runner = AgentRunner([agent])
Expand Down
2 changes: 1 addition & 1 deletion src/rai_core/rai/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from rai.agents.base import BaseAgent
from rai.agents.conversational_agent import create_conversational_agent
from rai.agents.react_agent import ReActAgent
from rai.agents.langchain.react_agent import ReActAgent
from rai.agents.runner import AgentRunner, wait_for_shutdown
from rai.agents.state_based import create_state_based_agent
from rai.agents.tool_runner import ToolRunner
Expand Down
172 changes: 172 additions & 0 deletions src/rai_core/rai/agents/langchain/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import threading
import time
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from typing import Deque, Dict, List, Literal, Optional, Tuple, TypedDict

from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable
from pydantic import BaseModel

from rai.agents.base import BaseAgent
from rai.agents.langchain import HRICallbackHandler
from rai.agents.langchain.runnables import ReActAgentState
from rai.communication.hri_connector import HRIConnector, HRIMessage
from rai.initialization import get_tracing_callbacks


class BaseState(TypedDict):
messages: List[BaseMessage]


class HRIConfig(BaseModel):
source: str
targets: List[str]


class LangChainAgent(BaseAgent):
def __init__(
self,
target_connectors: Dict[str, HRIConnector],
source_connector: Tuple[str, HRIConnector],
runnable: Runnable,
state: BaseState | None = None,
new_message_behavior: Literal[
"take_all",
"keep_last",
"queue",
"interuppt_take_all",
"interuppt_keep_last",
] = "interuppt_keep_last",
max_size: int = 100,
):
super().__init__()
self.logger = logging.getLogger(__name__)
self.agent = runnable
self.new_message_behavior = new_message_behavior
self.tracing_callbacks = get_tracing_callbacks()
self.state = state or ReActAgentState(messages=[])
self.source_connector = source_connector
self.callback = HRICallbackHandler(
connectors=target_connectors,
aggregate_chunks=True,
logger=self.logger,
)

self.source, self.source_connector = source_connector
self.source_connector.register_callback(
self.source, self.source_callback, msg_type="rai_interfaces/msg/HRIMessage"
)
self._received_messages: Deque[HRIMessage] = deque()
self.max_size = max_size

self.thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._executor = ThreadPoolExecutor(max_workers=1)
self._interupt_event = threading.Event()
self._agent_ready_event = threading.Event()

def run(self):
if self.thread is not None:
raise RuntimeError("Agent is already running")
self.thread = threading.Thread(target=self._run_loop)
self.thread.start()
self._agent_ready_event.set()

def source_callback(self, msg: HRIMessage):
if self.max_size is not None and len(self._received_messages) >= self.max_size:
self.logger.warning("Buffer overflow. Dropping olders message")
self._received_messages.popleft()
if "interuppt" in self.new_message_behavior:
self._executor.submit(self.interuppt_agent_and_run)
self.logger.info(f"Received message: {msg}, {type(msg)}")
self._received_messages.append(msg)

def interuppt_agent_and_run(self):
self.logger.info("Interuppting agent...")
self._interupt_event.set()
self._agent_ready_event.wait()
self._interupt_event.clear()
self.logger.info("Interuppting agent: DONE")

def run_agent(self):
self._agent_ready_event.clear()
try:
if len(self._received_messages) == 0:
self.logger.info("Waiting for messages...")
time.sleep(0.5)
return
self.logger.info("Running agent...")
reduced_message = self._reduce_messages()
langchain_message = reduced_message.to_langchain()
self.state["messages"].append(langchain_message)
for _ in self.agent.stream(
self.state,
config={"callbacks": [self.callback, *self.tracing_callbacks]},
):
if self._interupt_event.is_set():
break
finally:
self._agent_ready_event.set()

def _run_loop(self):
while not self._stop_event.is_set():
time.sleep(0.01)
if self._agent_ready_event.is_set():
self.run_agent()

def stop(self):
self._stop_event.set()
self._interupt_event.set()
self._agent_ready_event.wait()
if self.thread is not None:
self.logger.info("Stopping the agent. Please wait...")
self.thread.join()
self.thread = None
self.logger.info("Agent stopped")

def _reduce_messages(self) -> HRIMessage:
text = ""
images = []
audios = []
source_messages = list()
if "take_all" in self.new_message_behavior:
# Take all starting from the oldest
while len(self._received_messages) > 0:
source_messages.append(self._received_messages.popleft())
elif "keep_last" in self.new_message_behavior:
# Take the recently added message
source_messages.append(self._received_messages.pop())
self._received_messages.clear()
elif self.new_message_behavior == "queue":
# Take the first message from the queue. Let other messages wait.
source_messages.append(self._received_messages.popleft())
else:
raise ValueError(
f"Invalid new_message_behavior: {self.new_message_behavior}"
)
for source_message in source_messages:
text += f"{source_message.text}\n"
images.extend(source_message.images)
audios.extend(source_message.audios)
return HRIMessage(
text=text,
images=images,
audios=audios,
message_author="human",
)
23 changes: 10 additions & 13 deletions src/rai_core/rai/agents/langchain/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
import threading
from typing import List, Optional
from typing import Dict, List, Optional
from uuid import UUID

from langchain_core.callbacks import BaseCallbackHandler
Expand All @@ -27,7 +27,7 @@
class HRICallbackHandler(BaseCallbackHandler):
def __init__(
self,
connectors: dict[str, HRIConnector[HRIMessage]],
connectors: Dict[str, HRIConnector[HRIMessage]],
aggregate_chunks: bool = False,
splitting_chars: Optional[List[str]] = None,
max_buffer_size: int = 200,
Expand All @@ -47,21 +47,18 @@ def _should_split(self, token: str) -> bool:
return token in self.splitting_chars

def _send_all_targets(self, tokens: str, done: bool = False):
self.logger.info(
f"Sending {len(tokens)} tokens to {len(self.connectors)} connectors"
)
for connector_name, connector in self.connectors.items():
for target, connector in self.connectors.items():
self.logger.info(f"Sending {len(tokens)} tokens to targer: {target}")
try:
connector.send_all_targets(
AIMessage(content=tokens),
self.current_conversation_id,
self.current_chunk_id,
done,
message = AIMessage(content=tokens)
to_send = connector.T_class.from_langchain(
message, self.current_conversation_id, self.current_chunk_id, done
)
self.logger.debug(f"Sent {len(tokens)} tokens to {connector_name}")
connector.send_message(to_send, target)
self.logger.debug(f"Sent {len(tokens)} tokens to hri_connector.")
except Exception as e:
self.logger.error(
f"Failed to send {len(tokens)} tokens to {connector_name}: {e}"
f"Failed to send {len(tokens)} tokens to hri_connector: {e}"
)

def on_llm_new_token(self, token: str, *, run_id: UUID, **kwargs):
Expand Down
44 changes: 44 additions & 0 deletions src/rai_core/rai/agents/langchain/react_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Tuple

from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool

from rai.agents.langchain import create_react_runnable
from rai.agents.langchain.agent import LangChainAgent
from rai.agents.langchain.runnables import ReActAgentState
from rai.communication.hri_connector import HRIConnector


class ReActAgent(LangChainAgent):
def __init__(
self,
target_connectors: Dict[str, HRIConnector],
source_connector: Tuple[str, HRIConnector],
llm: Optional[BaseChatModel] = None,
tools: Optional[List[BaseTool]] = None,
state: Optional[ReActAgentState] = None,
system_prompt: Optional[str] = None,
):
runnable = create_react_runnable(
llm=llm, tools=tools, system_prompt=system_prompt
)
super().__init__(
target_connectors=target_connectors,
source_connector=source_connector,
runnable=runnable,
state=state,
)
Loading