@@ -126,7 +126,7 @@ void index_put__batch_rule(
126
126
auto values_ = moveBatchDimToFront (values, values_bdim);
127
127
TORCH_INTERNAL_ASSERT (indices.size () == indices_bdims.size ());
128
128
std::vector<optional<Tensor>> indices_ = batchIndices (indices, indices_bdims, self_.size (0 ), self_bdim, values_bdim);
129
- at::index_put_ (self_, List<optional<Tensor>>(indices_), values , accumulate);
129
+ at::index_put_ (self_, List<optional<Tensor>>(indices_), values_ , accumulate);
130
130
}
131
131
132
132
// plumbing done since we don't support List<optional<Tensor>> in codegen
@@ -158,6 +158,54 @@ Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indice
158
158
return self;
159
159
}
160
160
161
+ void _index_put_impl__batch_rule (
162
+ Tensor& self,
163
+ optional<int64_t > self_bdim,
164
+ ArrayRef<optional<Tensor>> indices,
165
+ ArrayRef<optional<int64_t >> indices_bdims,
166
+ const Tensor& values,
167
+ optional<int64_t > values_bdim,
168
+ bool accumulate,
169
+ bool unsafe) {
170
+ if (!self_bdim.has_value ()) {
171
+ vmapIncompatibleInplaceError (" _index_put_impl_" );
172
+ }
173
+ auto self_ = moveBatchDimToFront (self, self_bdim);
174
+ auto values_ = moveBatchDimToFront (values, values_bdim);
175
+ TORCH_INTERNAL_ASSERT (indices.size () == indices_bdims.size ());
176
+ std::vector<optional<Tensor>> indices_ = batchIndices (indices, indices_bdims, self_.size (0 ), self_bdim, values_bdim);
177
+ at::_index_put_impl_ (self_, List<optional<Tensor>>(indices_), values_, accumulate, unsafe);
178
+ }
179
+
180
+ // plumbing done since we don't support List<optional<Tensor>> in codegen
181
+ Tensor& _index_put_impl__plumbing (Tensor & self, const List<optional<Tensor>> & indices
182
+ , const Tensor & values, bool accumulate, bool unsafe) {
183
+ c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
184
+ auto maybe_layer = maybeCurrentDynamicLayer ();
185
+ TORCH_INTERNAL_ASSERT (maybe_layer.has_value ());
186
+ int64_t cur_level = maybe_layer->layerId ();
187
+ Tensor self_value;
188
+ optional<int64_t > self_bdim;
189
+ std::tie (self_value, self_bdim) = unwrapTensorAtLevel (self, cur_level);
190
+ std::vector<optional<Tensor>> indices_value;
191
+ std::vector<optional<int64_t >> indices_bdims;
192
+ for (const auto && indRef : indices) {
193
+ optional<Tensor> ind = indRef;
194
+ optional<Tensor> index ;
195
+ optional<int64_t > index_bdim;
196
+ if (ind.has_value ()) {
197
+ std::tie (index , index_bdim) = unwrapTensorAtLevel (ind.value (), cur_level);
198
+ }
199
+ indices_value.push_back (index );
200
+ indices_bdims.push_back (index_bdim);
201
+ }
202
+ Tensor values_value;
203
+ optional<int64_t > values_bdim;
204
+ std::tie (values_value, values_bdim) = unwrapTensorAtLevel (values, cur_level);
205
+ _index_put_impl__batch_rule (self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate, unsafe);
206
+ return self;
207
+ }
208
+
161
209
namespace {
162
210
163
211
template <typename Func, typename ...Args>
@@ -496,6 +544,7 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
496
544
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
497
545
m.impl (" index.Tensor" , index_plumbing);
498
546
m.impl (" index_put_" , index_put__plumbing);
547
+ m.impl (" _index_put_impl_" , _index_put_impl__plumbing);
499
548
m.impl (" slice_scatter" , slice_scatter_decomp);
500
549
m.impl (" select_scatter" , select_scatter_decomp);
501
550
m.impl (" index_copy" , index_copy_decomp);
0 commit comments