From 0f2c6f9a8bfc0a58c9e677e122adeba9983982f0 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 18 Oct 2024 10:59:25 -0400 Subject: [PATCH 1/6] peer contained is modifier --- mttl/models/containers/base.py | 5 +++ mttl/models/containers/peer_container.py | 56 +++++++++++++++++++----- mttl/models/expert_model.py | 2 +- mttl/models/modifiers/mlp.py | 29 +----------- 4 files changed, 51 insertions(+), 41 deletions(-) diff --git a/mttl/models/containers/base.py b/mttl/models/containers/base.py index 1f1406f2f..6e7915742 100644 --- a/mttl/models/containers/base.py +++ b/mttl/models/containers/base.py @@ -35,7 +35,12 @@ def __init__(self, config, layer, selector=None): self.selector = selector or TaskNameSelector() self._default_expert_name = None self.expert_infos = {} + self.experts = nn.ModuleDict({}) + @property + def num_experts(self): + return len(self.experts) + @property def default_expert_name(self): return self._default_expert_name diff --git a/mttl/models/containers/peer_container.py b/mttl/models/containers/peer_container.py index 99a9a3045..b945eca0e 100644 --- a/mttl/models/containers/peer_container.py +++ b/mttl/models/containers/peer_container.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass + import torch from torch import nn @@ -7,14 +9,26 @@ MultiheadBatchSequenceExpertsAndWeightsSelectorOutput, ) from mttl.models.library.expert import Expert -from mttl.models.modifiers.mlp import PEERConfig, PEERModifier +from mttl.models.modifiers.base import Modifier, ModifierConfig + +# from mttl.models.modifiers.mlp import PEERConfig, PEERModifier # diff architectures name those layers differently DOWN_NAMES = ["fc1", "c_fc"] UP_NAMES = ["fc2", "c_proj"] -class PEERMLPContainer(ExpertContainer): +@dataclass +class PEERConfig(ModifierConfig): + n_heads: int = 8 + moe__num_experts: int = 100 + emb_dim: int = 128 + down_proj_layer: str = "fc1" + up_proj_layer: str = "fc2" + + +@Modifier.register("peer", config_cls=PEERConfig) +class PEERMLPContainer(ExpertContainer, Modifier): """ PEER layer from Mixture of A Million Experts (https://arxiv.org/pdf/2407.04153) @@ -33,7 +47,7 @@ def __init__( **kwargs, ): super().__init__(config, module) - self.num_experts = 0 + self._num_experts = 0 down_names = DOWN_NAMES + [ config.down_proj_layer ] # names of the up and down projection layers in the MLP block @@ -55,21 +69,25 @@ def __init__( self.dtype = next(self.layer.parameters()).dtype self.layer = nn.Identity() + self.expert_name = None self.layer.in_features = self.input_dim - self.experts = PEERModifier(config) + + @property + def num_experts(self): + return self._num_experts def initialize_experts(self, expert_config: PEERConfig) -> None: - self.num_experts = expert_config.moe_num_experts + self._num_experts = expert_config.moe__num_experts assert ( - self.num_experts**0.5 + self._num_experts**0.5 ).is_integer(), "Number of experts must be a square number" self.peer_weight_down_embed = nn.Embedding( - num_embeddings=self.num_experts, + num_embeddings=self._num_experts, embedding_dim=self.input_dim, dtype=self.dtype, ) self.peer_weight_up_embed = nn.Embedding( - num_embeddings=self.num_experts, + num_embeddings=self._num_experts, embedding_dim=self.output_dim, dtype=self.dtype, ) @@ -96,10 +114,24 @@ def add_expert(self, expert: Expert, **kwargs) -> None: return self.on_add_expert(expert, **kwargs) def on_add_expert(self, expert: Expert, **kwargs) -> None: + """ + 'initialize_experts' is called from here instead of __init__ to allow for laoding expert weights from expert object that is passed here + """ expert_config: PEERConfig = expert.expert_config - if self.num_experts == expert_config.moe_num_experts: + if self._num_experts == expert_config.moe__num_experts: raise ContainerFullException() self.initialize_experts(expert_config) - - def __getitem__(self, key): - pass + self.expert_infos[expert.name] = expert.expert_info + if expert.expert_weights: + self.load_state_dict(expert.expert_weights) + self.expert_name = expert.name + + def __getitem__(self, name): + if name != self.expert_name: + raise ValueError( + f"Expert with name {name} does not exist in this container." + ) + return self + + def __len__(self): + return self._num_experts diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 288861522..5aebfd71d 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -163,7 +163,7 @@ def experts_containers(self) -> List[ExpertContainer]: containers = [] for _, module in self.model.named_modules(): for _, child in dict(module.named_children()).items(): - if isinstance(child, ExpertContainer) and len(child.experts) > 0: + if isinstance(child, ExpertContainer) and child.num_experts > 0: containers.append(child) return containers diff --git a/mttl/models/modifiers/mlp.py b/mttl/models/modifiers/mlp.py index 8d2b4cf35..66ac3d1eb 100644 --- a/mttl/models/modifiers/mlp.py +++ b/mttl/models/modifiers/mlp.py @@ -68,31 +68,4 @@ def parallel_linear_forward( hidden_states = input[indices] hidden_states = mlp._modifier_forward(hidden_states) output.index_add_(0, indices, hidden_states.to(input.dtype)) - return mlps[0].layer(input) + output - - -@dataclass -class PEERConfig(ModifierConfig): - n_heads: int = 8 - moe_num_experts: int = 100 - emb_dim: int = 128 - down_proj_layer: str = "fc1" - up_proj_layer: str = "fc2" - - -@Modifier.register("peer", config_cls=PEERConfig) -class PEERModifier(Modifier): - """ - Peer modifier basically does nothing, the job is done in the container. - """ - - def __init__( - self, - config: PEERConfig, - **kwargs, - ): - super().__init__() - self.config = config - - def __len__(self): - return self.config.moe_num_experts + return mlps[0].layer(input) + output \ No newline at end of file From 2ff6d899aac94dc98393fd944d706b47cf8d5c8b Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 18 Oct 2024 11:55:20 -0400 Subject: [PATCH 2/6] make sure peer can be stored to library --- mttl/models/lightning/expert_module.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mttl/models/lightning/expert_module.py b/mttl/models/lightning/expert_module.py index 5d2918588..dc274ded8 100644 --- a/mttl/models/lightning/expert_module.py +++ b/mttl/models/lightning/expert_module.py @@ -24,6 +24,14 @@ class LightningTrainingMixin: + + @property + def experts_names(self): + return self.model.experts_names + + def get_expert_instance(self, name): + return self.model.get_expert_instance(name) + @property def _log_pref(self): return getattr(self.hparams, "logging_prefix", "") @@ -215,7 +223,7 @@ def load_from_checkpoint( return model -class MoEModule(LightningEfficientCheckpoint, LightningTrainingMixin): +class MoEModule(LightningTrainingMixin, LightningEfficientCheckpoint): def __init__( self, model_object: PreTrainedModel = None, @@ -269,7 +277,7 @@ def load_from_checkpoint( ) model.load_state_dict(ckpt["state_dict"], strict=False) return model - + def training_step(self, batch, _): output, context = self.forward(**batch, return_context=True) loss = output.loss From 8abb4de4cf47cb001ae5a2cfb95db3fcaa5a0a8c Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 18 Oct 2024 11:57:29 -0400 Subject: [PATCH 3/6] black --- mttl/models/modifiers/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mttl/models/modifiers/mlp.py b/mttl/models/modifiers/mlp.py index 66ac3d1eb..2ce59fef3 100644 --- a/mttl/models/modifiers/mlp.py +++ b/mttl/models/modifiers/mlp.py @@ -68,4 +68,4 @@ def parallel_linear_forward( hidden_states = input[indices] hidden_states = mlp._modifier_forward(hidden_states) output.index_add_(0, indices, hidden_states.to(input.dtype)) - return mlps[0].layer(input) + output \ No newline at end of file + return mlps[0].layer(input) + output From 9082f0a1cdcfb3575650f89a3a03d49c2547e685 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 18 Oct 2024 12:01:46 -0400 Subject: [PATCH 4/6] black --- mttl/models/containers/base.py | 2 +- mttl/models/lightning/expert_module.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mttl/models/containers/base.py b/mttl/models/containers/base.py index 6e7915742..94aaeaddb 100644 --- a/mttl/models/containers/base.py +++ b/mttl/models/containers/base.py @@ -40,7 +40,7 @@ def __init__(self, config, layer, selector=None): @property def num_experts(self): return len(self.experts) - + @property def default_expert_name(self): return self._default_expert_name diff --git a/mttl/models/lightning/expert_module.py b/mttl/models/lightning/expert_module.py index dc274ded8..6d5da6b63 100644 --- a/mttl/models/lightning/expert_module.py +++ b/mttl/models/lightning/expert_module.py @@ -24,14 +24,14 @@ class LightningTrainingMixin: - + @property def experts_names(self): return self.model.experts_names - + def get_expert_instance(self, name): return self.model.get_expert_instance(name) - + @property def _log_pref(self): return getattr(self.hparams, "logging_prefix", "") @@ -277,7 +277,7 @@ def load_from_checkpoint( ) model.load_state_dict(ckpt["state_dict"], strict=False) return model - + def training_step(self, batch, _): output, context = self.forward(**batch, return_context=True) loss = output.loss From 8937c66da7ec667981c2579c3b7152bbccbac2bc Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 18 Oct 2024 13:09:21 -0400 Subject: [PATCH 5/6] save and load peer expert --- mttl/models/containers/peer_container.py | 4 +--- projects/modular_llm/eval_library.py | 12 ++++++++++-- projects/modular_llm/train_experts.py | 4 +++- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/mttl/models/containers/peer_container.py b/mttl/models/containers/peer_container.py index b945eca0e..e2fb29418 100644 --- a/mttl/models/containers/peer_container.py +++ b/mttl/models/containers/peer_container.py @@ -115,15 +115,13 @@ def add_expert(self, expert: Expert, **kwargs) -> None: def on_add_expert(self, expert: Expert, **kwargs) -> None: """ - 'initialize_experts' is called from here instead of __init__ to allow for laoding expert weights from expert object that is passed here + 'initialize_experts' is called from here """ expert_config: PEERConfig = expert.expert_config if self._num_experts == expert_config.moe__num_experts: raise ContainerFullException() self.initialize_experts(expert_config) self.expert_infos[expert.name] = expert.expert_info - if expert.expert_weights: - self.load_state_dict(expert.expert_weights) self.expert_name = expert.name def __getitem__(self, name): diff --git a/projects/modular_llm/eval_library.py b/projects/modular_llm/eval_library.py index 8ad43c300..d5152ad76 100644 --- a/projects/modular_llm/eval_library.py +++ b/projects/modular_llm/eval_library.py @@ -3,9 +3,9 @@ from copy import deepcopy import torch -import wandb from pytorch_lightning import seed_everything +import wandb from mttl.arguments import EvaluationConfig from mttl.datamodule.base import get_datamodule from mttl.evaluators.base import EvaluatorRunner, setup_evaluators @@ -20,7 +20,11 @@ WeightedLinearMergeConfig, ) from mttl.models.lightning.callbacks import LossCallback -from mttl.models.lightning.expert_module import ExpertModule, MultiExpertModule +from mttl.models.lightning.expert_module import ( + ExpertModule, + MoEModule, + MultiExpertModule, +) from mttl.models.modifiers.lora import LoRAConfig from mttl.utils import remote_login @@ -186,6 +190,10 @@ def run_eval(args: EvaluationConfig): module = MultiExpertModule(**vars(expert.training_config)).to("cuda") module.add_expert_instance(expert, is_default=True) + elif args.merge_or_route in ["peer"]: + module: MoEModule = MoEModule(**vars(an_expert.training_config)).to("cuda") + module.model.model.load_state_dict(an_expert.expert_weights, strict=False) + elif args.merge_or_route == "uniform_lora_after_op": # Here we merge the LoRA experts after the outer product we cannot really do it # with the lib transform, cause this would require storing large matrices in memory diff --git a/projects/modular_llm/train_experts.py b/projects/modular_llm/train_experts.py index bb73ca7fe..025e6e493 100644 --- a/projects/modular_llm/train_experts.py +++ b/projects/modular_llm/train_experts.py @@ -7,6 +7,7 @@ from mttl.arguments import Args, ExpertConfig from mttl.datamodule.base import get_datamodule from mttl.logging import logger, setup_logging +from mttl.models.library.expert import Expert from mttl.models.library.expert_library import ExpertLibrary from mttl.models.lightning.callbacks import ( DownstreamEvalCallback, @@ -200,7 +201,8 @@ def upload_library(expert_library, module): if isinstance(module, MoEModule): with expert_library.batched_commit(): for expert_name in module.experts_names: - expert = module.get_expert_instance(expert_name) + expert: Expert = module.get_expert_instance(expert_name) + expert.expert_info.training_config = args expert_library.add_expert(expert, expert_name) elif isinstance(module, ExpertModule): expert = module.as_expert() From 002cb280eac19849ba6514f63a2207e2567d6756 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 18 Oct 2024 13:11:08 -0400 Subject: [PATCH 6/6] comment --- mttl/models/containers/peer_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mttl/models/containers/peer_container.py b/mttl/models/containers/peer_container.py index e2fb29418..6be1ae7df 100644 --- a/mttl/models/containers/peer_container.py +++ b/mttl/models/containers/peer_container.py @@ -115,7 +115,7 @@ def add_expert(self, expert: Expert, **kwargs) -> None: def on_add_expert(self, expert: Expert, **kwargs) -> None: """ - 'initialize_experts' is called from here + 'initialize_experts' is called from here in order not to break logic in expert model """ expert_config: PEERConfig = expert.expert_config if self._num_experts == expert_config.moe__num_experts: