15
15
import json
16
16
import logging
17
17
import time
18
+ import uuid
18
19
from abc import abstractmethod
19
20
from typing import Annotated , Any , Dict , List , Optional , cast
20
21
21
- from langchain_core .embeddings import Embeddings
22
22
from langchain_core .language_models import BaseChatModel
23
23
from langchain_core .messages import AIMessage , BaseMessage , HumanMessage , SystemMessage
24
+ from langchain_core .vectorstores import VectorStore
24
25
from pydantic import BaseModel , ConfigDict , Field
25
26
from pymongo import MongoClient
26
27
@@ -60,12 +61,12 @@ class PoseStamped(BaseModel):
60
61
61
62
62
63
class SpatioTemporalData (BaseModel ):
64
+ id : str = Field (default_factory = lambda : str (uuid .uuid4 ()))
63
65
timestamp : Annotated [float , "timestamp" ]
64
66
images : Dict [Annotated [str , "camera topic" ], str ] = Field (repr = False )
65
67
tf : Optional [PoseStamped ]
66
68
temporal_context : Annotated [str , "compressed history of messages" ]
67
69
image_text_descriptions : Annotated [str , "text descriptions of images" ]
68
- embeddings : List [float ] = Field (default_factory = list , repr = False )
69
70
70
71
71
72
class SpatioTemporalConfig (BaseModel ):
@@ -74,7 +75,7 @@ class SpatioTemporalConfig(BaseModel):
74
75
collection_name : str
75
76
image_to_text_model : BaseChatModel
76
77
context_compression_model : BaseChatModel
77
- embeddings : Embeddings
78
+ vector_db : VectorStore
78
79
time_interval : float
79
80
80
81
model_config = ConfigDict (arbitrary_types_allowed = True )
@@ -104,10 +105,9 @@ def __init__(
104
105
105
106
self .db = MongoClient (self .config .db_url )[self .config .db_name ] # type: ignore
106
107
self .collection = self .db [self .config .collection_name ] # type: ignore
107
- self ._initialize_embeddings_search_index ()
108
108
self .logger = logging .getLogger (__name__ )
109
109
110
- def insert_into_db (self , data : SpatioTemporalData ):
110
+ def _insert_into_db (self , data : SpatioTemporalData ):
111
111
"""
112
112
Insert spatiotemporal data into the database.
113
113
@@ -121,20 +121,51 @@ def insert_into_db(self, data: SpatioTemporalData):
121
121
)
122
122
self .collection .insert_one (data .model_dump ()) # type: ignore
123
123
124
+ def _insert_into_vectorstore (self , data : SpatioTemporalData ):
125
+ """
126
+ Insert embeddings of the spatiotemporal data into the vector store.
127
+
128
+ Parameters
129
+ ----------
130
+ data : SpatioTemporalData
131
+ The spatiotemporal data to be inserted.
132
+ """
133
+ self .logger .info ("Inserting embeddings into vector store" )
134
+
135
+ print (
136
+ self .config .vector_db .add_texts (
137
+ texts = [data .temporal_context + data .image_text_descriptions ],
138
+ metadatas = [{"id" : data .id }],
139
+ ids = [data .id ],
140
+ )
141
+ )
142
+
124
143
def run (self ):
125
144
"""
126
145
Run the agent in a loop, executing tasks at specified intervals.
127
146
"""
128
147
while True :
129
- ts = time .time ()
130
- self .logger .info ("Starting new interval" )
131
- self .on_interval ()
132
- te = time .time ()
133
- if te - ts > self .config .time_interval :
134
- self .logger .warning (
135
- f"Time interval exceeded. Expected { self .config .time_interval :.2f} s, got { te - ts :.2f} s"
136
- )
137
- time .sleep (max (0 , self .config .time_interval - (te - ts )))
148
+ try :
149
+ ts = time .time ()
150
+ self .logger .info ("Starting new interval" )
151
+ self .on_interval ()
152
+ te = time .time ()
153
+ if te - ts > self .config .time_interval :
154
+ self .logger .warning (
155
+ f"Time interval exceeded. Expected { self .config .time_interval :.2f} s, got { te - ts :.2f} s"
156
+ )
157
+ time .sleep (max (0 , self .config .time_interval - (te - ts )))
158
+ except KeyboardInterrupt :
159
+ # seriously hacky
160
+ from langchain_community .vectorstores import FAISS
161
+
162
+ from rai .utils .model_initialization import load_config
163
+
164
+ self .logger .info ("Saving vector store" )
165
+
166
+ config = load_config ()
167
+ cast (FAISS , self .config .vector_db ).save_local (config .vectorstore .uri )
168
+ raise
138
169
139
170
def on_interval (self ):
140
171
"""
@@ -154,41 +185,15 @@ def on_interval(self):
154
185
self .logger .info ("Retrieving temporal context" )
155
186
temporal_context = self ._get_robots_history ()
156
187
157
- embedding_text = temporal_context + str (image_text_descriptions .values ())
158
- self .logger .info ("Embedding text" )
159
- embeddings = self ._embed_text (embedding_text )
160
-
161
188
data = SpatioTemporalData (
162
189
timestamp = time .time (),
163
190
images = images ,
164
191
tf = tf ,
165
192
temporal_context = temporal_context ,
166
193
image_text_descriptions = json .dumps (image_text_descriptions ),
167
- embeddings = embeddings ,
168
- )
169
- self .insert_into_db (data )
170
-
171
- def _initialize_embeddings_search_index (self ):
172
- self .collection .create_search_index (
173
- {
174
- "definition" : {
175
- "mappings" : {
176
- "dynamic" : True ,
177
- "fields" : {
178
- EMBEDDINGS_FIELD_NAME : {
179
- "dimensions" : 1536 ,
180
- "similarity" : "dotProduct" ,
181
- "type" : "knnVector" ,
182
- },
183
- },
184
- },
185
- },
186
- "name" : SEARCH_INDEX_NAME ,
187
- }
188
194
)
189
-
190
- def _embed_text (self , text : str ) -> List [float ]:
191
- return self .config .embeddings .embed_query (text )
195
+ self ._insert_into_db (data )
196
+ self ._insert_into_vectorstore (data )
192
197
193
198
@abstractmethod
194
199
def _get_images (
0 commit comments