Skip to content

Refactor PEER #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mttl/models/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def __init__(self, config, layer, selector=None):
self.selector = selector or TaskNameSelector()
self._default_expert_name = None
self.expert_infos = {}
self.experts = nn.ModuleDict({})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not all containers have "experts", see my last PR on making LoRA faster


@property
def num_experts(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use len(self)

return len(self.experts)

@property
def default_expert_name(self):
Expand Down
54 changes: 42 additions & 12 deletions mttl/models/containers/peer_container.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import dataclass

import torch
from torch import nn

Expand All @@ -7,14 +9,26 @@
MultiheadBatchSequenceExpertsAndWeightsSelectorOutput,
)
from mttl.models.library.expert import Expert
from mttl.models.modifiers.mlp import PEERConfig, PEERModifier
from mttl.models.modifiers.base import Modifier, ModifierConfig

# from mttl.models.modifiers.mlp import PEERConfig, PEERModifier

# diff architectures name those layers differently
DOWN_NAMES = ["fc1", "c_fc"]
UP_NAMES = ["fc2", "c_proj"]


class PEERMLPContainer(ExpertContainer):
@dataclass
class PEERConfig(ModifierConfig):
n_heads: int = 8
moe__num_experts: int = 100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two __

emb_dim: int = 128
down_proj_layer: str = "fc1"
up_proj_layer: str = "fc2"


@Modifier.register("peer", config_cls=PEERConfig)
class PEERMLPContainer(ExpertContainer, Modifier):
"""
PEER layer from Mixture of A Million Experts (https://arxiv.org/pdf/2407.04153)

Expand All @@ -33,7 +47,7 @@ def __init__(
**kwargs,
):
super().__init__(config, module)
self.num_experts = 0
self._num_experts = 0
down_names = DOWN_NAMES + [
config.down_proj_layer
] # names of the up and down projection layers in the MLP block
Expand All @@ -55,21 +69,25 @@ def __init__(
self.dtype = next(self.layer.parameters()).dtype

self.layer = nn.Identity()
self.expert_name = None
self.layer.in_features = self.input_dim
self.experts = PEERModifier(config)

@property
def num_experts(self):
return self._num_experts

def initialize_experts(self, expert_config: PEERConfig) -> None:
self.num_experts = expert_config.moe_num_experts
self._num_experts = expert_config.moe__num_experts
assert (
self.num_experts**0.5
self._num_experts**0.5
).is_integer(), "Number of experts must be a square number"
self.peer_weight_down_embed = nn.Embedding(
num_embeddings=self.num_experts,
num_embeddings=self._num_experts,
embedding_dim=self.input_dim,
dtype=self.dtype,
)
self.peer_weight_up_embed = nn.Embedding(
num_embeddings=self.num_experts,
num_embeddings=self._num_experts,
embedding_dim=self.output_dim,
dtype=self.dtype,
)
Expand All @@ -96,10 +114,22 @@ def add_expert(self, expert: Expert, **kwargs) -> None:
return self.on_add_expert(expert, **kwargs)

def on_add_expert(self, expert: Expert, **kwargs) -> None:
"""
'initialize_experts' is called from here in order not to break logic in expert model
"""
expert_config: PEERConfig = expert.expert_config
if self.num_experts == expert_config.moe_num_experts:
if self._num_experts == expert_config.moe__num_experts:
raise ContainerFullException()
self.initialize_experts(expert_config)

def __getitem__(self, key):
pass
self.expert_infos[expert.name] = expert.expert_info
self.expert_name = expert.name

def __getitem__(self, name):
if name != self.expert_name:
raise ValueError(
f"Expert with name {name} does not exist in this container."
)
return self

def __len__(self):
return self._num_experts
2 changes: 1 addition & 1 deletion mttl/models/expert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def experts_containers(self) -> List[ExpertContainer]:
containers = []
for _, module in self.model.named_modules():
for _, child in dict(module.named_children()).items():
if isinstance(child, ExpertContainer) and len(child.experts) > 0:
if isinstance(child, ExpertContainer) and child.num_experts > 0:
containers.append(child)
return containers

Expand Down
10 changes: 9 additions & 1 deletion mttl/models/lightning/expert_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@


class LightningTrainingMixin:

@property
def experts_names(self):
return self.model.experts_names

def get_expert_instance(self, name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only available for multiexpert, i am not sure to add this dependency here

return self.model.get_expert_instance(name)

@property
def _log_pref(self):
return getattr(self.hparams, "logging_prefix", "")
Expand Down Expand Up @@ -215,7 +223,7 @@ def load_from_checkpoint(
return model


class MoEModule(LightningEfficientCheckpoint, LightningTrainingMixin):
class MoEModule(LightningTrainingMixin, LightningEfficientCheckpoint):
def __init__(
self,
model_object: PreTrainedModel = None,
Expand Down
27 changes: 0 additions & 27 deletions mttl/models/modifiers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,3 @@ def parallel_linear_forward(
hidden_states = mlp._modifier_forward(hidden_states)
output.index_add_(0, indices, hidden_states.to(input.dtype))
return mlps[0].layer(input) + output


@dataclass
class PEERConfig(ModifierConfig):
n_heads: int = 8
moe_num_experts: int = 100
emb_dim: int = 128
down_proj_layer: str = "fc1"
up_proj_layer: str = "fc2"


@Modifier.register("peer", config_cls=PEERConfig)
class PEERModifier(Modifier):
"""
Peer modifier basically does nothing, the job is done in the container.
"""

def __init__(
self,
config: PEERConfig,
**kwargs,
):
super().__init__()
self.config = config

def __len__(self):
return self.config.moe_num_experts
12 changes: 10 additions & 2 deletions projects/modular_llm/eval_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from copy import deepcopy

import torch
import wandb
from pytorch_lightning import seed_everything

import wandb
from mttl.arguments import EvaluationConfig
from mttl.datamodule.base import get_datamodule
from mttl.evaluators.base import EvaluatorRunner, setup_evaluators
Expand All @@ -20,7 +20,11 @@
WeightedLinearMergeConfig,
)
from mttl.models.lightning.callbacks import LossCallback
from mttl.models.lightning.expert_module import ExpertModule, MultiExpertModule
from mttl.models.lightning.expert_module import (
ExpertModule,
MoEModule,
MultiExpertModule,
)
from mttl.models.modifiers.lora import LoRAConfig
from mttl.utils import remote_login

Expand Down Expand Up @@ -186,6 +190,10 @@ def run_eval(args: EvaluationConfig):
module = MultiExpertModule(**vars(expert.training_config)).to("cuda")
module.add_expert_instance(expert, is_default=True)

elif args.merge_or_route in ["peer"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain what you are trying to do?

module: MoEModule = MoEModule(**vars(an_expert.training_config)).to("cuda")
module.model.model.load_state_dict(an_expert.expert_weights, strict=False)

elif args.merge_or_route == "uniform_lora_after_op":
# Here we merge the LoRA experts after the outer product we cannot really do it
# with the lib transform, cause this would require storing large matrices in memory
Expand Down
4 changes: 3 additions & 1 deletion projects/modular_llm/train_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mttl.arguments import Args, ExpertConfig
from mttl.datamodule.base import get_datamodule
from mttl.logging import logger, setup_logging
from mttl.models.library.expert import Expert
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.lightning.callbacks import (
DownstreamEvalCallback,
Expand Down Expand Up @@ -200,7 +201,8 @@ def upload_library(expert_library, module):
if isinstance(module, MoEModule):
with expert_library.batched_commit():
for expert_name in module.experts_names:
expert = module.get_expert_instance(expert_name)
expert: Expert = module.get_expert_instance(expert_name)
expert.expert_info.training_config = args
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this, not a good idea to store complex object in training_config, it will be transformed to a Dict in the next PR :)

expert_library.add_expert(expert, expert_name)
elif isinstance(module, ExpertModule):
expert = module.as_expert()
Expand Down
Loading