GPTで長いテキストの埋め込み実装を解説【OpenAI-cookbookの実装】

OpenAI cookbookの公式的な実装方法について解説していきたいと思います。

https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb

OpenAIの埋め込みモデルは、最大長を超えるテキストを埋め込むことはできません。最大長はモデルによって異なり、文字列の長さではなくトークンで測定されます。

この記事では、モデルの最大コンテキスト長を超えるテキストをどのように取り扱うかを解説します。text-embedding-ada-002からの埋め込みを使用してデモンストレーションを行いますが、同様の考え方は他のモデルやタスクにも適用できます。

1. モデルのコンテキスト長

環境:google colab

まずは必要ライブラリのインストール

!pip install openai
!pip install tiktoken
!pip insstall retry

モデルを選択し、APIから埋め込みを取得する関数を定義します。

import openai
from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_not_exception_type

EMBEDDING_MODEL = 'text-embedding-ada-002'
EMBEDDING_CTX_LENGTH = 8191
EMBEDDING_ENCODING = 'cl100k_base'

@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(openai.InvalidRequestError))
def get_embedding(text_or_tokens, model=EMBEDDING_MODEL):
    return openai.Embedding.create(input=text_or_tokens, model=model)["data"][0]["embedding"]

text-embedding-ada-002モデルは、cl100k_baseエンコーディングを使用した場合、8191トークンのコンテキスト長を持ちます。この限度を超えるとエラーが発生することが確認できます。

long_text = 'AGI ' * 5000
try:
    get_embedding(long_text)
except openai.InvalidRequestError as e:
    print(e)

特に大量の埋め込みをプログラムで処理する場合、このようなエラーは避けたいところです。しかし、最大コンテキスト長を超えるテキストが存在する可能性があります。以下では、これらの長いテキストを取り扱う主なアプローチを説明し、レシピを提供します。それは、(1) テキストを最大許容長に単純に切り詰める、および (2) テキストを分割し、各チャンクを個別に埋め込む、の2つの方法です。

1. 入力テキストの切り詰め

最もシンプルな解決策は、入力テキストを最大許容長に切り詰めることです。コンテキスト長はトークンで測定されるため、テキストを切り詰める前にまずトークン化する必要があります。APIはテキスト形式またはトークン形式の入力を受け付けるため、適切なエンコーディングを使用している限り、トークンを再び文字列形式に変換する必要はありません。以下にそのような切り詰め関数の例を示します。

import tiktoken

def truncate_text_tokens(text, encoding_name=EMBEDDING_ENCODING, max_tokens=EMBEDDING_CTX_LENGTH):
    """指定されたエンコーディングに従って `max_tokens` 内になるように文字列を切り詰めます。"""
    encoding = tiktoken.get_encoding(encoding_name)
    return encoding.encode(text)[:max_tokens]

先ほどの例は、エラーなしで動作します。

truncated = truncate_text_tokens(long_text)
len(get_embedding(truncated))
# 1536

2. 入力テキストの分割

切り詰めが機能する一方で、潜在的に関連するテキストを破棄することは明らかな欠点です。もう一つのアプローチは、入力テキストをチャンクに分割し、各チャンクを個別に埋め込むことです。その後、チャンクの埋め込みを別々に使用するか、あるいはそれらを何らかの方法で組み合わせることができます(例えば、各チャンクのサイズによって重み付けされた平均を取るなど)。

ここでは、シーケンスをチャンクに分割するPythonの独自の関数を取り入れます。

from itertools import islice

def batched(iterable, n):
    """データを長さ n のタプルにバッチ処理します。 最後のバッチは短くなる可能性があります。"""
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while (batch := tuple(islice(it, n))):
        yield batch

次に、文字列をトークンにエンコードし、それをチャンクに分割する関数を定義します。

def chunked_tokens(text, encoding_name, chunk_length):
    encoding = tiktoken.get_encoding(encoding_name)
    tokens = encoding.encode(text)
    chunks_iterator = batched(tokens, chunk_length)
    yield from chunks_iterator

最後に、入力テキストが最大コンテキスト長を超える場合でも、入力トークンをチャンクに分割し、各チャンクを個別に埋め込むことで、埋め込みリクエストを安全に処理する関数を書くことができます。averageフラグをTrueに設定すると、チャンクの埋め込みの重み付き平均を返し、Falseに設定すると、チャンクの埋め込みのリストをそのまま返します。

import numpy as np

def len_safe_get_embedding(text, model=EMBEDDING_MODEL, max_tokens=EMBEDDING_CTX_LENGTH, encoding_name=EMBEDDING_ENCODING, average=True):
    chunk_embeddings = []
    chunk_lens = []
    for chunk in chunked_tokens(text, encoding_name=encoding_name, chunk_length=max_tokens):
        chunk_embeddings.append(get_embedding(chunk, model=model))
        chunk_lens.append(len(chunk))

    if average:
        chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens)
        chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)  # normalizes length to 1
        chunk_embeddings = chunk_embeddings.tolist()
    return chunk_embeddings
average_embedding_vector = len_safe_get_embedding(long_text, average=True)
chunks_embedding_vectors = len_safe_get_embedding(long_text, average=False)

print(f"Setting average=True gives us a single {len(average_embedding_vector)}-dimensional embedding vector for our long text.")
print(f"Setting average=False gives us {len(chunks_embedding_vectors)} embedding vectors, one for each of the chunks.")

コード解説

上記のコードは、大きなテキストを小さなチャンクに分割し、それらのチャンクごとにベクトル表現(エンベッディング)を生成するためのものです。各関数について詳しく見てみましょう。

1. batched()関数

from itertools import islice
 
def batched(iterable, n):
    """データを長さ n のタプルにバッチ処理します。 最後のバッチは短くなる可能性があります。"""
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while (batch := tuple(islice(it, n))):
        yield batch

この関数は、イテラブル(繰り返し可能な)オブジェクトを取り、その要素を一定のサイズnのタプルにまとめます。この関数はイテレータ(次の要素を順次返すオブジェクト)を使用して、最初のn要素を取得し、それらをタプルとして返します。これを全ての要素が処理されるまで繰り返します。

itertools.islice()について

itertools.islice()はPythonのitertoolsモジュールに含まれる関数で、イテラブル(リスト、文字列、辞書など)から指定された範囲の要素を返すために使用されます。

islice()関数は3つの引数を取ります:イテラブル、開始インデックス、終了インデックスです。開始インデックスはオプションで、デフォルトは0(最初の要素)です。終了インデックスは必須で、このインデックスの要素は結果に含まれません。

例を見てみましょう:

from itertools import islice

# リストから要素をスライスする
numbers = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
print(list(islice(numbers, 2, 6)))  # [2, 3, 4, 5]

この例では、リストnumbersの2番目から5番目までの要素を取得しています。islice()関数はイテレータを返すので、結果をリストとして表示するためにlist()関数を使用しています。

islice()は “lazy”であるため、特に大きなイテラブルに対してメモリ効率が良いです。これは、全ての要素をメモリに読み込むのではなく、必要な要素のみを一度に取得するためです。

通常のスライスと何が違うのかというと、
Pythonの通常のスライス操作とitertools.islice()の主な違いは、”メモリ効率”と”イテレータのサポート”です。

メモリ効率: 通常のスライス操作は、新たなリストを作成してそれを返すため、元のリストの大きな部分をコピーする場合、メモリ消費量が大きくなります。一方、islice()はイテレータを返すため、メモリに全てをロードする必要はありません。これは大規模なデータを扱う場合に特に重要な違いとなります。

イテレータのサポート: 通常のスライス操作はリストや文字列のようなシーケンス型に対してのみ動作しますが、islice()は全てのイテレータに対して動作します。これにより、リストや文字列だけでなく、ファイルの行や生成器など、任意のイテレータから部分を抽出することができます。

例えば、以下のコードは非常に大きなファイルから最初の10行を読み取ることができます:

with open('large_file.txt') as f:
    for line in islice(f, 10):
        print(line)

上記の処理では開かれたファイルはディスク上に展開され、メモリ上には10個ずつ乗ります。これによりメモリを効率的に利用することが可能となります。

iter()関数について

このコードでは、iter()関数はiterableからイテレータを作成するために使用されています。このイテレータitは、その後islice(it, n)に渡され、各ループでn個の要素が取り出されます。

なぜiter()関数が必要なのかと言うと、それはislice()がイテレータの状態(つまり次にどの要素を取り出すべきか)を保存するからです。イテレータは次にどの要素を取り出すべきかを覚えており、それを次々と提供します。islice()はその性質を利用し、イテレータから連続するn個の要素を提供します。

もしiter()を使用せずに直接islice(iterable, n)を呼び出すと、islice()は毎回新たなイテレータを作成し、その都度最初からn個の要素を取り出そうとします。そのため、各ループで新たなバッチが生成されず、同じバッチが繰り返し生成されることになります。

したがって、iter()関数はiterableからイテレータを一度だけ作成し、その後islice()がそのイテレータの状態を利用できるようにするために必要です。

下記コードについて

    while (batch := tuple(islice(it, n))):
        yield batch

このコード中の:=記号はPythonの「代入式」(別名「ウォルラス演算子」)を示しています。これはPython 3.8以降で利用可能な新機能で、式の中で変数への代入を行うことができます。

このコードのwhile (batch := tuple(islice(it, n))):部分では、「islice(it, n)から得られる結果をbatchに代入し、そのbatchが空(つまりislice(it, n)がイテレータからこれ以上要素を取り出せない)でない限りループを続ける」という意味になります。

これにより、whileループの条件式であると同時に、ループ内で使用するbatchの値を設定することができます。これはコードを短く、読みやすくするのに役立ちます。

2. chunked_tokens()関数

def chunked_tokens(text, encoding_name, chunk_length):
    encoding = tiktoken.get_encoding(encoding_name)
    tokens = encoding.encode(text)
    chunks_iterator = batched(tokens, chunk_length)
    yield from chunks_iterator

この関数は、テキストをトークン化し、そのトークンを一定の長さchunk_lengthのチャンクに分割します。具体的には、まずtiktoken.get_encoding(encoding_name)を使用してエンコーディングを取得し、そのエンコーディングを使用してテキストをトークン化します。その後、先ほど定義したbatched()関数を使用してトークンをチャンクに分割します。

3. len_safe_get_embedding()関数

import numpy as np
 
def len_safe_get_embedding(text, model=EMBEDDING_MODEL, max_tokens=EMBEDDING_CTX_LENGTH, encoding_name=EMBEDDING_ENCODING, average=True):
    chunk_embeddings = []
    chunk_lens = []
    for chunk in chunked_tokens(text, encoding_name=encoding_name, chunk_length=max_tokens):
        chunk_embeddings.append(get_embedding(chunk, model=model))
        chunk_lens.append(len(chunk))
 
    if average:
        chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens)
        chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)  # normalizes length to 1
        chunk_embeddings = chunk_embeddings.tolist()
    return chunk_embeddings

この関数は、与えられたテキストに対してエンベッディングを生成します。まず、テキストをchunked_tokens()関数でチャンクに分割し、各チャンクに対してget_embedding()関数(このコードスニペットには定義が含まれていませんが、おそらくモデルを使ってエンベッディングを取得する関数)を呼び出してエンベッディングを計算します。そして、各チャンクの長さを保存します。

最後に、averageパラメータがTrueの場合、エンベッディングをチャンクの長さで重み付けした平均を計算し、結果のエンベッディングの長さを1に正規化します。この平均化と正規化は、異なる長さのチャンクから得られるエンベッディングを統一的に扱うために行われます。

正規化について

このコードの正規化部分は、ベクトルの長さ(大きさ)を1にするために行われます。具体的には、chunk_embeddings / np.linalg.norm(chunk_embeddings)という部分で、chunk_embeddingsベクトルをそのノルム(長さ)で除算しています。

ベクトルの正規化は、データのスケールに依存しない特性を確保するためによく行われます。正規化されたベクトルは、その方向だけを示し、大きさは一定(この場合は1)になります。

このコードの文脈では、各チャンクの埋め込みが異なる長さ(単語数)を持つ可能性があります。長いチャンクはより多くの情報を含む可能性があり、その影響力が大きくなりすぎることを防ぐために、正規化によって各埋め込みベクトルの大きさを一定に保ちます。これにより、全てのチャンクの埋め込みが等しく寄与することを確保します。

具体的に

具体的な例を使って、正規化を行った場合のベクトルの値を説明します。例えば、以下の3次元のエンベッディングベクトルがあるとします。

v = [1, 2, 3]

まず、ベクトルの長さ(ノルム)を計算します。

ノルム = √(1^2 + 2^2 + 3^2) = √14

次に、ベクトルの各次元の値をノルムで除算して正規化します。

正規化されたベクトル = [1/√14, 2/√14, 3/√14] ≈ [0.267, 0.534, 0.801]

このように、正規化を行うと、元のベクトルの各次元の値が変わり、正規化されたベクトルの長さが1になります。この正規化されたベクトルは、同じスケールで比較しやすくなり、テキスト間の類似性や関連性を計算しやすくなります。

正規化のデメリット

正規化にはデメリットも存在しますが、主なものは以下の通りです。

1. 情報の損失: 正規化では、各特徴量のスケールを変更して同一の尺度に合わせます。これにより、元のデータの特徴が失われることがあります。例えば、ある特徴量がもともと大きな範囲を持っていた場合、正規化によってその範囲が縮小され、他の特徴量との関係が変わってしまいます。

2. 計算コストの増加: 正規化は、データセット全体に対して適用される前処理手法であるため、計算コストが増加します。特に、大規模なデータセットに対して正規化を実施する場合、処理にかかる時間が長くなることがあります。

3. 可読性の低下: 正規化によって、特徴量の単位や尺度が変わることがあります。これにより、データの解釈が難しくなることがあります。例えば、正規化された特徴量の値が0.5である場合、元のデータの尺度に戻すことなく、その値が高いのか低いのかを判断することが難しくなります。

4. 不適切な正規化手法: データセットによっては、特徴量のスケールや分布が異なるため、適切な正規化手法を選択することが重要です。不適切な正規化手法を選択すると、モデルの性能が低下することがあるため、注意が必要です。

5. 疎なデータの扱い: 疎なデータ(ほとんどの要素がゼロであるデータ)の場合、正規化が不適切な結果をもたらすことがあります。例えば、L2正規化を適用すると、すべての非ゼロ要素が非常に小さい値になり、モデルの性能が低下する可能性があります。

これらのデメリットを考慮して、適切な正規化手法を選択し、データの前処理を行うことが重要です。