Skip to content

Commit 5bc91db

Browse files
committed
refactor: use faiss instead of mongodb for default setup
1 parent d04dcef commit 5bc91db

File tree

7 files changed

+103
-67
lines changed

7 files changed

+103
-67
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,5 @@ src/examples/*-demo
174174
artifact_database.pkl
175175

176176
imgui.ini
177+
178+
vectorstore_data/

config.toml

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ complex_model = "llama3.1:70b"
2121
embeddings_model = "llama3.2"
2222
base_url = "http://localhost:11434"
2323

24+
[vectorstore]
25+
type = "faiss"
26+
uri = "vectorstore_data"
27+
2428
[tracing]
2529
project = "rai"
2630

examples/agents/spatiotemporal.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import rclpy
1616
from rai.agents.spatiotemporal import ROS2SpatioTemporalAgent, ROS2SpatioTemporalConfig
17-
from rai.utils.model_initialization import get_embeddings_model, get_llm_model
17+
from rai.utils.model_initialization import get_llm_model, get_vectorstore
1818

1919

2020
def create_agent():
@@ -28,7 +28,7 @@ def create_agent():
2828
context_compression_model=get_llm_model("simple_model"),
2929
time_interval=10.0,
3030
camera_topics=["/camera/camera/color/image_raw"],
31-
embeddings=get_embeddings_model(),
31+
vector_db=get_vectorstore(),
3232
)
3333
agent = ROS2SpatioTemporalAgent(config)
3434
return agent

src/rai_core/rai/agents/spatiotemporal/spatiotemporal_agent.py

+47-42
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
import json
1616
import logging
1717
import time
18+
import uuid
1819
from abc import abstractmethod
1920
from typing import Annotated, Any, Dict, List, Optional, cast
2021

21-
from langchain_core.embeddings import Embeddings
2222
from langchain_core.language_models import BaseChatModel
2323
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
24+
from langchain_core.vectorstores import VectorStore
2425
from pydantic import BaseModel, ConfigDict, Field
2526
from pymongo import MongoClient
2627

@@ -60,12 +61,12 @@ class PoseStamped(BaseModel):
6061

6162

6263
class SpatioTemporalData(BaseModel):
64+
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
6365
timestamp: Annotated[float, "timestamp"]
6466
images: Dict[Annotated[str, "camera topic"], str] = Field(repr=False)
6567
tf: Optional[PoseStamped]
6668
temporal_context: Annotated[str, "compressed history of messages"]
6769
image_text_descriptions: Annotated[str, "text descriptions of images"]
68-
embeddings: List[float] = Field(default_factory=list, repr=False)
6970

7071

7172
class SpatioTemporalConfig(BaseModel):
@@ -74,7 +75,7 @@ class SpatioTemporalConfig(BaseModel):
7475
collection_name: str
7576
image_to_text_model: BaseChatModel
7677
context_compression_model: BaseChatModel
77-
embeddings: Embeddings
78+
vector_db: VectorStore
7879
time_interval: float
7980

8081
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -104,10 +105,9 @@ def __init__(
104105

105106
self.db = MongoClient(self.config.db_url)[self.config.db_name] # type: ignore
106107
self.collection = self.db[self.config.collection_name] # type: ignore
107-
self._initialize_embeddings_search_index()
108108
self.logger = logging.getLogger(__name__)
109109

110-
def insert_into_db(self, data: SpatioTemporalData):
110+
def _insert_into_db(self, data: SpatioTemporalData):
111111
"""
112112
Insert spatiotemporal data into the database.
113113
@@ -121,20 +121,51 @@ def insert_into_db(self, data: SpatioTemporalData):
121121
)
122122
self.collection.insert_one(data.model_dump()) # type: ignore
123123

124+
def _insert_into_vectorstore(self, data: SpatioTemporalData):
125+
"""
126+
Insert embeddings of the spatiotemporal data into the vector store.
127+
128+
Parameters
129+
----------
130+
data : SpatioTemporalData
131+
The spatiotemporal data to be inserted.
132+
"""
133+
self.logger.info("Inserting embeddings into vector store")
134+
135+
print(
136+
self.config.vector_db.add_texts(
137+
texts=[data.temporal_context + data.image_text_descriptions],
138+
metadatas=[{"id": data.id}],
139+
ids=[data.id],
140+
)
141+
)
142+
124143
def run(self):
125144
"""
126145
Run the agent in a loop, executing tasks at specified intervals.
127146
"""
128147
while True:
129-
ts = time.time()
130-
self.logger.info("Starting new interval")
131-
self.on_interval()
132-
te = time.time()
133-
if te - ts > self.config.time_interval:
134-
self.logger.warning(
135-
f"Time interval exceeded. Expected {self.config.time_interval:.2f}s, got {te - ts:.2f}s"
136-
)
137-
time.sleep(max(0, self.config.time_interval - (te - ts)))
148+
try:
149+
ts = time.time()
150+
self.logger.info("Starting new interval")
151+
self.on_interval()
152+
te = time.time()
153+
if te - ts > self.config.time_interval:
154+
self.logger.warning(
155+
f"Time interval exceeded. Expected {self.config.time_interval:.2f}s, got {te - ts:.2f}s"
156+
)
157+
time.sleep(max(0, self.config.time_interval - (te - ts)))
158+
except KeyboardInterrupt:
159+
# seriously hacky
160+
from langchain_community.vectorstores import FAISS
161+
162+
from rai.utils.model_initialization import load_config
163+
164+
self.logger.info("Saving vector store")
165+
166+
config = load_config()
167+
cast(FAISS, self.config.vector_db).save_local(config.vectorstore.uri)
168+
raise
138169

139170
def on_interval(self):
140171
"""
@@ -154,41 +185,15 @@ def on_interval(self):
154185
self.logger.info("Retrieving temporal context")
155186
temporal_context = self._get_robots_history()
156187

157-
embedding_text = temporal_context + str(image_text_descriptions.values())
158-
self.logger.info("Embedding text")
159-
embeddings = self._embed_text(embedding_text)
160-
161188
data = SpatioTemporalData(
162189
timestamp=time.time(),
163190
images=images,
164191
tf=tf,
165192
temporal_context=temporal_context,
166193
image_text_descriptions=json.dumps(image_text_descriptions),
167-
embeddings=embeddings,
168-
)
169-
self.insert_into_db(data)
170-
171-
def _initialize_embeddings_search_index(self):
172-
self.collection.create_search_index(
173-
{
174-
"definition": {
175-
"mappings": {
176-
"dynamic": True,
177-
"fields": {
178-
EMBEDDINGS_FIELD_NAME: {
179-
"dimensions": 1536,
180-
"similarity": "dotProduct",
181-
"type": "knnVector",
182-
},
183-
},
184-
},
185-
},
186-
"name": SEARCH_INDEX_NAME,
187-
}
188194
)
189-
190-
def _embed_text(self, text: str) -> List[float]:
191-
return self.config.embeddings.embed_query(text)
195+
self._insert_into_db(data)
196+
self._insert_into_vectorstore(data)
192197

193198
@abstractmethod
194199
def _get_images(

src/rai_core/rai/tools/spatiotemporal/spatiotemporal.py

+13-23
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,19 @@
1313
# limitations under the License.
1414

1515
from datetime import datetime
16-
from typing import Any, List, Type
16+
from typing import Any, Dict, List, Type
1717

18-
from langchain_core.embeddings import Embeddings
1918
from langchain_core.tools import BaseTool, BaseToolkit
19+
from langchain_core.vectorstores import VectorStore
2020
from pydantic import BaseModel, ConfigDict, Field
2121
from pymongo.collection import Collection
2222
from pymongo.mongo_client import MongoClient
2323

2424
from rai.agents.spatiotemporal.spatiotemporal_agent import (
25-
EMBEDDINGS_FIELD_NAME,
26-
SEARCH_INDEX_NAME,
2725
Pose,
2826
SpatioTemporalData,
2927
)
3028
from rai.agents.tool_runner import MultimodalArtifact
31-
from rai.utils.model_initialization import get_embeddings_model
3229

3330

3431
class SpatiotemporalToolkit(BaseToolkit):
@@ -38,6 +35,7 @@ class SpatiotemporalToolkit(BaseToolkit):
3835
mongodb_url: str = Field(default="mongodb://localhost:27017/")
3936
mongodb_db_name: str = Field(default="rai")
4037
mongodb_collection_name: str = Field(default="spatiotemporal_collection")
38+
vectorstore: VectorStore
4139

4240
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
4341

@@ -57,7 +55,7 @@ def get_tools(self) -> list[BaseTool]:
5755
collection=self.collection,
5856
),
5957
GetMemoriesOfObjectTool(
60-
collection=self.collection,
58+
collection=self.collection, vectorstore=self.vectorstore
6159
),
6260
]
6361

@@ -205,28 +203,20 @@ class GetMemoriesOfObjectTool(BaseTool):
205203
description: str = "Get the past memories of the robot of a specific object"
206204
args_schema: Type[GetMemoriesOfObjectToolInput] = GetMemoriesOfObjectToolInput
207205
collection: Collection[Any]
208-
embeddings: Embeddings = Field(default_factory=lambda: get_embeddings_model())
206+
vectorstore: VectorStore
209207

210208
response_model: str = "content_and_artifact"
211209

212210
def _run(self, object_name: str, n_results: int = 5):
213-
results = list(
214-
self.collection.aggregate(
215-
[
216-
{
217-
"$vectorSearch": {
218-
"index": SEARCH_INDEX_NAME,
219-
"path": EMBEDDINGS_FIELD_NAME,
220-
"queryVector": self.embeddings.embed_query(object_name),
221-
"numCandidates": 200,
222-
"limit": n_results,
223-
}
224-
}
225-
]
226-
),
227-
)
211+
documents = self.vectorstore.similarity_search(object_name, k=n_results)
212+
mongodb_data: List[Dict[str, Any]] = []
213+
for document in documents:
214+
id = document.id
215+
record = self.collection.find_one({"id": id})
216+
if record is not None:
217+
mongodb_data.append(record)
228218
images: List[str] = []
229-
parsed_results = list(map(SpatioTemporalData.model_validate, results))
219+
parsed_results = list(map(SpatioTemporalData.model_validate, mongodb_data))
230220
for result in parsed_results:
231221
for image in result.images.values():
232222
images.append(image)

src/rai_core/rai/utils/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from .model_initialization import get_embeddings_model, get_llm_model, get_vectorstore
16+
17+
__all__ = ["get_embeddings_model", "get_llm_model", "get_vectorstore"]

src/rai_core/rai/utils/model_initialization.py

+31
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,20 @@ class TracingConfig:
7373
langsmith: LangsmithConfig
7474

7575

76+
@dataclass
77+
class VectorStoreConfig:
78+
type: str
79+
uri: str
80+
81+
7682
@dataclass
7783
class RAIConfig:
7884
vendor: VendorConfig
7985
aws: AWSConfig
8086
openai: OpenAIConfig
8187
ollama: OllamaConfig
8288
tracing: TracingConfig
89+
vectorstore: VectorStoreConfig
8390

8491

8592
def load_config() -> RAIConfig:
@@ -95,6 +102,7 @@ def load_config() -> RAIConfig:
95102
langfuse=LangfuseConfig(**config_dict["tracing"]["langfuse"]),
96103
langsmith=LangsmithConfig(**config_dict["tracing"]["langsmith"]),
97104
),
105+
vectorstore=VectorStoreConfig(**config_dict["vectorstore"]),
98106
)
99107

100108

@@ -167,6 +175,29 @@ def get_embeddings_model(vendor: str = None):
167175
raise ValueError(f"Unknown embeddings vendor: {vendor}")
168176

169177

178+
def get_vectorstore():
179+
config = load_config()
180+
logger.info(
181+
f"Initializing vector store: {config.vectorstore.type} in {config.vectorstore.uri}"
182+
)
183+
if config.vectorstore.type == "faiss":
184+
from langchain_community.vectorstores import FAISS
185+
186+
if os.path.exists(config.vectorstore.uri):
187+
print("I EXIST")
188+
return FAISS.load_local(
189+
config.vectorstore.uri,
190+
embeddings=get_embeddings_model(),
191+
allow_dangerous_deserialization=True,
192+
)
193+
else:
194+
index = FAISS.from_texts(["empty"], embedding=get_embeddings_model())
195+
index.save_local(config.vectorstore.uri)
196+
return index
197+
else:
198+
raise ValueError(f"Unknown vector store type: {config.vectorstore.type}")
199+
200+
170201
def get_tracing_callbacks(
171202
override_use_langfuse: bool = False, override_use_langsmith: bool = False
172203
) -> List[BaseCallbackHandler]:

0 commit comments

Comments
 (0)