はじめに
こんにちは!
この記事では、テキストをベクトルに変換(エンコード)にLLMを用いる際に有効なLLM2Vecという手法を紹介します。 合わせて、LLM2Vecにおける日本語ドメイン適応として、LLM2Vecの処理を日本語で行った場合とLLMの継続事前学習を日本語で行った場合について実験を行ったため、これを紹介します。
LLM2Vecとは
LLM2Vecは、"LLM2Vec: Large Language Models Are Secretly Powerful Text Encoders"で提案された手法です。 Llamaなどで有名なLLMでは、テキストをベクトルにエンコードする際は、入力したテキストの末尾のベクトルを用いることが多いです。 これは、LLMがテキストの次の単語を予測することで事前学習を行っていることに関連があります。 人間は、一部のテキストからその続きの部分を予測・生成していく場合に、これまでのテキストを前提にします。 このとき、続きの部分の情報が事前に手に入ることは通常ありません。 LLMもこれと全く同様に学習されており、LLMがテキストの続きを生成していく能力は、入力されたテキストの次の単語を予測しその正誤を学習する言語モデルとしての学習によって会得されています。 また、LLMはテキストの各単語(実際はトークン)をベクトルに変換し、ベクトルを最後に単語に変換ことで、学習を行っています。 そのため、入力テキストの全情報が含まれる一番末尾のベクトルを用いることが有効と考えられます。
しかしながら、テキストをベクトルに変換するという用途を考えると、ベクトルにしたいテキスト全体の全情報を使うほうが有効と考えられます。 実際、LLMの基本要素であるTransformerを用いた言語モデルにおいては、その単語の近くの情報のほうが反映されやすい傾向にあると言われており、 末尾のベクトルだけでは、全情報の反映に十分とは言えません。また、そのために各単語のベクトル全部を平均したベクトルを用いるという方法も 考えられますが、この場合は、各単語のベクトルがすべて、テキスト全体を反映できていたほうがよりよくテキストを表現するベクトルを得られると考えられます。 このような事前学習も以前から取り組まれており、BERTなどのモデルは、テキストの一部の単語をマスクし、その単語を予測し、その正誤を学習することで作成されています。
とはいえ、改めて事前学習するのは、計算資源確保の観点からも非常に労力がかかります。そのため、LLMのように前から後ろを予測する単方向の言語モデルから、 BERTのように、入力テキスト全体を前後の双方向から参照する言語モデルへと変換する学習を省計算資源で行う手法がLLM2Vecとなります。
下図は、LLM2Vecの概要です。次節にて詳細を説明します。
LLM2Vecの詳細
近年の言語モデルの殆どがTransformerというアーキテクチャーを使っています。Transformerは自己注意機構 (Self-Attention)とフィードフォワードネットワークを組み合わせたものです。
単方向言語モデルと双方向言語モデルからへの変更は、Self-Attentionを用います。単方向言語モデルではSelf-Attentionにマスクをかけることによって、 先々の情報を使わないようにしていますが、このマスクを用いないことによって、双方向言語モデルに変更します。
しかし、単にSelf-Attentionの使い方を変えるだけでは、学習時と異なる状態になります。この差異を埋めるために、Masked Next Token Predicton(MNTP)というタスクで学習を行います。MNTPは、マスクしたトークンの一つ後ろで、マスクされたトークンを予測して、その正誤を学習していくタスクです。 LLM2Vecの上手い点はこのMNTPにあると私は考えています。単方向言語モデルの次のトークンを予測するという性質を保ちながら、双方向の情報を使うタスクにすることによって、単方向言語モデルからの差分を小さくしながら、双方向言語モデルへの変換を行うことを可能にしています。
LLM2Vecでは最後に既存手法であるSimCSEを用いることによって、テキスト全体のベクトル表現としてより洗練されたものにしています。
実際のコード
では実際にどのような実装になっているのかを、単方向言語モデルと双方向言語モデルのSelf-Attentionの変更方法について見ていきます。 まず、Attentionの変更方法についてです。Llamaを代表に見ていきます。
Llm2Vecのレポジトリでは、Llamaの双方向言語モデルはLlamaBiModelとして実装されています。大きな変更はAttentionが双方向として実装されている点とマスクの仕方です。 Attentionの双方向化はis_causalをFalseにする変更がなされています。
この変更は、実際にはFlashAttentionの中で使用されています。
もう一つのマスクの仕方は、_update_causal_maskメソッドのオーバーライドによってなされています。中心となるのは、以下の部分です。
対比として、単方向言語モデルのLlamaの実装で同様の処理を行っている部分を以下に掲載します。
LLM2Vecのコードでは、causal_maskがで初期化されています。attentionのマスクはattentionの重みと足し合わせて使用するため、となり、マスクをしないということになります。 一方、単方向言語モデルでは、データ型最小の値でtorch.fullによって初期化されています。最小値をと解釈すれば、であるため、マスクしてということになります。 また、torch.triuによって左下と対角成分をにしているため、右上の成分のみがマスクされることになります。Attentionは最終的には、として使用されるので、各単語は自分よりも後方にある単語の情報を混合出来なくなるため、先々の情報を使わない学習が達成されます。
なお通常のAttentionにおけるマスク適用の実装はこちらです。
余談1:上記コードを調べていたら4D maskがサポートされていることに気づきました。 上記コードの周辺attention_maskの次元が4の場合の処理が書かれていますが、4D maskを使う場合であると思います。こちらもそのうち調査できたらと思います。
余談2: causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
は、cacheしている部分まではmaskをしないように補正していると思われます。ただし、LLM2Vecのコードでは、初期化の値が0なので、causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
は何をしてもで、どのような意味があるのかは不明でした。。。
論文の実験結果
それでは、論文に掲載されている結果について見ていきます。MNTPと SimCSEの学習はどちらもWikipediaを用いています。これは、モデルの変化をみるために、実験に使用したモデルはすでに学習していると想定されるWikipediaを用いることで、新しい知識を与えないようにするためです。学習はとちらもLoRAを使ってファインチューニングしています。バッチサイズはMNTPが32, SimCSEが128を用いており、ともに1000 ステップ学習しています。LLM2Vecでのプーリングは平均を用いています。
それでは結果です。以下は、MTEBというベンチマークを用いて評価しています。LLM2VeCが最も精度が高いことがわかります。また、LLM2Vec(w/oSimCSE)=MNTPだけ行った場合でもUni + w.Meanという単方向言語モデルで平均プーリングを用いた場合より精度が高くなっています。逆に、Bi + Meanという、MNTPを行わずに単純に双方向にだけにした場合は、大きく精度が下がっいます。双方向言語モデルにするためにMNTPが効果的であることがわかります。 既存の手法と比較すると、EcoEmbeddingやBERTのようなEncoder-onlyモデルに対してSimCSEを行うよりもLLM2Vecが高い精度をもたらしています。
なお、元論文中には、LlaMA-1.3B、LlaMA-2-7B、Mistral-7Bでの実験結果も掲載されています。気になる方は、元論文を確認ください。
次に教師ありの場合を見ます。E5と呼ばれるいくつかのデータセットを合わせた教師データを用いて学習を行った結果です。こちらも学習はLoRAを用いており、バッチサイズ512で1000step実行しています。下の表よりLLM2Vecが最も良い精度となっています。双方向言語モデルは教師ありの場合も効果的であることがわかります。ただし、SimCSEの効果は限定的であり、LlaMA-3-8Bの場合はわずかながら低下しています。
LlaMA-2-7B、Mistral-7Bでの実験結果も掲載されています。気になる方は、元論文を確認ください。
次に、各ステップごとでの精度を比べます。すると、SimCSEを用いたLLM2Vecが最も早く収束を迎えていることがわかります。SimCSEを用いることで、少ない学習データでも高い精度となることがわかります。ただし、SimCSEの効果はモデルの性質か教師ありの方法に用いる負例のどちらがより効いていたかは定かではありません。
日本語での実験:継続事前学習に効果はあるか?
日本語で実験を行いました。日本語では、英語ほど教師データが潤沢ではありません。そのため、LLM2Vecの教師なしで高精度・少ない教師データで高精度といった性質が有効な場面が多いと考えます。 また、ドメインを絞った場合に教師データはより少なくなります。そこで、ドメイン適応の仮想的な実験として、以下の2つの実験を行いました。
- LLM2Vecに用いるデータを日本語にした場合、どの程度の効果があるか?
- ベースにする言語モデル自体も日本語でドメイン適応した場合、より効果があるか?
1の実験としてLlama2-7bを用いて、日本語wikipediaでMNTPとSimCSEを行いました (以下、Llama2-Llm2vec-jpnと呼ぶ)。 また、2の実験として、Llama2-7bから日本語コーパスで継続して言語モデルの事前学習を行った、Swallow-7bを用いて同様に日本語WikipediaでMNTPとSimCSEを行いました (以下、Swallow-Llm2vec-jpnと呼ぶ)。 ベンチマークとしてMTEBのうち日本語を含むデータを用いています。なお、一部データは権利の都合上除いています。
それでは以下が結果です。各タスク区分はMTEBの論文を参照ください。ベースラインとして、Llama2-7bに対して、著者ら用いたが英語Wikipediaで行った結果 (以下、Llama2-Llm2vec-engと呼ぶ)も示しています。*1
Classification | Clustering | PairClassification | Reranking | BitextMining | Retrieval | STS | 平均 | |
---|---|---|---|---|---|---|---|---|
Llama2-Llm2vec-eng | 0.527 | 0.258 | 0.501 | 0.217 | 0.275 | 0.296 | 0.765 | 0.406 |
Llama2-Llm2vec-jpn | 0.570 | 0.365 | 0.510 | 0.349 | 0.478 | 0.417 | 0.795 | 0.498 |
Swallow-Llm2vec-jpn | 0.621 | 0.391 | 0.510 | 0.338 | 0.475 | 0.491 | 0.832 | 0.523 |
平均的には、Swallow-jpnが最もよい結果となりました。また、 Llama2-jpnもLlama2-engを上回っています。これらは、LLM2Vecの手続きを日本語で行うだけでも効果があることがわかります。 さらに、SwallowはLlama2を日本語で継続的に学習した言語モデルなので、言語モデル自体を日本語で学習することも埋め込みに効果をもたらすことがわかります。 このことから、同言語内のドメイン適応としても対象とするドメインで継続的な学習やLLM2Vecを行うことが効果的がある可能性があります。今後機会があれば、実験してみたいところです。
各タスクの詳細を見ると、PairClassification、Reranking、BitextMiningはSwallow-jpnとLlama2-jpnでほとんど変わらないまたはSwallow-jpnで少し精度が悪化しています。 RerankingはRetrievalと同様の傾向を示すと考えていましたので、意外でした。これが、データセットの性質の違いなのかRerankingとRetreivalの差なのかは、気になるところです。 こちらも、今後の課題としたいところです。
次に英語のタスクで確認します。こちらはLLM2Vecの著者らの挙げているタスクを用いています。なお、こちらも一部データは権利の都合上除いています。 以下が結果です。
Classification | Clustering | Pair_Classification | Reranking | Retrieval | STS | 平均 | |
---|---|---|---|---|---|---|---|
Llama2-Llm2vec-eng | 0.709 | 0.386 | 0.780 | 0.588 | 0.329 | 0.723 | 0.586 |
Llama2-Llm2vec-jpn | 0.722 | 0.428 | 0.785 | 0.594 | 0.371 | 0.717 | 0.603 |
Swallow-Llm2vec-jpn | 0.695 | 0.385 | 0.751 | 0.576 | 0.318 | 0.710 | 0.572 |
平均的には、Llama2-jpnの評価値が高くなり、Swallow-jpnが最も低くなりました。 まず、SwallowがLlama2の評価結果が低いことについては、Llama2から日本語で継続事前学習を行ったSwallowが、英語タスクの精度が悪化するという傾向と一貫しています。 一方で、驚いたことに、Llama2-jpnの方がLlama2-engより良い結果となっています。 ClusteringやClassificationのようなタスクは、SimCSEのようにベクトルの等方性を促す処理は必ずしも効果的ではないことが指摘されています。 実際、ClassificationやClusteringでは、Llama2-LLM2Vec-jpnがLlama2-LLM2Vec-engを上回っています。 また、SimCSEが有効とされるSTSでは、Llama2-LLM2Vec-engがLlama2-LLM2Vec-jpnを上回っており、先行研究の指摘を支持しています。 しかしながら、先程の日本語の実験ではこのような傾向は見られませんでした。 一貫して説明するためには、さらなる研究が必要と思われます。
まとめ
本記事では、LLM2Vecという単方向言語モデルを双方向言語モデルに変換する手法を紹介しました。また、双方向言語モデルに変換することで、埋め込み表現の精度が向上することを確認しました。 さらに、日本語での実験を行い、日本語コーパスでLLM2Vecのプロセスを実行すると精度が向上すること、継続事前学習により、日本語でさらに精度を向上させられることを確認しました。 一方で、英語による実験では、対象の言語でLLM2Vecを行うよりも、異なる言語でLLM2Vecを行うほうがよいという異なる傾向を得ました。
日本語での実験で学習したLoraの重みは以下で公開しています。
- Llama2-Llm2vec-eng
- Llama2-Llm2vec-jpn
- Swallow-Llm2vec-jpn
補足
実験条件詳細
- 使用GPU: A6000x2
- 学習フレームワーク: DeepSpeedZero2を使用
MNTP
元論文からの変更
- bf16を有効にしました
- GradientCheckintPointを有効にしました
- バッチサイズを64に変更しました
SimCSE
元論文からの変更
- bf16を有効にしました
- バッチサイズを256に変更しました
- GradientCheckintPointを有効にしました
実験結果詳細
日本語
参考値として、JMTEBのe5-largeの結果を載せています。STSを除き、教師ありで学習したmultilingual-e5-largeの方が良い結果になる傾向が見られます。
genre | task\model | Llama2-Llm2vec-eng | Llama2-Llm2vec-jpn | Swallow-Llm2vec-jpn | multilingual-e5-large |
---|---|---|---|---|---|
BITEXT_MINING | BibleNLPBitextMining | 0.464 | 0.711 | 0.941 | - |
FloresBitextMining | 0.135 | 0.279 | 0.180 | - | |
NTREXBitextMining | 0.226 | 0.445 | 0.304 | - | |
CLASSIFICATION | AmazonCounterfactualClassification | 0.593 | 0.628 | 0.699 | 0.707 |
MassiveIntentClassification | 0.600 | 0.636 | 0.696 | 0.756 | |
MassiveScenarioClassification | 0.639 | 0.679 | 0.743 | 0.886 | |
SIB200Classification | 0.558 | 0.661 | 0.685 | - | |
WRIMEClassification | 0.245 | 0.243 | 0.280 | - | |
CLUSTERING | LivedoorNewsClustering | 0.230 | 0.324 | 0.367 | 0.571 |
MewsC16JaClustering | 0.323 | 0.428 | 0.455 | 0.453 | |
SIB200ClusteringS2S | 0.222 | 0.343 | 0.352 | - | |
PAIR_CLASSIFICATION | PawsXPairClassification | 0.501 | 0.510 | 0.510 | 0.621 |
RERANKING | MIRACLReranking | 0.179 | 0.340 | 0.280 | - |
VoyageMMarcoReranking | 0.255 | 0.358 | 0.396 | - | |
RETRIEVAL | BelebeleRetrieval | 0.314 | 0.632 | 0.815 | - |
JaGovFaqsRetrieval | 0.335 | 0.453 | 0.534 | 0.703 | |
JaQuADRetrieval | 0.306 | 0.423 | 0.481 | - | |
MintakaRetrieval | 0.130 | 0.171 | 0.241 | - | |
MultiLongDocRetrieval | 0.054 | 0.150 | 0.226 | - | |
NLPJournalAbsIntroRetrieval | 0.358 | 0.478 | 0.515 | 0.860 | |
NLPJournalTitleAbsRetrieval | 0.481 | 0.610 | 0.700 | 0.947 | |
NLPJournalTitleIntroRetrieval | 0.182 | 0.242 | 0.278 | 0.725 | |
XPQARetrieval | 0.506 | 0.595 | 0.635 | - | |
STS | JSICK | 0.806 | 0.824 | 0.848 | 0.784 |
JSTS | 0.724 | 0.767 | 0.816 | 0.810 |
英語
参考値として、 元論文中のLlama2-7b-chat-hfの結果を載せています。実験時の差分として、指示文を用いていない点があります。指示文を用いたほうが精度がよいようですが、今回はすべて指示チューニングをしていないモデルでの検証を行ったため、除きました。
genre | task | Llama2-Llm2vec-eng | Llama2-Llm2vec-jpn | Swallow-Llm2vec-jpn | Llama2-Llm2vec-eng(元論文) |
---|---|---|---|---|---|
CLASSIFICATION | AmazonCounterfactualClassification | 0.740 | 0.736 | 0.718 | 0.769 |
AmazonPolarityClassification | 0.719 | 0.754 | 0.717 | 0.791 | |
Banking77Classification | 0.847 | 0.841 | 0.828 | 0.847 | |
EmotionClassification | 0.488 | 0.445 | 0.470 | 0.466 | |
ImdbClassification | 0.638 | 0.692 | 0.647 | 0.757 | |
MTOPDomainClassification | 0.921 | 0.937 | 0.911 | 0.943 | |
MTOPIntentClassification | 0.731 | 0.780 | 0.686 | 0.795 | |
MassiveIntentClassification | 0.713 | 0.733 | 0.687 | 0.738 | |
MassiveScenarioClassification | 0.765 | 0.782 | 0.742 | 0.792 | |
ToxicConversationsClassification | 0.649 | 0.673 | 0.661 | 0.718 | |
TweetSentimentExtractionClassification | 0.589 | 0.567 | 0.580 | 0.572 | |
CLUSTERING | ArxivClusteringP2P | 0.413 | 0.466 | 0.440 | 0.478 |
ArxivClusteringS2S | 0.347 | 0.396 | 0.366 | 0.405 | |
MedrxivClusteringP2P | 0.302 | 0.313 | 0.318 | 0.309 | |
MedrxivClusteringS2S | 0.286 | 0.273 | 0.286 | 0.280 | |
RedditClustering | 0.407 | 0.451 | 0.386 | 0.428 | |
RedditClusteringP2P | 0.502 | 0.596 | 0.504 | 0.601 | |
StackExchangeClustering | 0.591 | 0.649 | 0.549 | 0.651 | |
StackExchangeClusteringP2P | 0.298 | 0.326 | 0.294 | 0.336 | |
TwentyNewsgroupsClustering | 0.324 | 0.381 | 0.318 | 0.308 | |
PAIR_CLASSIFICATION | SprintDuplicateQuestions | 0.909 | 0.894 | 0.882 | 0.876 |
TwitterSemEval2015 | 0.607 | 0.645 | 0.574 | 0.651 | |
TwitterURLCorpus | 0.823 | 0.815 | 0.796 | 0.809 | |
RERANKING | AskUbuntuDupQuestions | 0.556 | 0.553 | 0.550 | 0.556 |
SciDocsRR | 0.745 | 0.775 | 0.748 | 0.776 | |
StackOverflowDupQuestions | 0.461 | 0.453 | 0.431 | 0.478 | |
RETRIEVAL | ArguAna | 0.456 | 0.517 | 0.445 | 0.471 |
CQADupstack | 0.236 | 0.282 | 0.207 | 0.308 | |
FiQA2018 | 0.187 | 0.232 | 0.166 | 0.246 | |
NFCorpus | 0.226 | 0.257 | 0.214 | 0.268 | |
QuoraRetrieval | 0.826 | 0.821 | 0.810 | 0.862 | |
SCIDOCS | 0.101 | 0.129 | 0.099 | 0.100 | |
SciFact | 0.476 | 0.565 | 0.488 | 0.645 | |
Touche2020 | 0.123 | 0.163 | 0.118 | 0.102 | |
STS | BIOSSES | 0.834 | 0.832 | 0.825 | 0.824 |
SICK-R | 0.707 | 0.727 | 0.686 | 0.718 | |
STS12 | 0.732 | 0.650 | 0.705 | 0.654 | |
STS13 | 0.782 | 0.772 | 0.774 | 0.793 | |
STS14 | 0.733 | 0.719 | 0.718 | 0.730 | |
STS15 | 0.826 | 0.804 | 0.813 | 0.827 | |
STS16 | 0.799 | 0.804 | 0.783 | 0.810 | |
STS17 | 0.853 | 0.855 | 0.837 | 0.867 | |
STS22 | 0.615 | 0.616 | 0.631 | 0.635 | |
STSBenchmark | 0.777 | 0.785 | 0.752 | 0.783 | |
SummEval | 0.297 | 0.321 | 0.285 | 0.314 |
*1:著者らのLlama2-7bの実験は指示学習・選考学習後のものを使用してます。