Skip to content

Commit 1af1ae2

Browse files
authored
Fix getitem (#364)
Fixes #363 This PR: - adds a batch rule for _index_put_impl_ - fixes the index_put_ batch rule - adds a new OpInfo so we can actually test this - fixes the fallback paths to error out on Tensor?[], otherwise they are very wrong.
1 parent 3ba93d3 commit 1af1ae2

File tree

5 files changed

+97
-6
lines changed

5 files changed

+97
-6
lines changed

functorch/csrc/BatchRulesScatterOps.cpp

+50-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ void index_put__batch_rule(
126126
auto values_ = moveBatchDimToFront(values, values_bdim);
127127
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
128128
std::vector<optional<Tensor>> indices_ = batchIndices(indices, indices_bdims, self_.size(0), self_bdim, values_bdim);
129-
at::index_put_(self_, List<optional<Tensor>>(indices_), values, accumulate);
129+
at::index_put_(self_, List<optional<Tensor>>(indices_), values_, accumulate);
130130
}
131131

132132
// plumbing done since we don't support List<optional<Tensor>> in codegen
@@ -158,6 +158,54 @@ Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indice
158158
return self;
159159
}
160160

161+
void _index_put_impl__batch_rule(
162+
Tensor& self,
163+
optional<int64_t> self_bdim,
164+
ArrayRef<optional<Tensor>> indices,
165+
ArrayRef<optional<int64_t>> indices_bdims,
166+
const Tensor& values,
167+
optional<int64_t> values_bdim,
168+
bool accumulate,
169+
bool unsafe) {
170+
if (!self_bdim.has_value()) {
171+
vmapIncompatibleInplaceError("_index_put_impl_");
172+
}
173+
auto self_ = moveBatchDimToFront(self, self_bdim);
174+
auto values_ = moveBatchDimToFront(values, values_bdim);
175+
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
176+
std::vector<optional<Tensor>> indices_ = batchIndices(indices, indices_bdims, self_.size(0), self_bdim, values_bdim);
177+
at::_index_put_impl_(self_, List<optional<Tensor>>(indices_), values_, accumulate, unsafe);
178+
}
179+
180+
// plumbing done since we don't support List<optional<Tensor>> in codegen
181+
Tensor& _index_put_impl__plumbing(Tensor & self, const List<optional<Tensor>> & indices
182+
, const Tensor & values, bool accumulate, bool unsafe) {
183+
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
184+
auto maybe_layer = maybeCurrentDynamicLayer();
185+
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
186+
int64_t cur_level = maybe_layer->layerId();
187+
Tensor self_value;
188+
optional<int64_t> self_bdim;
189+
std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
190+
std::vector<optional<Tensor>> indices_value;
191+
std::vector<optional<int64_t>> indices_bdims;
192+
for (const auto&& indRef : indices) {
193+
optional<Tensor> ind = indRef;
194+
optional<Tensor> index;
195+
optional<int64_t> index_bdim;
196+
if (ind.has_value()) {
197+
std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level);
198+
}
199+
indices_value.push_back(index);
200+
indices_bdims.push_back(index_bdim);
201+
}
202+
Tensor values_value;
203+
optional<int64_t> values_bdim;
204+
std::tie(values_value, values_bdim) = unwrapTensorAtLevel(values, cur_level);
205+
_index_put_impl__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate, unsafe);
206+
return self;
207+
}
208+
161209
namespace {
162210

163211
template<typename Func, typename ...Args>
@@ -496,6 +544,7 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
496544
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
497545
m.impl("index.Tensor", index_plumbing);
498546
m.impl("index_put_", index_put__plumbing);
547+
m.impl("_index_put_impl_", _index_put_impl__plumbing);
499548
m.impl("slice_scatter", slice_scatter_decomp);
500549
m.impl("select_scatter", select_scatter_decomp);
501550
m.impl("index_copy", index_copy_decomp);

functorch/csrc/BatchedFallback.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ static bool areAnyArgumentsTensorList(const at::FunctionSchema& schema) {
6767
return std::any_of(
6868
schema.arguments().begin(),
6969
schema.arguments().end(),
70-
[] (const Argument& arg) { return arg.type()->isSubtypeOf(ListType::ofTensors()); });
70+
[] (const Argument& arg) {
71+
return arg.type()->isSubtypeOf(ListType::ofTensors()) ||
72+
arg.type()->isSubtypeOf(ListType::ofOptionalTensors());
73+
});
7174
}
7275

7376
static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {

test/functorch_additional_op_db.py

+39
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,42 @@ def generator():
258258
sample_inputs_func=sample_inputs_embedding,
259259
supports_out=False,
260260
))
261+
262+
263+
def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
264+
S = 5
265+
test_args = [
266+
([1, 2],),
267+
(slice(0, 3),),
268+
([slice(0, 3), 1],),
269+
([[0, 2, 3], [1, 3, 3], [0, 0, 2]],),
270+
([[0, 0, 3], [1, 1, 3], [0, 0, 2]],),
271+
([slice(None), slice(None), [0, 3]],),
272+
([slice(None), [0, 3], slice(None)],),
273+
([[0, 3], slice(None), slice(None)],),
274+
([[0, 3], [1, 2], slice(None)],),
275+
([[0, 3], ],),
276+
([[0, 3], slice(None)],),
277+
([[0, 3], Ellipsis],),
278+
# index_backward is not CompositeCompliant TODO.
279+
# ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],),
280+
]
281+
282+
return tuple(SampleInput(
283+
make_tensor((S, S, S), device, dtype, low=None, high=None, requires_grad=requires_grad),
284+
args=args)
285+
for args in test_args)
286+
287+
288+
# TODO: split PyTorch's __getitem__. The problem is we don't support indexing
289+
# with masks with vmap.
290+
additional_op_db.append(
291+
OpInfo('__getitem__',
292+
variant_test_name='functorch',
293+
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
294+
supports_out=False,
295+
supports_inplace_autograd=False,
296+
supports_scripting=False,
297+
op=torch.Tensor.__getitem__,
298+
assert_jit_shape_analysis=False, # TODO: support index.Tensor()
299+
sample_inputs_func=sample_inputs_getitem,))

test/test_ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def vjp_of_vjp(*args_and_cotangents):
549549
xfail('nn.functional.fractional_max_pool3d'),
550550
xfail('as_strided'),
551551
xfail('nn.functional.fractional_max_pool2d'),
552-
xfail('__getitem__'),
552+
xfail('__getitem__', ''),
553553
xfail('index_put'),
554554
xfail('lu_solve'),
555555
})
@@ -744,7 +744,7 @@ def test_vmapjvpall(self, device, dtype, op):
744744
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
745745
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({
746746
xfail('view_as_complex'),
747-
xfail('__getitem__'),
747+
xfail('__getitem__', ''),
748748
xfail('cholesky'),
749749
xfail('complex'),
750750
xfail('copysign'),
@@ -865,7 +865,7 @@ def test():
865865
# fallback path doesn't work
866866
xfail('H'),
867867
# All of the following are bugs and need to be fixed
868-
xfail('__getitem__'),
868+
xfail('__getitem__', ''),
869869
xfail('clamp', ''),
870870
xfail('dsplit'),
871871
xfail('fill_'),

test/test_vmap.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3219,7 +3219,7 @@ def test_vmap_exhaustive(self, device, dtype, op):
32193219
xfail('to_sparse'),
32203220
xfail('vdot'),
32213221
xfail('vsplit'),
3222-
xfail('__getitem__'),
3222+
xfail('__getitem__', ''),
32233223
xfail('all'),
32243224
xfail('any'),
32253225
xfail('count_nonzero'),

0 commit comments

Comments
 (0)