From d949a5c165d223f5b564e2c7dcc17f3b00561d8d Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 13 Aug 2022 09:50:30 +0100 Subject: [PATCH] adapt vmap for tensordict --- functorch/_src/vmap.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/functorch/_src/vmap.py b/functorch/_src/vmap.py index df5755fbc..c8a5cc449 100644 --- a/functorch/_src/vmap.py +++ b/functorch/_src/vmap.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch +from torchrl.data.tensordict.tensordict import TensorDictBase import functools from torch import Tensor from typing import Any, Callable, Optional, Tuple, Union, List @@ -84,7 +85,7 @@ def _process_batched_inputs( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' f'Got in_dim={in_dim} for an input but in_dim must be either ' f'an integer dimension or None.') - if isinstance(in_dim, int) and not isinstance(arg, Tensor): + if isinstance(in_dim, int) and not isinstance(arg, (Tensor, TensorDictBase)): raise ValueError( f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' f'Got in_dim={in_dim} for an input but the input is of type ' @@ -109,6 +110,8 @@ def _create_batched_inputs( flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple: # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] batched_inputs = [arg if in_dim is None else + arg.apply(lambda _arg: _add_batch_dim(_arg, in_dim, vmap_level)) + if isinstance(arg, TensorDictBase) else _add_batch_dim(arg, in_dim, vmap_level) for in_dim, arg in zip(flat_in_dims, flat_args)] return tree_unflatten(batched_inputs, args_spec)