はじめに
こんにちは!
株式会社ユーザベース スピーダ事業の飯田です。 普段はベクトル検索用の埋め込みモデルの学習及びデプロイを行っています。
大きなモデルを低リソースで学習する手法としてQLoRAとFSDPが広く用いられています。 この記事では、SentenceTransformersにおいてFSDP+QLoRAを行うときの注意点をご紹介します! 特に、開発データをつかって、学習過程をモニタリングするケースについて紹介します。
前置き
FSDP (Full Sharded Data Parallel) は、パラメータを各GPUで分割して保持し、順伝播・逆伝播時に必要なパラメータを一時的に集約することによって、より少ないメモリ量で深層学習のモデルを学習する手法です。 深層学習において、GPU数に応じて処理できるデータ量をスケールさせる学習手法として、DDP(Distributed Data Parallel)がありますが、これは各GPUがモデルの完全なコピーを持つので、大きなメモリ量が必要でした。 FSDPでは、より無駄なくGPUメモリを使用することができるため、大規模なモデルを学習することが可能になります。 詳細は公式チュートリアルなどを参照してください。
QLoRA
本題に入る前に、気になる方向けにQLoRAについて、簡単に触れます。まずLoRAとは、学習対象のモデルにモデルよりも小さいパラメータ数のネットワーク(アダプタと呼ばれる)を結合し、その部分のみを学習させることで、学習に必要なGPUのメモリ容量等を大幅に減らす手法です。 LoRAでは元々のモデルの部分、つまりアダプタ部分以外の部分は学習で更新されません。そこで、その部分を量子化 (Quantization) して、さらに必要なGPUメモリを減らすことができます。この量子化とLoRAを合わせたんものがQLoRAです。LoRAの詳細はこちら、Quantizationの詳細はこちらを参照してください。
本題
それでは、SentenceTransformersの話に移ります。深層学習では、学習が順調に進んでいるか、モデルが過学習などを起こしていないかを、学習データの一部をサンプルして得た開発データを用いて、一定ステップごとに確認します。
SentenceTransformersでは、huggingfaceのaccelerateを用いることで、FSDP学習ができるのですが、残念ながら開発データによる評価ができませんでした。
また、公式ドキュメントが、チェックポイントをFULL_STATE_DICT
で保存するよう推奨しています。これは、各GPUに分散されたパラメータを再集結して、通常のモデルの様に保存するモードです。
しかし、我々の環境ではFULL_STATE_DICT
とするとチェックポイント保存時にハングしてしまい、学習がすすめられませんでした。
そのため、パラメータを分散したまま保存する、SHARDED_STATE_DICT
を使用する必要がありました。
以下をこれらに対する対処方法です。SentenceTransformerのv4.0.2以降で可能です。
Evaluatorの書き換え
公式にも記載されている通り、FSDPで学習を行うとEvaluatorが動作しません。 この要因の一つは、SentenceTransformerのencodeメソッドを使っているためです。例えば、Evaluatorの一つである、SentenceEvaluatorでは以下の箇所が該当します。
これでエラーになる理由ですが、SentenceTransformerのencodeメソッドでは、前処理をしたのちself.forwardでテキストをベクトルにして後処理をしています。
FSDPでは、外からモデルをWrapしてforwardメソッドを上書いています。これにより、分散したパラメータを計算時に集約する計算を外から隠しています。 しかし、内側のSentenceTransformerでself.forwardを呼び出すと、FSDPのforwardメソッドが呼ばれず、もともとのSentenceTransformerのforwardが呼び出されます。 そのため、分散したパラメータを計算時に集約することができずに、エラーで落ちます。(エラーは、embeddingのshapeが合わないと行った趣旨で出てきます)
これに対処するために、Evaluatorは自作する必要があります。Evaluatorはやりたい処理をcallメソッドに書き込めば良いです。
その際、modelが渡されます
*1 。このモデルはFSDPでWrapされたモデルなので、tokenizeなどのテキストの前処理を自分で行い、model.forward(preprocessed_texts)
としましょう。
チェックポイントの読み込み
FSDPのチェックポイント保存はFULL_STATE_DICT
とSHARDED_STATE_DICT
が代表的なものとしてあります。
FULL_STATE_DICTを選択すると、チェックポイントの保存のたびに分散された重みを集約する必要があります。そのため遅くなったり、最悪の場合Out-of-Memory(OOM)で落ちます。
さらに、われわれは保存時にハングしてしまったため、SHARDED_STATE_DICT
を用いる必要がありました。
また、開発データで評価できたのであれば、最後に最も良かったモデル(チェックポイント)を読み込みたいですよね!? しかし、SentenceTransformerで実装されている_load_best_modelメソッドでは、SHARDED_STATE_DICTで保存されたチェックポイントの読み込みができません。 そのため、以下の様にtorch.distributed.checkpoint.loadを用います。
def init_qlora_model(model_name, instruction, peft_config, bnb_config): prompt = f"<instruct>{instruction}\n<query>" model = SentenceTransformer( model_name, prompts={ "anchor": prompt, }, tokenizer_kwargs={"padding_side": "right"}, model_kwargs={"quantization_config": bnb_config, "torch_dtype": torch.bfloat16, "attn_implementation": "flash_attention_2"}, ) model[0].auto_model = get_peft_model( model[0].auto_model, peft_config=peft_config, ) return model def fsdp_load_model_from_sharded_checkpint(model, ckp_path): model_state_dict = model.state_dict() torch.distributed.checkpoint.load( {"model": model_state_dict}, checkpoint_id=ckp_path ) model.load_state_dict(model_state_dict, strict=False) return model ... model = init_qlora_model(model, instruct, peft_config, bnb_config) ... trainer.train(model) ... if os.getenv("LOCAL_RANK", "0") == "0": best_model = init_qlora_model( model_name, instruction, max_seq_length, peft_config, bnb_config ) check_point_model_dir = os.path.join( trainer.state.best_model_checkpoint, "pytorch_model_fsdp_0" ) best_model = fsdp_load_model_from_sharded_checkpint( best_model, check_point_model_dir )
ポイントは以下の通りです。
- fsdp_load_model_from_sharded_checkpoint関数内の、torch.distributed.checkpoint.load実行時に、
{"model": model_state_dict}
とする点です。チェックポイントはFSDPのモデルを保存しています。一方、読み込みに使うモデルは通常のモデルです。 そのため、state_dictのkeyが異なります。読み込みに使うモデルのkeyの冒頭にmodel
を付与することで、state_dictのkeyを同じにすることが可能です。どうやら、FSDPのcheckpoint保存時に、modelというkeyが冒頭に付与されるようです。 - fsdp_load_model_from_sharded_checkpint関数内の、model.load_state_dict実行時に、strict=Falseにしましょう。なぜか量子化部分のkeyが合わないようです。量子化部分は学習していないので、読み込めなくても問題ありません。
- 途中評価を行っていれば、trainerのargsでload_best_model=Trueと指定なくても、best_modelの保存先をtrainerが保持しています。trainer.state.best_model_checkpointがそれに該当するため、これを用います。
- 上記のコードは、学習後のモデルの読み込みを1プロセスだけで行うため、
os.getenv("LOCAL_RANK", "0") == "0"
としています。FSDPでは複数のプロセスが同時に走っているため、これによりディスクへの読み込み負荷を軽減しています。
Adapterだけの読み書きをしたい場合
本節は余談です。先程の方法では、モデルの全体が保存されるため、無駄が生じています。しかし、FSDPのモデルにおいて、Lora部分だけを保存する方法がSentenceTransformerでは用意されていません。そのため、SentenceTransformerTrainerにおいて、以下のメソッドをオーバーライドする必要があります。
def _save_optimizer_and_scheduler(self, output_dir): if self.is_fsdp_enabled: self.model = self.model[0].auto_model try: super()._save_optimizer_and_scheduler(output_dir) finally: if self.is_fsdp_enabled: self.model = self.model_wrapped
この操作は、FSDPを使用している場合に、SentenceTransformerから中身のTransformerの部分を取り出して元に戻すという操作をしています。 中身のモデルにすることで、TransformerのTrainerがLora用の保存方法を適用し、Adapterのみを保存することが可能です(最終的にはaccelerateで処理されます)。 TransformersのTrainerでは、fsdp利用時に、self.model_wrappedにFSDPでwrapしたモデルを退避しています。そのため、保存がおわったら、self.modelに代入し、self.modelを元の状態に戻しています。
保存されるのが、Adapterのみになったため、読み込みも以下のように変える必要があります。
from peft import set_peft_model_state_dict, get_peft_model_state_dict def fsdp_load_model_from_sharded_checkpint(model, ckp_path): model_state_dict = get_peft_model_state_dict(model, adapter_name=model.active_adapter) torch.distributed.checkpoint.load({"model": model_state_dict}, checkpoint_id=ckp_path) set_peft_model_state_dict(model, model_state_dict, adapter_name=model.active_adapter) return model model = init_lora_model(model_name, instruction, max_seq_length, peft_config) peft_model = model[0].auto_model ckp_path = "/path/to/model" peft_model = fsdp_load_model_from_sharded_checkpint(peft_model, ckp_path)
変更点としては、Transformer部分のAdapterのみ保存されたため、keyが合うように読み込みするときも中身のモデル (peft_modelの部分) で行う必要があります。 もう一つは、state_dictの取り出しと設定に、peft用のものを用いる必要があります。
上記を鑑みると、CustomTrainerを用意する必要があり、若干エンジニアリング的なポータビリティが下がるため、ディスク容量で困った場合などに検討ください。
まとめ
今回は、SentenceTransformersで、QLORA-FSDPを用いる際に必要な変更点などを紹介しました。 実は、もともとはFSDP自体が動作しなかったのですが、こちらは公式になおしてもらいました。 なお、今回の方法では、FSDP2には対応できません。また、CachedMultinegativeRankingLossの学習も未だ成功していません。こちらも、なんらかの知見が得られたら、また紹介したいとおもいます。
*1:SentenceTransformersのEvlauatorの大半はリンク先と同様のcallメソッドを持っています