From 66b621c0076a1b99d0a6187321fa8524a513d1dd Mon Sep 17 00:00:00 2001 From: zhangyuqin Date: Thu, 17 Apr 2025 20:11:19 +0800 Subject: [PATCH] [Fix] fix vpp fthenb in dynamic shape mode --- .../fleet/meta_parallel/pipeline_parallel.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 2b5f05ca38c13..2a819b153d5fb 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -2708,11 +2708,17 @@ def forward_backward_pipeline( ) self._record_stamp("F", forward_micro_step_id, '"E"', forward=True) - output_tensor_grad = self._p2p_helper.send_forward_recv_backward( + # NOTE: `send_forward_recv_backward` is intentionally unused to + # prevent hanging bugs in dynamic shape mode. + self._p2p_helper.send_forward( output_tensor, self.is_pipeline_last_stage(ignore_virtual=True), batch_p2p_comm=self._use_batch_p2p_comm, ) + output_tensor_grad = self._p2p_helper.recv_backward( + self.is_pipeline_last_stage(ignore_virtual=True), + batch_p2p_comm=self._use_batch_p2p_comm, + ) # Unlike normal FthenB, in 1F1B steps, we recv output_tensor_grad # for the current step, but not for the next step cur_backward_virtual_pp_rank = self._get_virtual_pp_rank( @@ -2736,7 +2742,13 @@ def forward_backward_pipeline( backward_send_recv_buffer_queue.put(input_tensor_grad) if not last_iter: - input_tensor = self._p2p_helper.send_backward_recv_forward( + # NOTE: `send_backward_recv_forward` is intentionally unused to + # prevent hanging bugs in dynamic shape mode. + input_tensor = self._p2p_helper.recv_forward( + self.is_pipeline_first_stage(ignore_virtual=True), + batch_p2p_comm=self._use_batch_p2p_comm, + ) + self._p2p_helper.send_backward( input_tensor_grad, self.is_pipeline_first_stage(ignore_virtual=True), batch_p2p_comm=self._use_batch_p2p_comm,