@@ -52,7 +52,7 @@ def __init__(
52
52
self .stepped = False
53
53
self .closed = False
54
54
55
- self .position = 0
55
+ self ._position = 0
56
56
self .history = None # Used in case of server failures to regenerate attention caches on new servers
57
57
self .next_session = None
58
58
@@ -102,11 +102,12 @@ def step(
102
102
n_input_tokens = inputs .shape [1 ]
103
103
if self .history is None :
104
104
self .history = inputs
105
- elif self .history .shape [1 ] == self .position :
105
+ elif self .history .shape [1 ] == self ._position :
106
106
self .history = torch .cat ([self .history , inputs [:, - n_input_tokens :]], dim = 1 )
107
- assert (
108
- self .history .shape [1 ] == self .position + n_input_tokens
109
- ), f"Broken input cache: { self .span = } { self .history .shape = } { self .position = } { n_input_tokens = } "
107
+ assert self .history .shape [1 ] == self ._position + n_input_tokens , (
108
+ f"Broken input cache: span={ self .span } shape={ self .history .shape } "
109
+ f"position={ self ._position } n_input_tokens={ n_input_tokens } "
110
+ )
110
111
111
112
if not self .stepped :
112
113
inputs = self .history # Pass full inputs including prefix
@@ -173,7 +174,7 @@ def step(
173
174
outputs [0 ].shape == inputs .shape
174
175
), f"output activation shape is different from input shape: { outputs [0 ].shape } != { inputs .shape } "
175
176
176
- self .position += n_input_tokens
177
+ self ._position += n_input_tokens
177
178
178
179
return outputs [0 ]
179
180
@@ -363,10 +364,6 @@ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) ->
363
364
# If there is a failed span, this code replaces it, otherwise it just adds new ones
364
365
if server_idx < n_prev_spans :
365
366
updated_sessions [0 ].history = self ._server_sessions [server_idx ].history
366
- updated_sessions [0 ].position = self ._position
367
- assert (
368
- updated_sessions [0 ].history .shape [1 ] == self ._position
369
- ), f"Broken input cache: { updated_sessions [0 ].history .shape = } { self ._position = } "
370
367
self ._server_sessions [server_idx : server_idx + 1 ] = updated_sessions
371
368
372
369
# Update links to the next server session for direct server-to-server communication via rpc_push()
0 commit comments