您的位置:首頁 > 軟件教程 > 教程 > 視覺語言模型詳解

視覺語言模型詳解

來源:好特整理 | 時間:2024-04-30 08:46:41 | 閱讀:188 |  標簽: 視覺   | 分享到:

視覺語言模型可以同時從圖像和文本中學習,因此可用于視覺問答、圖像描述等多種任務(wù)。本文,我們將帶大家一覽視覺語言模型領(lǐng)域: 作個概述、了解其工作原理、搞清楚如何找到真命天“模”、如何對其進行推理以及如何使用最新版的 trl 輕松對其進行微調(diào)。 什么是視覺語言模型? 視覺語言模型是可以同時從圖像和文本中

視覺語言模型可以同時從圖像和文本中學習,因此可用于視覺問答、圖像描述等多種任務(wù)。本文,我們將帶大家一覽視覺語言模型領(lǐng)域: 作個概述、了解其工作原理、搞清楚如何找到真命天“模”、如何對其進行推理以及如何使用最新版的 trl 輕松對其進行微調(diào)。

什么是視覺語言模型?

視覺語言模型是可以同時從圖像和文本中學習的多模態(tài)模型,其屬于生成模型,輸入為圖像和文本,輸出為文本。大視覺語言模型具有良好的零樣本能力,泛化能力良好,并且可以處理包括文檔、網(wǎng)頁等在內(nèi)的多種類型的圖像。其擁有廣泛的應(yīng)用,包括基于圖像的聊天、根據(jù)指令的圖像識別、視覺問答、文檔理解、圖像描述等。一些視覺語言模型還可以捕獲圖像中的空間信息,當提示要求其檢測或分割特定目標時,這些模型可以輸出邊界框或分割掩模,有些模型還可以定位不同的目標或回答其相對或絕對位置相關(guān)的問題,F(xiàn)有的大視覺語言模型在訓練數(shù)據(jù)、圖像編碼方式等方面采用的方法很多樣,因而其能力差異也很大。

視覺語言模型詳解

開源視覺語言模型概述

Hugging Face Hub 上有很多開放視覺語言模型,下表列出了其中一些佼佼者。

  • 其中有基礎(chǔ)模型,也有可用于對話場景的針對聊天微調(diào)的模型。
  • 其中一些模型具有“接地 (grounding )”功能,因此能夠減少模型幻覺。
  • 除非另有說明,所有模型的訓練語言皆為英語。
模型 可否商用 模型尺寸 圖像分辨率 其它能力
LLaVA 1.6 (Hermes 34B) ? 34B 672x672
deepseek-vl-7b-base ? 7B 384x384
DeepSeek-VL-Chat ? 7B 384x384 聊天
moondream2 ? ~2B 378x378
CogVLM-base ? 17B 490x490
CogVLM-Chat ? 17B 490x490 接地、聊天
Fuyu-8B ? 8B 300x300 圖像中的文本檢測
KOSMOS-2 ? ~2B 224x224 接地、零樣本目標檢測
Qwen-VL ? 4B 448x448 零樣本目標檢測
Qwen-VL-Chat ? 4B 448x448 聊天
Yi-VL-34B ? 34B 448x448 雙語 (英文、中文)

尋找合適的視覺語言模型

有多種途徑可幫助你選擇最適合自己的模型。

視覺競技場 (Vision Arena) 是一個完全基于模型輸出進行匿名投票的排行榜,其排名會不斷刷新。在該競技場上,用戶輸入圖像和提示,會有兩個匿名的不同的模型為其生成輸出,然后用戶可以基于他們的喜好選擇一個輸出。這種方式生成的排名完全是基于人類的喜好的。

視覺語言模型詳解

開放 VLM 排行榜 提供了另一種選擇,各種視覺語言模型按照所有指標的平均分進行排名。你還可以按照模型尺寸、私有或開源許可證來篩選模型,并按照自己選定的指標進行排名。

視覺語言模型詳解

VLMEvalKit 是一個工具包,用于在視覺語言模型上運行基準測試,開放 VLM 排行榜就是基于該工具包的。

還有一個評估套件是 LMMS-Eval ,其提供了一個標準命令行界面,你可以使用 Hugging Face Hub 上托管的數(shù)據(jù)集來對選定的 Hugging Face 模型進行評估,如下所示:

accelerate launch --num_processes=8 -m lmms_eval --model llava --model_args pretrained="liuhaotian/llava-v1.5-7b" --tasks mme,mmbench_en --batch_size 1 --log_samples --log_samples_suffix llava_v1.5_mme_mmbenchen --output_path ./logs/

視覺競技場和開放 VLM 排行榜都僅限于提交給它們的模型,且需要更新才能添加新模型。如果你想查找其他模型,可以在 image-text-to-text 任務(wù)下瀏覽 hub 中的 模型 。

在排行榜中,你會看到各種不同的用于評估視覺語言模型的基準,下面我們選擇其中幾個介紹一下。

MMMU

針對專家型 AGI 的海量、多學科、多模態(tài)理解與推理基準 (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI,MMMU) 是評估視覺語言模型的最全面的基準。它包含 11.5K 個多模態(tài)問題,這些問題需要大學水平的學科知識以及跨學科 (如藝術(shù)和工程) 推理能力。

MMBench

MMBench 由涵蓋超過 20 種不同技能的 3000 道單選題組成,包括 OCR、目標定位等。論文還介紹了一種名為 CircularEval 的評估策略,其每輪都會對問題的選項進行不同的組合及洗牌,并期望模型每輪都能給出正確答案。

另外,針對不同的應(yīng)用領(lǐng)域還有其他更有針對性的基準,如 MathVista (視覺數(shù)學推理) 、AI2D (圖表理解) 、ScienceQA (科學問答) 以及 OCRBench (文檔理解)。

技術(shù)細節(jié)

對視覺語言模型進行預(yù)訓練的方法很多。主要技巧是統(tǒng)一圖像和文本表征以將其輸入給文本解碼器用于文本生成。最常見且表現(xiàn)最好的模型通常由圖像編碼器、用于對齊圖像和文本表征的嵌入投影子模型 (通常是一個稠密神經(jīng)網(wǎng)絡(luò)) 以及文本解碼器按序堆疊而成。至于訓練部分,不同的模型采用的方法也各不相同。

例如,LLaVA 由 CLIP 圖像編碼器、多模態(tài)投影子模型和 Vicuna 文本解碼器組合而成。作者將包含圖像和描述文本的數(shù)據(jù)集輸入 GPT-4,讓其描述文本和圖像生成相關(guān)的問題。作者凍結(jié)了圖像編碼器和文本解碼器,僅通過給模型饋送圖像與問題并將模型輸出與描述文本進行比較來訓練多模態(tài)投影子模型,從而達到對齊圖像和文本特征的目的。在對投影子模型預(yù)訓練之后,作者把圖像編碼器繼續(xù)保持在凍結(jié)狀態(tài),解凍文本解碼器,然后繼續(xù)對解碼器和投影子模型進行訓練。這種預(yù)訓練加微調(diào)的方法是訓練視覺語言模型最常見的做法。

視覺語言模型詳解

視覺語言模型詳解

再舉一個 KOSMOS-2 的例子,作者選擇了端到端地對模型進行完全訓練的方法,這種方法與 LLaVA 式的預(yù)訓練方法相比,計算上昂貴不少。預(yù)訓練完成后,作者還要用純語言指令對模型進行微調(diào)以對齊。還有一種做法,F(xiàn)uyu-8B 甚至都沒有圖像編碼器,直接把圖像塊饋送到投影子模型,然后將其輸出與文本序列直接串接送給自回歸解碼器。

大多數(shù)時候,我們不需要預(yù)訓練視覺語言模型,僅需使用現(xiàn)有的模型進行推理,抑或是根據(jù)自己的場景對其進行微調(diào)。下面,我們介紹如何在 transformers 中使用這些模型,以及如何使用 SFTTrainer 對它們進行微調(diào)。

在 transformers 中使用視覺語言模型

你可以使用 LlavaNext 模型對 Llava 進行推理,如下所示。

首先,我們初始化模型和數(shù)據(jù)處理器。

from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained(
    "llava-hf/llava-v1.6-mistral-7b-hf",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)
model.to(device)

現(xiàn)在,將圖像和文本提示傳給數(shù)據(jù)處理器,然后將處理后的輸入傳給 generate 方法。請注意,每個模型都有自己的提示模板,請務(wù)必根據(jù)模型選用正確的模板,以避免性能下降。

from PIL import Image
import requests

url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "[INST] \nWhat is shown in this image? [/INST]"

inputs = processor(prompt, image, return_tensors="pt").to(device)
output = model.generate(**inputs, max_new_tokens=100)

調(diào)用 decode 對輸出詞元進行解碼。

print(processor.decode(output[0], skip_special_tokens=True))

使用 TRL 微調(diào)視覺語言模型

我們很高興地宣布,作為一個實驗性功能, TRL 的 SFTTrainer 現(xiàn)已支持視覺語言模型!這里,我們給出了一個例子,以展示如何在 llava-instruct 數(shù)據(jù)集上進行 SFT,該數(shù)據(jù)集包含 260k 個圖像對話對。

llava-instruct 數(shù)據(jù)集將用戶與助理之間的交互組織成消息序列的格式,且每個消息序列皆與用戶問題所指的圖像配對。

要用上 VLM 訓練的功能,你必須使用 pip install -U trl 安裝最新版本的 TRL。你可在 此處 找到完整的示例腳本。

from trl.commands.cli_utils import SftScriptArguments, TrlParser

parser = TrlParser((SftScriptArguments, TrainingArguments))
args, training_args = parser.parse_args_and_config()

初始化聊天模板以進行指令微調(diào)。

LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""

現(xiàn)在,初始化模型和分詞器。

from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration
import torch

model_id = "llava-hf/llava-1.5-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer = tokenizer

model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16)

建一個數(shù)據(jù)整理器來組合文本和圖像對。

class LLavaDataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, examples):
        texts = []
        images = []
        for example in examples:
            messages = example["messages"]
            text = self.processor.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
            texts.append(text)
            images.append(example["images"][0])

        batch = self.processor(texts, images, return_tensors="pt", padding=True)

        labels = batch["input_ids"].clone()
        if self.processor.tokenizer.pad_token_id is not None:
            labels[labels == self.processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        return batch

data_collator = LLavaDataCollator(processor)

加載數(shù)據(jù)集。

from datasets import load_dataset

raw_datasets = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft")
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]

初始化 SFTTrainer ,傳入模型、數(shù)據(jù)子集、PEFT 配置以及數(shù)據(jù)整理器,然后調(diào)用 train() 。要將最終 checkpoint 推送到 Hub,需調(diào)用 push_to_hub()

from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text", # need a dummy field
    tokenizer=tokenizer,
    data_collator=data_collator,
    dataset_kwargs={"skip_prepare_dataset": True},
)

trainer.train()

保存模型并推送到 Hugging Face Hub。

trainer.save_model(training_args.output_dir)
trainer.push_to_hub()

你可在 此處 找到訓得的模型。你也可以通過下面的頁面試玩一下我們訓得的模型??。

視覺語言模型詳解

致謝

我們感謝 Pedro Cuenca、Lewis Tunstall、Kashif Rasul 和 Omar Sanseviero 對本文的評論和建議。


英文原文: https://hf.co/blog/vlms
原文作者: Merve Noyan,Edward Beeching
譯者: Matrix Yao (姚偉峰),英特爾深度學習工程師,工作方向為 transformer-family 模型在各模態(tài)數(shù)據(jù)上的應(yīng)用及大規(guī)模模型的訓練推理。

小編推薦閱讀

好特網(wǎng)發(fā)布此文僅為傳遞信息,不代表好特網(wǎng)認同期限觀點或證實其描述。

相關(guān)視頻攻略

更多

掃二維碼進入好特網(wǎng)手機版本!

掃二維碼進入好特網(wǎng)微信公眾號!

本站所有軟件,都由網(wǎng)友上傳,如有侵犯你的版權(quán),請發(fā)郵件[email protected]

湘ICP備2022002427號-10 湘公網(wǎng)安備:43070202000427號© 2013~2025 haote.com 好特網(wǎng)