最近の生成AIは、Google の Transformer という技術を使っています。
しかし、同じ Transformer でも、計算の仕方が異なるモデルが並列して存在しています。
代表的なものは、
ChatGPT で使われている GPT系列の「デコーダ」系モデルと、
google が2018年に発表した BERT系列の「エンコーダ」系モデルでしょうか。
Hugging Face の文書には、たくさんの「TEXT MODELS」名が列挙されています。

これらのアーキテクチャの違いを意識しないといけないのは、モデルやトークナイザを読み込むときに使用する関数が変わるからです。
一応、自動で判別する関数「Auto」があるのですが、万能ではありません。
「GPT2」は Auto でも読み込み可能で、以下のコードでも実行できますが、
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto")
text = "What is AI"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
max_new_tokens=64,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.pad_token_id,
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(output)
系列を指定する場合は、GPT系列の関数を使う必要があります。
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2', torch_dtype=torch.float32)
text = "What is AI"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
max_new_tokens=64,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.pad_token_id,
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(output)
(torch_dtype=torch.float32 を指定しているのは、RuntimeError: “LayerNormKernelImpl” not implemented for ‘Half’ のエラーを回避するためです。エラーが出ないのであれば不要です)
モデルと関数の組み合わせを無視して、以下のように T5用のトークナイザ―を使用すると
import torch
from transformers import AutoModelForCausalLM, T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto")
text = "What is AI"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
max_new_tokens=64,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.pad_token_id,
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(output)
OSError: Can’t load tokenizer for ‘gpt2’. となり、実行できません。
アーキテクチャはたくさんあるので、どのモデルがどの系列なのかを把握するのは大変です。

コメント