Skip to content

Conversation

@marcusinthesky
Copy link

We may want to train just a subset of our token's while leaving the rest frozen. This may be done when we have token's for named entities, or special tokens for chat assistants. While you can train only certain token's by registering a custom gradient hook, to 'zero-out' the frozen token embedding, this still requires us to store in memory and on-disk a large embedding matrix, which can take up ~30% of the total model size. A more efficient approach, would be to wrap embedding weight deltas for trainable token's in a Sparse Matrix, making merging, unmerging and delta operations fast.

Please see #1462.

ToDo:

  • Tests
  • Documentation
  • Examples

@BenjaminBossan
Copy link
Member

Thanks for this PR. I have just skimmed it so far. Could you give an example of how to use this adapter? How would we combine it with, say, LoRA?

If you have specific questions regarding tests and documentation, feel free to ask me.

@BenjaminBossan
Copy link
Member

@marcusinthesky Do you still plan on working on this?

@marcusinthesky
Copy link
Author

Yes. I do. It may be two weeks.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this Apr 26, 2024
@ds5423
Copy link

ds5423 commented Jan 22, 2025

@marcusinthesky Thank you for making this PR! I was exactly looking for a way to train embeddings just for the newly added tokens while freezing the rest. It would be super helpful if this can be integrated to peft.

@marcusinthesky
Copy link
Author

@ds5423 thanks. I got fairly far, but did not have time for tests, docs, examples and save/load config.

I do think this is the optimal approach, which would really benefit the community. Would love to hear how it works out for you?

@ds5423
Copy link

ds5423 commented Feb 6, 2025

@marcusinthesky Hi! Right now I'm just using a gradient hook to freeze the rest of the embedding matrix. I want to see if this can help preserve the model's original behaviors.

@marcusinthesky
Copy link
Author

@ds5423 I hope it worked out for you, @githubnemo is doing some cool stuff in #2376 which looks exciting.

githubnemo added a commit that referenced this pull request Feb 26, 2025
This change is based on the nifty addition of @marcusinthesky from #1541.

When adding tokens or fine-tuning the representation of specific tokens we currently have little choice but to retrain the whole embedding matrix which can be huge and adds to the memory footprint (in RAM but also on disk). This method creates a sparse matrix of shape (n, embed_dim) where n is the number of tokens to be customized and only trains these few values.

This change introduces two ways of using it:

```
peft_config = TrainableTokensConfig(target_modules=['embed_tokens'], token_indices=[0, 1, 2])
peft_model = get_peft_model(model, peft_config)
```

and with LoRA

```
peft_config = LoraConfig(
    target_modules='all-linear',
    trainable_token_indices={'embed_tokens': [0, 1, 2]},
)
peft_model = get_peft_model(model, peft_config)
```

Adding this feature to adapters other than LoRA should be relatively easy, mostly adding the `trainable_token_indices` config option and some debugging.

To make this change it was necessary to change the `modules_to_save` infrastructure as combining this feature with LoRA is quite similar. This refactoring entailed moving most of the basic functionality of `ModulesToSave` to the `AuxiliaryTrainingWrapper` class. This also changes the logic how `modules_to_save` is loaded/saved from from the state dict, so there could still be bugs here.

This implementation does not entail support for weight-tied layers yet. This will follow in a future change.

---

Notable commits in this squash:

* Use unload_and_optionally_merge_module protocol

With `AuxiliaryTrainingWrapper` as abstraction it is probably a good idea to
have support for `unload_and_optionally_merge_module`.

Since the wrapper is more akin to a PEFT layer than a model the name semantics
are fine and it does basically the same job.

* trainable tokens is also trained in certain adapters

Before, the assumption was that modules_to_save was the only thing that
is trained alongside an adapter's parameters. Now there's also the
token_adapter delta tokens via `NewTokensWrapper`.

* Remove old modules_to_save handling

This is now all handled via the `AuxiliaryTrainingWrapper`.

* Fix modules_to_save module overwriting

The state dict imlementation of ModulesToSaveWrapper was incorrect in that
it did not include its own parameters, just the parameters it needs to overwrite
in the end. I.e. if layer `lin1` is modules to save wrapped,
`lin1.{weight,bias}` is saved and overwritten but `lin1.modules_to_save.<adpater_name>.[...]`
is not saved.

* Introduce a load key map for aux. train wrapper

Before this change it was only possible to remove a key prefix from the wrapper's
state dict (e.g., `modules_to_save.default.weight` -> `weight`); now it is possible
to restore such reduced value by mapping the key back
(i.e., `weight` -> `modules_to_save.default.weight`).

* Replace sparse matrix with dense + index_copy

This change is mostly because sparse matrices are not that beneficial in this case
(at least not from what we can see right now) and they do not solve the problem
of having to change the new tokens in-place to avoid outdated deltas when new token
vectors are initialized randomly after loading the deltas.

* Make peft_config.layers_to_transform optional

Before this change the base tuner class was forcing this attribute
to be present on the config class even though the attribute is not
specified in the base config.

* Implement missing key logic in `_set_trainable`

Before this it was not checked if the targeted module by `modules_to_save` or `trainable_token_indices` existed
or not (when used in conjunction with a PEFT method). In this case an error message similar to the `inject_adapter`
error is raised when no module is found.

---------

Co-authored-by: Marcus Gawronsky <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
…2376)

This change is based on the nifty addition of @marcusinthesky from huggingface#1541.

When adding tokens or fine-tuning the representation of specific tokens we currently have little choice but to retrain the whole embedding matrix which can be huge and adds to the memory footprint (in RAM but also on disk). This method creates a sparse matrix of shape (n, embed_dim) where n is the number of tokens to be customized and only trains these few values.

This change introduces two ways of using it:

```
peft_config = TrainableTokensConfig(target_modules=['embed_tokens'], token_indices=[0, 1, 2])
peft_model = get_peft_model(model, peft_config)
```

and with LoRA

```
peft_config = LoraConfig(
    target_modules='all-linear',
    trainable_token_indices={'embed_tokens': [0, 1, 2]},
)
peft_model = get_peft_model(model, peft_config)
```

Adding this feature to adapters other than LoRA should be relatively easy, mostly adding the `trainable_token_indices` config option and some debugging.

To make this change it was necessary to change the `modules_to_save` infrastructure as combining this feature with LoRA is quite similar. This refactoring entailed moving most of the basic functionality of `ModulesToSave` to the `AuxiliaryTrainingWrapper` class. This also changes the logic how `modules_to_save` is loaded/saved from from the state dict, so there could still be bugs here.

This implementation does not entail support for weight-tied layers yet. This will follow in a future change.

---

Notable commits in this squash:

* Use unload_and_optionally_merge_module protocol

With `AuxiliaryTrainingWrapper` as abstraction it is probably a good idea to
have support for `unload_and_optionally_merge_module`.

Since the wrapper is more akin to a PEFT layer than a model the name semantics
are fine and it does basically the same job.

* trainable tokens is also trained in certain adapters

Before, the assumption was that modules_to_save was the only thing that
is trained alongside an adapter's parameters. Now there's also the
token_adapter delta tokens via `NewTokensWrapper`.

* Remove old modules_to_save handling

This is now all handled via the `AuxiliaryTrainingWrapper`.

* Fix modules_to_save module overwriting

The state dict imlementation of ModulesToSaveWrapper was incorrect in that
it did not include its own parameters, just the parameters it needs to overwrite
in the end. I.e. if layer `lin1` is modules to save wrapped,
`lin1.{weight,bias}` is saved and overwritten but `lin1.modules_to_save.<adpater_name>.[...]`
is not saved.

* Introduce a load key map for aux. train wrapper

Before this change it was only possible to remove a key prefix from the wrapper's
state dict (e.g., `modules_to_save.default.weight` -> `weight`); now it is possible
to restore such reduced value by mapping the key back
(i.e., `weight` -> `modules_to_save.default.weight`).

* Replace sparse matrix with dense + index_copy

This change is mostly because sparse matrices are not that beneficial in this case
(at least not from what we can see right now) and they do not solve the problem
of having to change the new tokens in-place to avoid outdated deltas when new token
vectors are initialized randomly after loading the deltas.

* Make peft_config.layers_to_transform optional

Before this change the base tuner class was forcing this attribute
to be present on the config class even though the attribute is not
specified in the base config.

* Implement missing key logic in `_set_trainable`

Before this it was not checked if the targeted module by `modules_to_save` or `trainable_token_indices` existed
or not (when used in conjunction with a PEFT method). In this case an error message similar to the `inject_adapter`
error is raised when no module is found.

---------

Co-authored-by: Marcus Gawronsky <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants