-
Notifications
You must be signed in to change notification settings - Fork 17
Refactor PEER #123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Refactor PEER #123
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,11 @@ def __init__(self, config, layer, selector=None): | |
self.selector = selector or TaskNameSelector() | ||
self._default_expert_name = None | ||
self.expert_infos = {} | ||
self.experts = nn.ModuleDict({}) | ||
|
||
@property | ||
def num_experts(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's use len(self) |
||
return len(self.experts) | ||
|
||
@property | ||
def default_expert_name(self): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
@@ -7,14 +9,26 @@ | |
MultiheadBatchSequenceExpertsAndWeightsSelectorOutput, | ||
) | ||
from mttl.models.library.expert import Expert | ||
from mttl.models.modifiers.mlp import PEERConfig, PEERModifier | ||
from mttl.models.modifiers.base import Modifier, ModifierConfig | ||
|
||
# from mttl.models.modifiers.mlp import PEERConfig, PEERModifier | ||
|
||
# diff architectures name those layers differently | ||
DOWN_NAMES = ["fc1", "c_fc"] | ||
UP_NAMES = ["fc2", "c_proj"] | ||
|
||
|
||
class PEERMLPContainer(ExpertContainer): | ||
@dataclass | ||
class PEERConfig(ModifierConfig): | ||
n_heads: int = 8 | ||
moe__num_experts: int = 100 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. two __ |
||
emb_dim: int = 128 | ||
down_proj_layer: str = "fc1" | ||
up_proj_layer: str = "fc2" | ||
|
||
|
||
@Modifier.register("peer", config_cls=PEERConfig) | ||
class PEERMLPContainer(ExpertContainer, Modifier): | ||
""" | ||
PEER layer from Mixture of A Million Experts (https://arxiv.org/pdf/2407.04153) | ||
|
||
|
@@ -33,7 +47,7 @@ def __init__( | |
**kwargs, | ||
): | ||
super().__init__(config, module) | ||
self.num_experts = 0 | ||
self._num_experts = 0 | ||
down_names = DOWN_NAMES + [ | ||
config.down_proj_layer | ||
] # names of the up and down projection layers in the MLP block | ||
|
@@ -55,21 +69,25 @@ def __init__( | |
self.dtype = next(self.layer.parameters()).dtype | ||
|
||
self.layer = nn.Identity() | ||
self.expert_name = None | ||
self.layer.in_features = self.input_dim | ||
self.experts = PEERModifier(config) | ||
|
||
@property | ||
def num_experts(self): | ||
return self._num_experts | ||
|
||
def initialize_experts(self, expert_config: PEERConfig) -> None: | ||
self.num_experts = expert_config.moe_num_experts | ||
self._num_experts = expert_config.moe__num_experts | ||
assert ( | ||
self.num_experts**0.5 | ||
self._num_experts**0.5 | ||
).is_integer(), "Number of experts must be a square number" | ||
self.peer_weight_down_embed = nn.Embedding( | ||
num_embeddings=self.num_experts, | ||
num_embeddings=self._num_experts, | ||
embedding_dim=self.input_dim, | ||
dtype=self.dtype, | ||
) | ||
self.peer_weight_up_embed = nn.Embedding( | ||
num_embeddings=self.num_experts, | ||
num_embeddings=self._num_experts, | ||
embedding_dim=self.output_dim, | ||
dtype=self.dtype, | ||
) | ||
|
@@ -96,10 +114,22 @@ def add_expert(self, expert: Expert, **kwargs) -> None: | |
return self.on_add_expert(expert, **kwargs) | ||
|
||
def on_add_expert(self, expert: Expert, **kwargs) -> None: | ||
""" | ||
'initialize_experts' is called from here in order not to break logic in expert model | ||
""" | ||
expert_config: PEERConfig = expert.expert_config | ||
if self.num_experts == expert_config.moe_num_experts: | ||
if self._num_experts == expert_config.moe__num_experts: | ||
raise ContainerFullException() | ||
self.initialize_experts(expert_config) | ||
|
||
def __getitem__(self, key): | ||
pass | ||
self.expert_infos[expert.name] = expert.expert_info | ||
self.expert_name = expert.name | ||
|
||
def __getitem__(self, name): | ||
if name != self.expert_name: | ||
raise ValueError( | ||
f"Expert with name {name} does not exist in this container." | ||
) | ||
return self | ||
|
||
def __len__(self): | ||
return self._num_experts |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,14 @@ | |
|
||
|
||
class LightningTrainingMixin: | ||
|
||
@property | ||
def experts_names(self): | ||
return self.model.experts_names | ||
|
||
def get_expert_instance(self, name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is only available for multiexpert, i am not sure to add this dependency here |
||
return self.model.get_expert_instance(name) | ||
|
||
@property | ||
def _log_pref(self): | ||
return getattr(self.hparams, "logging_prefix", "") | ||
|
@@ -215,7 +223,7 @@ def load_from_checkpoint( | |
return model | ||
|
||
|
||
class MoEModule(LightningEfficientCheckpoint, LightningTrainingMixin): | ||
class MoEModule(LightningTrainingMixin, LightningEfficientCheckpoint): | ||
def __init__( | ||
self, | ||
model_object: PreTrainedModel = None, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,9 +3,9 @@ | |
from copy import deepcopy | ||
|
||
import torch | ||
import wandb | ||
from pytorch_lightning import seed_everything | ||
|
||
import wandb | ||
from mttl.arguments import EvaluationConfig | ||
from mttl.datamodule.base import get_datamodule | ||
from mttl.evaluators.base import EvaluatorRunner, setup_evaluators | ||
|
@@ -20,7 +20,11 @@ | |
WeightedLinearMergeConfig, | ||
) | ||
from mttl.models.lightning.callbacks import LossCallback | ||
from mttl.models.lightning.expert_module import ExpertModule, MultiExpertModule | ||
from mttl.models.lightning.expert_module import ( | ||
ExpertModule, | ||
MoEModule, | ||
MultiExpertModule, | ||
) | ||
from mttl.models.modifiers.lora import LoRAConfig | ||
from mttl.utils import remote_login | ||
|
||
|
@@ -186,6 +190,10 @@ def run_eval(args: EvaluationConfig): | |
module = MultiExpertModule(**vars(expert.training_config)).to("cuda") | ||
module.add_expert_instance(expert, is_default=True) | ||
|
||
elif args.merge_or_route in ["peer"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain what you are trying to do? |
||
module: MoEModule = MoEModule(**vars(an_expert.training_config)).to("cuda") | ||
module.model.model.load_state_dict(an_expert.expert_weights, strict=False) | ||
|
||
elif args.merge_or_route == "uniform_lora_after_op": | ||
# Here we merge the LoRA experts after the outer product we cannot really do it | ||
# with the lib transform, cause this would require storing large matrices in memory | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
from mttl.arguments import Args, ExpertConfig | ||
from mttl.datamodule.base import get_datamodule | ||
from mttl.logging import logger, setup_logging | ||
from mttl.models.library.expert import Expert | ||
from mttl.models.library.expert_library import ExpertLibrary | ||
from mttl.models.lightning.callbacks import ( | ||
DownstreamEvalCallback, | ||
|
@@ -200,7 +201,8 @@ def upload_library(expert_library, module): | |
if isinstance(module, MoEModule): | ||
with expert_library.batched_commit(): | ||
for expert_name in module.experts_names: | ||
expert = module.get_expert_instance(expert_name) | ||
expert: Expert = module.get_expert_instance(expert_name) | ||
expert.expert_info.training_config = args | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this, not a good idea to store complex object in training_config, it will be transformed to a Dict in the next PR :) |
||
expert_library.add_expert(expert, expert_name) | ||
elif isinstance(module, ExpertModule): | ||
expert = module.as_expert() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not all containers have "experts", see my last PR on making LoRA faster