Someone's Guide to SMT

This is a labeled as a "guide" but is moreso intended to serve as an introduction of SMT, a report of some general findings and thoughts, and address potential shortcomings of my implementation of SMT. All testing, except as annotated, was performed using a 1xA100 80GB.

What is SMT?

Sparse Matrix Tuning (SMT), is a fine-tuning method as described in the Sparse Matrix in Large Language Model Fine-tuning paper. Essentially, Sparse Matrix Tuning performs a number of warm-up steps against the target dataset, then uses the gradient information collected from those warm-up steps to find the most relevant sub-matrices for the domain-specific task. It does this by splitting the target modules into 256x256 (or some other number) blocks and then selecting X number of blocks with the highest averaged gradients.

SMT has the potential to be superior to LoRA. According to the paper, SMT exhibits better performance compared to LoRA when using the same number of trainable parameters. It's memory efficient due to only calculating gradients for the selected sub-matrices. Additionally, the paper claims that SMT overcomes the "performance plateau" of LoRA, meaning increasing the amount of the trainable parameters increases the performance of SMT more than it would for LoRA.

SMFT

Having read the SMT paper, I have done my best to adapt into a useable PEFT implementation as well as stand-alone script. I was working on this implementation prior to realizing that the code for the official implementation had been published not too long ago. Still, that implementation is not as easily useable as mine and in some ways is less complete than mine. Which is understandable, since it seems to be undergoing review for ICLR 2025. I mention this to note that my implementation diverges from the official implementation in numerous ways, some out of necessity and some with the intention of simplifying the process. For example, my implementation collects the gradient using the model rather than injecting it into the Trainer, as is done in the official implementation. Therefore, any deficiencies with my implementation shouldn't be construed as deficiencies with SMT, the fine-tuning method, itself.

Findings

This section is incomplete and will be updated as more testing is done to ensure the veracity of the claims attested.

Recall the three claims of SMT: Better performance, more efficient, and higher performance plateau than LoRA. Having done more than a dozen runs comparing SMT to LoRA and FFT, I believe these claims to be true, although perhaps exaggerated in the paper.

Memory Consumption

Model GPU Memory Allocated (%)
llama-3.2-1B-fft 96.27838134765625
llama-3.2-1B-lora-R64-A128-D0.01-Qt 94.20074462890625
llama-3.2-1B-lora-R64-A16-D0.01-Qt 94.19830322265625
llama-3.2-1B-lora-R64-A16-D0.01 78.58551025390625
llama-3.2-1B-lora-R64-A128-D0.01 78.58306884765625
llama-3.2-1B-smt-SR0.50-WUP100-BS256-GW 74.45709228515625
llama-3.2-1B-smt-SR0.30-WUP100-BS256-GW 74.04449462890625
llama-3.2-1B-smt-SR0.10-WUP100-BS256-GW 73.54644775390625
llama-3.2-1B-smt-SR0.05-WUP100-BS256-GW 73.25347900390625
llama-3.2-1B-smt-SR0.01-WUP100-BS256-GW 72.79449462890625
llama-3.2-1B-smt-SR0.10-WUP100-BS256-MW 72.13531494140625
llama-3.2-1B-smt-SR0.05-WUP100-BS256-MW 72.09381103515625
llama-3.2-1B-smt-SR0.01-WUP100-BS256-MW 71.27593994140625

SMT appears to reduce the percent (%) of GPU memory allocated compared to LoRA and FFT. All models were trained with the same hyperparameters, with the exception of learning rate (more on that later). Counterintuitively, the LoRAs trained on the quantized model (denoted with Qt) consumed more memory. I'm not sure why. The SMT models trained with MW selection (denoted with MW) were also quantized and consumed less memory as expected. Increasing the sparsity ratio appears to have a negligible effect on the percent of GPU memory allocated compared to the increase in parameter count.

Learning Rate

SMT seems to benefit from higher learning rates than normal, depending on a confluence of factors such as the total parameter count of the model and the sparsity ratio. For sparsity ratios at or around 0.01 and 0.05, I recommend using a LR of at least 2e-3. At 0.10 or higher, a LR of 2e-4 is acceptable.

Same model, same hyperparameters, same sparsity ratio (0.03), different learning rates. The red one uses 2e-3 whilst the blue one uses 2e-4.

The Problem with Quantization

The two methods of sparse sub-matrix selection described in the SMT paper are Gradient-aware selection (GW) and Activation-aware selection (AW). Both of these require warm-up steps against the target dataset. The SMFT implementation introduces an additional selection method, Magnitude-based selection (MW), which doesn't require warm-up steps and therefore is less computationally expensive than the other two. Gradient-aware selection is the best selection by far and consistently outperforms the other two.

The "problem" with quantization is that quantized tensors aren't optimized for training and can't require gradients, meaning that Gradient-aware selection isn't an option when training models in k-bit quantization. This leaves you with Activation-aware selection and Magnitude-based selection. Understand that using quantization and either selection method will lead to worse performance than what's achievable using Gradient-aware selection.

Evaluations

Model PEFT method #Params% Average IFEval BBH MATH GPQA MUSR MMLU-PRO
LLaMA-3.2-1B-SFT LoRA 0.55 4.9 18.87 4.94 0.23 1.68 2.08 1.6
LLaMA-3.2-1B-SFT-AW SMT (AW) 0.49 5.43 22.99 3.46 0.3 1.45 3.07 1.28
LLaMA-3.2-1B-SFT-GW SMT (GW) 0.49 5.36 18.63 3.37 0.23 0.78 7.67 1.47
LLaMA-3.2-1B-SFT-MW SMT (MW) 0.49 4.11 17.51 3.02 0.15 1.23 1.43 1.34

The finetune script used to train these models can be found here. The evaluations were performed using Huggingface LM-Eval. The command used is as follows: lm_eval --model hf --model_args pretrained=<your_model>,load_in_4bit=True,dtype="bfloat16",attn_implementation="flash_attention_2" --tasks leaderboard --device cuda:0 --batch_size auto:4 --system_instruction "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." --apply_chat_template --output_path "results" --fewshot_as_multiturn

Installation and Use

You can install the SMFT package as follows:

pip install git+https://github.com/HeroMines/SMFT.git

or use

1
2
3
git clone https://github.com/HeroMines/SMFT.git
cd peft
python setup.py develop # or "pip install .", but this way is recommended

Creating an SMT model

You can prepare a model for SMT as follows:

from transformers import AutoModelForCausalLM
from peft import get_peft_config, SMTConfig
model_name_or_path = "meta-llama/Llama-2-7b-hf"

config = SMTConfig(
 peft_type="SMT",
 task_type="CAUSAL_LM",
 target_modules=["q_proj", "v_proj", "k_proj"],
 sparsity_ratio=0.05,
 block_size=256,
 selection_method="MW"
)

model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)

The standalone version can be accessed as well, such as:

from peft import SMT
dataset = load_dataset(dataset_name, split="train")

# Tokenize the dataset
dataset = dataset.map(lambda x: tokenizer(x['text'], truncation=True, padding=True, max_length=1024), batched=True)
dataset = dataset.with_format("torch")
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Initialize SMT
smt = SMT(model, dataloader,
 sparsity_ratio=0.01, warmup_steps=100)

# You can print the actual trainable numbers in SMT
smt.print_trainable_params()

Colab Notebook

For your convenience, a Colab notebook has been created and can be found here.

Citation

1
2
3
4
5
6
7
8
9
@misc{he2024sparsematrixlargelanguage,
      title={Sparse Matrix in Large Language Model Fine-tuning}, 
      author={Haoze He and Juncheng Billy Li and Xuan Jiang and Heather Miller},
      year={2024},
      eprint={2405.15525},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2405.15525}, 
}
Edit Report
Pub: 26 Nov 2024 09:46 UTC
Edit: 29 Dec 2024 22:18 UTC
Views: 177