Skip to content

[Executorch][llama] Allow custom sdpa op replacement pass to leverage attention mask #10285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/kimishpatel/183/base
Choose a base branch
from

Conversation

kimishpatel
Copy link
Contributor

@kimishpatel kimishpatel commented Apr 17, 2025

Stack from ghstack (oldest at bottom):

Previously we assumed that the custom sdpa always does causal attention.
This diff adds option to this module swap pass to make custom sdpa leverage
attention mask instead of causal.

Differential Revision: D73222736

… attention mask

Previously we assumed that the custom sdpa always does causal attention.
This diff adds option to this module swap pass to make custom sdpa leverage
attention mask instead of causal.

Differential Revision: [D73222736](https://our.internmc.facebook.com/intern/diff/D73222736/)

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/10285

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 1 Pending, 1 Unrelated Failure

As of commit 06831fd with merge base 06f912d (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73222736

0, # dropout probability. Ignored by the code
True, # is_causal
)
if self.use_attention_mask:
Copy link
Contributor

@lucylq lucylq Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, maybe can move this up to where we handle if self.enable_dynamic_shape, and do something similar:

if not self.use_attention_mask:
    mask = None

And then have one call to custom_sdpa

output = torch.ops.llama.custom_sdpa(
                q,
                k,
                v,
                input_pos[0].item(),
                mask,  # Attention mask
                0,  # dropout probability. Ignored by the code
                False,  # is_causal
            )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that is_causal is True in one case, while not true in the other

@kimishpatel kimishpatel added the release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava label Apr 18, 2025
…to leverage attention mask"

Previously we assumed that the custom sdpa always does causal attention.
This diff adds option to this module swap pass to make custom sdpa leverage
attention mask instead of causal.

Differential Revision: [D73222736](https://our.internmc.facebook.com/intern/diff/D73222736/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73222736

…to leverage attention mask"

Previously we assumed that the custom sdpa always does causal attention.
This diff adds option to this module swap pass to make custom sdpa leverage
attention mask instead of causal.

Differential Revision: [D73222736](https://our.internmc.facebook.com/intern/diff/D73222736/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73222736

…to leverage attention mask"

Previously we assumed that the custom sdpa always does causal attention.
This diff adds option to this module swap pass to make custom sdpa leverage
attention mask instead of causal.

Differential Revision: [D73222736](https://our.internmc.facebook.com/intern/diff/D73222736/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73222736

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants