Skip to content

Commit 586ab23

Browse files
committed
chore: cleanup
docs: update type: ignore docs: add diagram
1 parent 5bc91db commit 586ab23

File tree

5 files changed

+28
-28
lines changed

5 files changed

+28
-28
lines changed

docs/agents/spatiotemporal.md

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SpatioTemporalAgent
22

3+
![SpatioTemporalAgent](../imgs/spatiotemporal.png)
4+
35
## Overview
46

57
The `SpatioTemporalAgent` is an intelligent agent designed to capture, process, and store spatiotemporal data - information that combines both spatial (location, pose) and temporal (time-based) aspects. It's particularly useful in robotics and autonomous systems where understanding both the spatial context and its evolution over time is crucial.
@@ -34,7 +36,7 @@ The agent is configured through `SpatioTemporalConfig`:
3436
- `collection_name`: Collection for data storage
3537
- `image_to_text_model`: AI model for image description
3638
- `context_compression_model`: AI model for context compression
37-
- `embeddings`: Embeddings model
39+
- `vector_db`: Vector store
3840
- `time_interval`: Data collection frequency
3941

4042
## ROS2 Implementation
@@ -62,20 +64,21 @@ docker run -d --name rai-mongo -p 27017:27017 mongo
6264
2. **Agent Configuration**
6365

6466
```python
67+
import rclpy
6568
from rai.agents.spatiotemporal import ROS2SpatioTemporalAgent, ROS2SpatioTemporalConfig
66-
from rai.utils.model_initialization import get_llm_model, get_embeddings_model
69+
from rai.utils.model_initialization import get_llm_model, get_vectorstore
6770

6871
config = ROS2SpatioTemporalConfig(
6972
robot_frame="base_link",
7073
world_frame="world",
7174
db_url="mongodb://localhost:27017/",
7275
db_name="rai",
73-
collection_name="spatiotemporal_data",
76+
collection_name="spatiotemporal_collection",
7477
image_to_text_model=get_llm_model("simple_model"),
7578
context_compression_model=get_llm_model("simple_model"),
7679
time_interval=10.0,
77-
camera_topics=["/camera/color/image_raw"],
78-
embeddings=get_embeddings_model(),
80+
camera_topics=["/camera/camera/color/image_raw"],
81+
vector_db=get_vectorstore(),
7982
)
8083

8184
agent = ROS2SpatioTemporalAgent(config)

docs/imgs/spatiotemporal.png

791 KB
Loading

src/rai_core/rai/agents/spatiotemporal/ros2.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ROS2SpatioTemporalConfig(SpatioTemporalConfig):
3737

3838

3939
class ROS2SpatioTemporalAgent(SpatioTemporalAgent):
40-
def __init__(self, config: ROS2SpatioTemporalConfig, *args, **kwargs):
40+
def __init__(self, config: ROS2SpatioTemporalConfig, *args, **kwargs): # type: ignore
4141
super().__init__(config, *args, **kwargs)
4242
self.config = config
4343
self.connector = ROS2ARIConnector()
@@ -48,7 +48,7 @@ def _get_images(self) -> Dict[Annotated[str, "camera topic"], str]:
4848
images: Dict[Annotated[str, "camera topic"], str] = {}
4949
for camera_topic in self.config.camera_topics:
5050
try:
51-
_, artifact = self.get_image_tool._run(topic=camera_topic)
51+
_, artifact = self.get_image_tool._run(topic=camera_topic) # type: ignore
5252
image = artifact["images"][0]
5353
images[camera_topic] = image
5454
except Exception as e:
@@ -67,21 +67,21 @@ def _get_tf(self) -> Optional[PoseStamped]:
6767
return None
6868
ps = PoseStamped(
6969
header=Header(
70-
stamp=tf_stamped.header.stamp.sec
71-
+ tf_stamped.header.stamp.nanosec * 1e-9,
72-
frame_id=tf_stamped.header.frame_id,
70+
stamp=tf_stamped.header.stamp.sec # type: ignore
71+
+ tf_stamped.header.stamp.nanosec * 1e-9, # type: ignore
72+
frame_id=tf_stamped.header.frame_id, # type: ignore
7373
),
7474
pose=Pose(
7575
position=Point(
76-
x=tf_stamped.transform.translation.x,
77-
y=tf_stamped.transform.translation.y,
78-
z=tf_stamped.transform.translation.z,
76+
x=tf_stamped.transform.translation.x, # type: ignore
77+
y=tf_stamped.transform.translation.y, # type: ignore
78+
z=tf_stamped.transform.translation.z, # type: ignore
7979
),
8080
orientation=Quaternion(
81-
x=tf_stamped.transform.rotation.x,
82-
y=tf_stamped.transform.rotation.y,
83-
z=tf_stamped.transform.rotation.z,
84-
w=tf_stamped.transform.rotation.w,
81+
x=tf_stamped.transform.rotation.x, # type: ignore
82+
y=tf_stamped.transform.rotation.y, # type: ignore
83+
z=tf_stamped.transform.rotation.z, # type: ignore
84+
w=tf_stamped.transform.rotation.w, # type: ignore
8585
),
8686
),
8787
)

src/rai_core/rai/agents/spatiotemporal/spatiotemporal_agent.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@
2828
from rai.agents.base import BaseAgent
2929
from rai.messages.multimodal import HumanMultimodalMessage
3030

31-
EMBEDDINGS_FIELD_NAME = "embeddings"
32-
SEARCH_INDEX_NAME = "embeddings_search_index"
33-
3431

3532
class Header(BaseModel):
3633
stamp: Annotated[float, "timestamp"]
@@ -100,7 +97,7 @@ def __init__(
10097
**kwargs : dict
10198
Additional keyword arguments.
10299
"""
103-
super().__init__(*args, **kwargs)
100+
super().__init__(*args, **kwargs) # type: ignore
104101
self.config = config
105102

106103
self.db = MongoClient(self.config.db_url)[self.config.db_name] # type: ignore
@@ -133,7 +130,7 @@ def _insert_into_vectorstore(self, data: SpatioTemporalData):
133130
self.logger.info("Inserting embeddings into vector store")
134131

135132
print(
136-
self.config.vector_db.add_texts(
133+
self.config.vector_db.add_texts( # type: ignore
137134
texts=[data.temporal_context + data.image_text_descriptions],
138135
metadatas=[{"id": data.id}],
139136
ids=[data.id],
@@ -252,7 +249,7 @@ def _get_image_text_descriptions(
252249
[text_description_prompt, human_message]
253250
),
254251
)
255-
if not isinstance(ai_msg.content, str):
252+
if not isinstance(ai_msg.content, str): # type: ignore
256253
raise ValueError("AI message content is not a string")
257254
text_descriptions[source] = ai_msg.content
258255

@@ -289,7 +286,8 @@ def _compress_context(self, history: List[BaseMessage]) -> str:
289286
)
290287

291288
robots_history: List[Dict[str, str]] = [
292-
{"role": msg.type, "content": msg.content} for msg in history
289+
{"role": msg.type, "content": msg.content} # type: ignore
290+
for msg in history
293291
]
294292
if len(robots_history) == 0:
295293
return ""
@@ -301,6 +299,6 @@ def _compress_context(self, history: List[BaseMessage]) -> str:
301299
[system_prompt, human_message]
302300
),
303301
)
304-
if not isinstance(ai_msg.content, str):
302+
if not isinstance(ai_msg.content, str): # type: ignore
305303
raise ValueError("AI message content is not a string")
306304
return ai_msg.content

src/rai_core/rai/utils/model_initialization.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import logging
1616
import os
1717
from dataclasses import dataclass
18-
from typing import List, Literal, Optional, cast
18+
from typing import Any, List, Literal, Optional, cast
1919

2020
import coloredlogs
2121
import tomli
@@ -109,7 +109,7 @@ def load_config() -> RAIConfig:
109109
def get_llm_model(
110110
model_type: Literal["simple_model", "complex_model"],
111111
vendor: Optional[str] = None,
112-
**kwargs,
112+
**kwargs: Any,
113113
):
114114
config = load_config()
115115
if vendor is None:
@@ -184,7 +184,6 @@ def get_vectorstore():
184184
from langchain_community.vectorstores import FAISS
185185

186186
if os.path.exists(config.vectorstore.uri):
187-
print("I EXIST")
188187
return FAISS.load_local(
189188
config.vectorstore.uri,
190189
embeddings=get_embeddings_model(),

0 commit comments

Comments
 (0)