使用 vLLM 对自定义实现模型进行高速推理

2025-04-05T14:06:34+08:00 | 18分钟阅读 | 更新于 2025-04-05T14:06:34+08:00

Macro Zhao

使用 vLLM 对自定义实现模型进行高速推理

介绍

在本文中,我们将解释如何使用 vLLM 使用视频生成模型推断独特的多模式模型。 vLLM 是一个用于高速推理和服务的 LLM 库,它非常易于使用,支持 Llama 和 Qwen 等知名模型。另一方面,除了官方文档 之外,关于如何合并自定义实现模型的信息非常少,甚至官方文档也没有特别的信息,因此很难上手。

因此,在本文中,我们将仔细解释将您自己的多模态模型合并到 vLLM 中的具体步骤,包括官方文档中未包含的信息。具体来说,我们以视频生成模型为例,详细讲解如何将 Hugging Face Transformers 库中实现的模型适配到 vLLM。我们还将介绍利用 vLLM 可以实现的推理速度的提升以及实际结果。

本文涵盖以下主题:

  • 如何使用 vLLM 实现自己的模型
  • 如何使用 vLLM 处理您自己的多模式数据

什么是 vLLM?

vLLM 是一个开源库,用于加速大型语言模型(LLM)的推理并使模型服务变得简单而高效。近年来,Hugging Face Transformers 被广泛用于 LLM 的训练和实验。然而,Transformers 的实现在推理过程中存在一些效率低下的问题。最值得注意的问题是键值缓存管理效率低下。

键值缓存是 Transformer 模型用来在推理过程中有效重用过去的上下文信息的一种机制。该缓存存储了过去令牌激活的结果,并在生成新令牌时重复使用。然而,在 Transformers 库的标准实现中,这种缓存的管理结构使得其容易出现不必要的内存消耗和处理延迟,从而限制了 LLM 推理的速度。

vLLM旨在通过采用独特的PagedAttention 算法来解决这一问题。 PagedAttention 是一种优化键值缓存分配和管理的机制,可显著减少缓存内存浪费,同时实现快速访问和更新。因此,与传统库相比,vLLM 可以实现更高的推理吞吐量。

此外,我们还做出了各种努力来加快这一过程,包括实现在使用 LLM 服务时有效的连续批处理,以及支持量化(GPTQ、AWQ、INT4/8、FP8)。

vLLM 与 Hugging Face Transformers 高度兼容,对于流行模型来说,不需要任何特殊工作。因此,它在开发速度很重要的项目和原型设计阶段特别有用。

处理 vLLM 中的多模态模型

vLLM 是一个用于高速推理和提供 LLM 的库,但它也可以加速处理图像和音频以及语言等输入的多模式模型的推理。此外,如果实施得当,任何自回归 Transformer 模型都可以变得更快。

通过vLLM包vllm.multimodal 支持多模式模型。用户可以使用字段vllm.inputs.PromptType以及multi_modal_data 文本和令牌提示将多模式输入传递给模型。目前,vLLM 提供对图像和视频数据的内置支持,但可以扩展以处理其他模式。

vLLM 还提供了使用多模态模型进行离线推理的示例。例如,官方文档提供了使用单幅图像输入进行 推理和组合多幅图像进行推理的示例代码。

实现独特的视频生成模型

在这篇博客中,我们考虑使用 vLLM 来加速视频生成模型“Terra”。

该模型使用图像标记器将视频的每个图像帧转换为离散标记序列,然后使用代表图像序列的标记序列作为输入来预测未来图像序列的离散标记序列。然后使用解码器将预测的离散标记序列转换为图像序列以生成视频。

您还可以输入称为动作的向量序列来进行条件反射。该矢量序列是一个 3 x 6 矩阵,由六个三维矢量组成,插入在每个图像帧之间。由于有 576 个离散标记代表一个图像帧,因此在推理过程中,会为图像中的每 576 个离散标记插入一个 6 标记向量。

Hugging Face在Transformers中的实现如下。该模型基于Llama架构的LLM模型,但不同之处在于它包含处理动作向量的机制和可以作为位置编码进行学习的特殊位置编码。

  • 使用 Transformer 实现
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from transformers import LlamaConfig, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

from ..positional_embedding import LearnableFactorizedSpatioTemporalPositionalEmbedding

class LlamaActionConfig(LlamaConfig):
    model_type = "llama_action"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.num_spatio_embeddings = kwargs.get("num_spatio_embeddings", 582)
        self.num_temporal_embeddings = kwargs.get("num_temporal_embeddings", 25)
        self.num_action_embeddings = kwargs.get("num_action_tokens", 6)
        self.num_image_patches = kwargs.get("num_image_patches", 576)
        self.action_dim = kwargs.get("action_dim", 3)


class LlamaActionForCausalLM(LlamaForCausalLM):
    config_class = LlamaActionConfig

    def __init__(self, config: LlamaActionConfig):
        super().__init__(config)

        self.num_spatio_embeddings = config.num_spatio_embeddings
        self.num_temporal_embeddings = config.num_temporal_embeddings
        self.num_image_patches = config.num_image_patches
        self.num_action_embeddings = config.num_action_embeddings

        self.pos_embedding_spatio_temporal = LearnableFactorizedSpatioTemporalPositionalEmbedding(
            config.num_spatio_embeddings, config.num_temporal_embeddings, config.hidden_size,
        )

        self.action_projection = nn.Linear(config.action_dim, config.hidden_size)

        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        actions: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.Tensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if labels is not None:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        inputs_embeds = self.model.get_input_embeddings()(input_ids)
        if past_key_values is None:
            inputs_embeds_list = torch.split(
                inputs_embeds,
                split_size_or_sections=self.num_image_patches,
                dim=1
            )
            actions_list = torch.split(
                actions,
                split_size_or_sections=self.num_action_embeddings,
                dim=1
            )

            embeddings = []
            if len(inputs_embeds_list) == len(actions_list):
                # 学习时使用的的逻辑,推理时几乎不用
                for inputs_embeds, action_embeds in zip(inputs_embeds_list, actions_list):
                    action_features = self.action_projection(action_embeds)
                    embeddings.append(inputs_embeds)
                    embeddings.append(action_features)
            elif len(inputs_embeds_list) < len(actions_list):
                # 推理使用embeded
                for i, inputs_embeds in enumerate(inputs_embeds_list):
                    embeddings.append(inputs_embeds)
                    if i < len(inputs_embeds_list) - 1:
                        # 最后一帧可能是生成过程中的图像令牌序列,因此不添加动作嵌入。
                        action_embeds = self.action_projection(actions_list[i])
                        embeddings.append(action_embeds)
                if inputs_embeds_list[-1].size(1) == self.num_image_patches:
                    # 如果图像令牌正好输出了一帧,则在添加动作嵌入的基础上,进一步添加用于下一帧的文本令牌。
                    action_embeds = self.action_projection(actions_list[len(inputs_embeds_list) - 1])
                    embeddings.append(action_embeds)
        else:
            past_key_values_length = past_key_values[0][0].size(2)
            embeddings = []
            # image, image, ..., image, action, action, ..., action格式进行输入
            # 由于只生成图像令牌,所以在生成完一帧的时添加动作令牌。
            if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
                seq_index = past_key_values_length // self.num_spatio_embeddings + 1
                actions_list = torch.split(
                    actions,
                    split_size_or_sections=self.num_action_embeddings,
                    dim=1
                )
                action_features = self.action_projection(actions_list[seq_index - 1])
                embeddings.append(action_features)
                embeddings.append(inputs_embeds)
            else:
                pass

        if len(embeddings) > 0:
            inputs_embeds = torch.cat(embeddings, dim=1)

        # Insert Spatio Temporal Positional Embedding
        past_key_values_length = past_key_values[0][0].size(2) if past_key_values is not None else 0
        inputs_embeds += self.pos_embedding_spatio_temporal(inputs_embeds, past_key_values_length)

        outputs = self.model(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        logits = self.lm_head(sequence_output).contiguous()

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        use_cache=None,
        **kwargs):
        batch_size = input_ids.size(0)
        seq_length = input_ids.size(1)
        n_frames = seq_length // self.num_image_patches
        attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
        if seq_length % self.num_image_patches != 0:
            n_last_frame_tokens = seq_length % self.num_image_patches
            attention_mask_length += n_last_frame_tokens
        else:
            print(f"attempting to generate new frame - frame no: {n_frames + 1}")
        attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)

        # cut decoder_input_ids if past_key_values is used
        if past_key_values is not None:
            past_length = past_key_values[0][0].size(2)
            if input_ids.size(1) > past_length:
                remove_prefix_length = past_length
            else:
                remove_prefix_length = input_ids.size(1) - 1
            input_ids = input_ids[:, remove_prefix_length:]
            seq_length = input_ids.size(1)
            past_key_values_length = past_key_values[0][0].size(2)
            mask_seq_length = seq_length + past_key_values_length
            if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
                mask_seq_length += self.num_action_embeddings
            attention_mask = torch.ones((batch_size, mask_seq_length), device=input_ids.device, dtype=torch.long)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "actions": kwargs.get("actions"),
            "past_key_values": past_key_values,
            "use_cache": use_cache,
        }
  • 位置编码详细信息 这里使用的位置编码是空间位置编码(Spatial Positional Encoding)和时间位置编码(Temporal Positional Encoding)的分解,空间位置编码指定同一时间步长内的位置,时间位置编码表达每个时间步长的位置。有 582 个标记具有相同的时间步长,因为有 576 个图像标记和 6 个动作标记。因此,空间位置编码代表同一时间步内的 582 个位置。另一方面,在这个模型中,我们一次最多可以处理 25 帧,因此时间位置编码代表 25 个位置。
import torch
import torch.nn as nn


class LearnableFactorizedSpatioTemporalPositionalEmbedding(nn.Module):
    def __init__(self, num_spatio_embeddings: int, num_temporal_embeddings: int, embedding_dim: int):
        super().__init__()
        self.spatio_embeddings = nn.Embedding(num_spatio_embeddings, embedding_dim)
        self.temporal_embeddings = nn.Embedding(num_temporal_embeddings, embedding_dim)
        self.num_spatio_embeddings = num_spatio_embeddings
        self.num_temporal_embeddings = num_temporal_embeddings

    def forward(self, attention_mask: torch.LongTensor, past_key_values_length):
        seq_length = attention_mask.size(1)
        batch_size = attention_mask.size(0)

        if past_key_values_length == 0:
            # [0, 1, 2, ..., num_spatio_embeddings-1, 0, 1, 2, ..., num_spatio_embeddings-1, ...]
            spatio_indices = torch.arange(
                self.num_spatio_embeddings,
                device=attention_mask.device
            ).repeat(self.num_temporal_embeddings).unsqueeze(0).repeat((batch_size, 1))

            # [0, 0, 0, ..., 1, 1, 1, ..., 2, 2, 2, ...]
            temporal_indices = torch.arange(
                self.num_temporal_embeddings,
                device=attention_mask.device
            ).repeat_interleave(self.num_spatio_embeddings).unsqueeze(0).repeat((batch_size, 1))

            spatio_indices = spatio_indices[:, :seq_length]
            temporal_indices = temporal_indices[:, :seq_length]
            
        else:
            temporal_index = past_key_values_length // self.num_spatio_embeddings
            spatio_index = past_key_values_length % self.num_spatio_embeddings
            spatio_indices = torch.tensor([[spatio_index]], device=attention_mask.device).repeat((batch_size, 1))
            temporal_indices = torch.tensor([[temporal_index]], device=attention_mask.device).repeat((batch_size, 1))

        return self.spatio_embeddings(spatio_indices) + self.temporal_embeddings(temporal_indices)

转换为 vLLM 模型的策略

在我们将在这里处理的模型中,在正常 LLM 中被视为语言标记序列的是表示图像的离散标记,而多模态输入是动作向量序列。

此外,由于动作向量序列的数量与帧的数量一样多,因此将会有多个多模态输入。这与使用多张图像作为输入的 VLM 推理 的情况类似,如 vLLM 文档中所述。

如上所述,在这种情况下,vLLM 用例的某些方面有些不常见,因此我们将在牢记这些要点的同时考虑如何使用 vLLM 实现它。

vLLM 的数据处理流程在输入处理管道 中介绍,请根据需要参考它以更容易理解。

准备输入标记序列

由于vLLM 模型[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]将输入一系列离散标记,例如

此外,必须提前在需要输入多模式信息的区域填充占位符 。例如,如果您知道要在上面的离散标记序列的开头插入 32 个图像嵌入,那么您将需要输入包含 32 个占位符特殊标记(例如 -1)的输入。所以上面例子中的离散标记序列[-1, -1, ..., -1, 2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]最终将以如下形式输入:

像 LLaVA 这样的著名model是如何准备占位符的? 提供占位符的推荐方式是input_processor实现:例如,LLaVA-1.5 模型input_processor_for_llava()实现了一项功能,该功能扩展了输入提示的特殊标记部分,并插入了等于 LLaVA-1.5 图像输入中使用的标记数量的占位符 ID。

但是这种方式使得数据流变得隐秘且难以理解,所以在这里介绍的实现中,我们在明确输入的令牌字符串中保留占位符,并且input_processor不实现。

如何添加多个多模式输入

由于我们的模型将有多个多模式输入,我们需要在输入标记序列中为每个传入的多模式输入创建带有特殊标记的占位符。

另外,在实现Hugging Face时,该forward方法每生成576个token,就为6个token添加一个动作向量序列,generate()只需调用一次该方法即可生成一个图像帧序列。然而在vLLM实现中,当你提供上下文图像帧序列的离散token序列和动作向量序列时,只会生成下一个图像帧的576个token,而generate()该方法的调用次数是你想要生成的图像帧数量的倍数。

如何输入动作向量序列

在 vLLM 中,除离散标记序列或文本之外的任何输入都被视为多模态数据,并且与离散标记序列/文本分开输入,如下所示:

from PIL import Image
from vllm import LLM, SamplingParams


inputs = {
    "prompt": "Describe the image. Picture 1: <img></img>\n",
    "multi_modal_data": {
        "image":  Image.open("/path/to/image").convert("RGB")
    }
}

llm = LLM(
    model="Qwen/Qwen-VL",
    trust_remote_code=True,
    max_model_len=1024,
    max_num_seqs=2
)
sampling_params = SamplingParams(temperature=0.2, max_tokens=64, stop_token_ids=None)
outputs = llm.generate(inputs, sampling_params=sampling_params)
generated_text = outputs[0].outputs[0].text
print(generated_text)

在上面的例子中,给出了一张图像作为多模态数据,image并且键表示该数据属于模态图像,因此以vllm.multimodal.image.ImagePlugin 状态输入的图像数据PIL.Image被转换为 PyTorch Tensor 并最终输入到模型中。

同样,如果您想输入自己的模态数据,请执行以下操作:

inputs = {
    "prompt": "...",
    "multi_modal_data": {
        "actions": torch.tensor([[0.0, 0.0, 0.0],
                                [0.0, 1.8, 0.5]])
    }
}

<modality>: <data>按照格式准备数据。这里,Plugin截至 2024 年 12 月,唯一可用于image处理video数据的选项是“处理数据”,因此您需要自己实现。Plugin的实现后面会讲到。

输出处理

对于一般的LLM推理来说,当文本输入到vLLM中时,会自动使用tokenizer将其转换成离散的token序列,再输入到Transformer中生成剩余的token序列,最后将生成的部分去token化并以句子的形式输出。

在该模型中,输入是使用图像标记器预先离散化的标记序列,输出在后期单独生成,因此必须关闭应用标记器的过程。此方法稍后也会描述。

vLLM 模型的实现

根据上面讨论的策略,使用 vLLM 实现了以下内容。

import sys
from array import array
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Mapping
from pathlib import Path

import numpy as np
import torch
from torch import nn
from transformers import LlamaConfig, AutoConfig, AutoModel

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
from vllm.inputs import InputContext, INPUT_REGISTRY
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
    get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_hip

from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

# HF path
sys.path.append(str(Path(__file__).parent.parent))
# HF import
from models.llama_action import LlamaActionConfig, LlamaActionForCausalLM

# HF实装的model在HF登录
AutoConfig.register("llama_action", LlamaActionConfig)
AutoModel.register(LlamaActionConfig, LlamaActionForCausalLM)


# 这是一个用于自定义实现的"action"模态的插件。实际上,它只是将输入的数据原样传输过去。
class ActionsPlugin(MultiModalPlugin):
    def get_data_key(self) -> str:
        return "actions"

    def _default_input_mapper(self, ctx: InputContext, data: object | List[object], **mm_processor_kwargs) -> MultiModalInputs:
        return MultiModalInputs({"actions": data})

    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        return 4096


MULTIMODAL_REGISTRY.register_plugin(ActionsPlugin())


# 推理Positional Encoding方法
class LearnableFactorizedSpatioTemporalPositionalEmbedding(nn.Module):
    def __init__(self, num_spatio_embeddings: int, num_temporal_embeddings: int, embedding_dim: int):
        super().__init__()
        self.spatio_embeddings = nn.Embedding(num_spatio_embeddings, embedding_dim)
        self.temporal_embeddings = nn.Embedding(num_temporal_embeddings, embedding_dim)
        self.num_spatio_embeddings = num_spatio_embeddings
        self.num_temporal_embeddings = num_temporal_embeddings

    def forward(self, positions: torch.Tensor):
        spatio_indices = positions % self.num_spatio_embeddings
        temporal_indices = positions // self.num_spatio_embeddings
        return self.spatio_embeddings(spatio_indices) + self.temporal_embeddings(temporal_indices)


# Precheck
def get_max_action_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlamaActionConfig)
    num_action_tokens = hf_config.num_action_embeddings
    num_frames = hf_config.num_temporal_embeddings - 1
    return num_action_tokens * num_frames


# Precheck
def create_dummy_data(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]):
    hf_config = ctx.get_hf_config(LlamaActionConfig)

    num_frames = hf_config.num_temporal_embeddings
    vocab_size = hf_config.vocab_size
    num_action_tokens = hf_config.num_action_embeddings
    num_image_tokens = hf_config.num_image_patches
    dummy_seq = []
    np.random.seed(0)
    for i in range(num_frames - 1):
        dummy_image_tokens = np.random.randint(0, vocab_size, num_image_tokens).tolist()
        dummy_seq.extend(dummy_image_tokens)
        dummy_action_tokens = [-3] * num_action_tokens
        dummy_seq.extend(dummy_action_tokens)
    seq_data = SequenceData(array("l", dummy_seq))

    action = torch.tensor([
        [0.0, 0.0, 0.0],
        [0.0, 2.0, 0.5],
        [0.0, 4.0, 1.0],
        [0.0, 6.0, 1.5],
        [0.0, 8.0, 2.0],
        [0.0, 10.0, 2.5],
        [0.0, 12.0, 3.0],
        [0.0, 14.0, 3.5],
        [0.0, 16.0, 4.0],
    ])
    actions = []
    for _ in range(num_frames - 1):
        actions.append(action[:num_action_tokens])
    actions = torch.cat(actions, dim=0)
    mm_data = {"actions": actions}
    return seq_data, mm_data


@MULTIMODAL_REGISTRY.register_input_mapper(data_type_key="actions")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("actions", get_max_action_tokens)
@INPUT_REGISTRY.register_dummy_data(create_dummy_data)
class VLLMLlamaActionForCausalLM(nn.Module, SupportsMultiModal):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings"
    }
    embedding_padding_modules = ["lm_head"]

    # BitandBytes specific attributes
    default_bitsandbytes_target_modules = [
        ".gate_proj.",
        ".down_proj.",
        ".up_proj.",
        ".q_proj.",
        ".k_proj.",
        ".v_proj.",
        ".o_proj.",
    ]
    # in TP, these weights are partitioned along the column dimension (dim=-1)
    column_parallel_weights_modules = [".down_proj.", ".o_proj."]
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }

    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm"
    }

    def __init__(
        self,
        config: LlamaActionConfig,
        multimodal_config: MultiModalConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.config = config
        self.multimodal_config = multimodal_config

        self.num_spatio_embeddings = config.num_spatio_embeddings
        self.num_temporal_embeddings = config.num_temporal_embeddings
        self.num_image_patches = config.num_image_patches
        self.num_action_embeddings = config.num_action_embeddings

        self.pos_embedding_spatio_temporal = LearnableFactorizedSpatioTemporalPositionalEmbedding(
            num_spatio_embeddings=self.num_spatio_embeddings,
            num_temporal_embeddings=self.num_temporal_embeddings,
            embedding_dim=config.hidden_size,
        )

        self.action_projection = nn.Linear(config.action_dim, config.hidden_size)

        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
                                lora_config=None,
                                prefix="model")
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
                quant_config=quant_config,
            )
            if config.tie_word_embeddings:
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
            self.sampler = Sampler()
        else:
            self.lm_head = PPMissingLayer()
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        """Forward pass for the model.
        input_ids already accounts for the positions of the to-be-inserted action embeddings.
    
        action tokens are represetnted by -3.
        example: [1287, 3342, ..., 6571, -3, ..., -3]
        """
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None
        else:
            action_token_indices = (input_ids == -3).nonzero(as_tuple=True)[0]
            image_token_indices = (input_ids > 0).nonzero(as_tuple=True)[0]

            image_tokens = input_ids[image_token_indices]
            image_token_embeddings = self.model.get_input_embeddings(image_tokens)

            inputs_embeds = torch.zeros(
                (input_ids.size(0), image_token_embeddings.size(1)), 
                device=input_ids.device, dtype=image_token_embeddings.dtype
            )
            inputs_embeds[image_token_indices] = image_token_embeddings

            actions = kwargs.pop("actions", None)
            if actions is not None:
                assert len(action_token_indices) == actions.size(0) * actions.size(1), "actions must have the same length as the number of action tokens"
                actions = actions.to(dtype=self.action_projection.weight.dtype)
                action_embeddings = self.action_projection(actions)
                inputs_embeds[action_token_indices] = action_embeddings.view(-1, action_embeddings.size(-1))
            input_ids = None
            inputs_embeds += self.pos_embedding_spatio_temporal(positions)
        hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        self.model.load_kv_cache_scales(quantization_param_path)

    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:
        def permute(w: torch.Tensor, n_heads: int):
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
        if "wk" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
        elif "wq" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

        for item in modules:
            if item in mapping and mapping[item] not in name:
                name = name.replace(item, mapping[item])

        return name, loaded_weight

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        loader.load_weights(
            self.maybe_remap_mistral(name, loaded_weight)
            for name, loaded_weight in weights)

下面解释一下要点。

注册 HF Model实现

由于 vLLM 在加载权重时引用了 Hugging Face 实现,因此需要在 Hugging Face Transformers 中注册自己的 Hugging Face 模型实现。

在下一节中,我们在 Hugging Face Transformers 中注册了我们自己的自定义模型实现。

import sys

# 将路径传递给HF实现。
sys.path.append(str(Path(__file__).parent.parent))
# 预先导入HF实现。
from models.llama_action import LlamaActionConfig, LlamaActionForCausalLM

# 将HF实现的模型注册到HF。
AutoConfig.register("llama_action", LlamaActionConfig)
AutoModel.register(LlamaActionConfig, LlamaActionForCausalLM)

为您自己的模态准备一个插件

如上所述,为了处理 vLLM 中实现的模式以外的其他模式,image您需要准备自己的实现。video``Plugin

以下部分实现了一种独特的模式"actions",面向。在Plugin这种multi_modal_data情况下,数据以可以直接输入模型的格式准备,因此实现无需任何特殊处理即可返回它。

from vllm.inputs import InputContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin

# 这是一个用于自定义实现的"action"模态的插件。实际上,它只是将输入的数据原样传输过去。
class ActionsPlugin(MultiModalPlugin):
    def get_data_key(self) -> str:
        return "actions"

    def _default_input_mapper(self, ctx: InputContext, data: object | List[object], **mm_processor_kwargs) -> MultiModalInputs:
        return MultiModalInputs({"actions": data})

    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        return 4096


MULTIMODAL_REGISTRY.register_plugin(ActionsPlugin())

位置编码实现的变更

在 Hugging Face 版本中,Positional Encoding 实现了一个可以同时处理训练和推理的分支,但 vLLM 版本只需要在推理期间考虑,因此消除了该分支。

此外,虽然 Hugging Face Transformers 没有明确使用 ID 来表示位置,但 vLLMpositions会自动生成一个指示每个 token 位置的张量,然后使用该张量创建位置嵌入。

import torch
import torch.nn as nn

class LearnableFactorizedSpatioTemporalPositionalEmbedding(nn.Module):
    def __init__(self, num_spatio_embeddings: int, num_temporal_embeddings: int, embedding_dim: int):
        super().__init__()
        self.spatio_embeddings = nn.Embedding(num_spatio_embeddings, embedding_dim)
        self.temporal_embeddings = nn.Embedding(num_temporal_embeddings, embedding_dim)
        self.num_spatio_embeddings = num_spatio_embeddings
        self.num_temporal_embeddings = num_temporal_embeddings

    def forward(self, positions: torch.Tensor):
        spatio_indices = positions % self.num_spatio_embeddings
        temporal_indices = positions // self.num_spatio_embeddings
        return self.spatio_embeddings(spatio_indices) + self.temporal_embeddings(temporal_indices)

实现虚拟输入函数

在初始化模型时,vLLM 具有一种机制,可以自动生成虚拟数据并将其传递给模型以检查其是否正常运行。对于LLM模型来说,这个虚拟数据只需要离散的token序列,但是对于多模态模型来说,还需要准备模拟输入多模态数据的数据。

操作检查需要实现的两个函数是返回用于多模态数据的最大令牌数的函数生成虚拟数据的函数

在这种情况下,一个动作向量序列对应 6 个 token,并且存在于图像的每一帧中,因此6×(25−1)=144。图像中的帧数−1− 1这是因为没有输入与最后一帧图像相对应的动作向量序列。

在下面的实现中,实现了一个函数来返回用于多模式数据的最大令牌数,从配置中读取数字而不是对其进行硬编码。

from vllm.inputs import InputContext

from models.llama_action import LlamaActionConfig

# Precheck
def get_max_action_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlamaActionConfig)
    num_action_tokens = hf_config.num_action_embeddings
    num_frames = hf_config.num_temporal_embeddings - 1
    return num_action_tokens * num_frames
    

此外,还实现了返回虚拟数据的函数,为离散标记序列和动作向量序列生成虚拟数据。如上所述,在 vLLM 中,需要为离散标记序列中将被多模态数据替换的部分准备占位符,因此在这种情况下,− 3 我已经实现它,以便包含的部分被视为占位符。

from typing import Mapping

import numpy as np
import torch
from vllm.inputs import InputContext
from vllm.sequence import SequenceData

from models.llama_action import LlamaActionConfig


def create_dummy_data(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]):
    hf_config = ctx.get_hf_config(LlamaActionConfig)

    num_frames = hf_config.num_temporal_embeddings
    vocab_size = hf_config.vocab_size
    num_action_tokens = hf_config.num_action_embeddings
    num_image_tokens = hf_config.num_image_patches
    dummy_seq = []
    np.random.seed(0)
    # 生成离散令牌序列的虚拟输入
    for i in range(num_frames - 1):
        dummy_image_tokens = np.random.randint(0, vocab_size, num_image_tokens).tolist()  # 随机生成一个图像帧的令牌。

        dummy_seq.extend(dummy_image_tokens)
        dummy_action_tokens = [-3] * num_action_tokens  # 准备动作插入部分的位置 holder。
        dummy_seq.extend(dummy_action_tokens)
    seq_data = SequenceData(array("l", dummy_seq))

    # 动作向量序列的模板。
    action = torch.tensor([
        [0.0, 0.0, 0.0],
        [0.0, 2.0, 0.5],
        [0.0, 4.0, 1.0],
        [0.0, 6.0, 1.5],
        [0.0, 8.0, 2.0],
        [0.0, 10.0, 2.5],
        [0.0, 12.0, 3.0],
        [0.0, 14.0, 3.5],
        [0.0, 16.0, 4.0],
    ])
    # 关于动作向量序列,生成虚拟输入
    actions = []
    for _ in range(num_frames - 1):
        actions.append(action[:num_action_tokens])
    actions = torch.cat(actions, dim=0)
    mm_data = {"actions": actions}
    return seq_data, mm_data

注册自己的插件和虚拟数据相关函数

为了使得输入到vLLM版本LlamaActionForCausalLM实现的VLLMLlamaActionForCausalLM数据使用上面实现的,并且在生成虚拟数据时使用与虚拟数据生成相关的函数,我们注册函数并使用装饰器,如下所示。ActionsPlugin``Plugin

from vllm.inputs import INPUT_REGISTRY
from vllm.multimodal import MULTIMODAL_REGISTRY


@MULTIMODAL_REGISTRY.register_input_mapper(data_type_key="actions")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("actions", get_max_action_tokens)
@INPUT_REGISTRY.register_dummy_data(create_dummy_data)
class VLLMLlamaActionForCausalLM(nn.Module, SupportsMultiModal):
    ... 
    

用多模式数据替换占位符

input_ids在输入中−3这里分配的特殊标记是占位符,因此在输入到 Transformer 之前需要用动作向量替换它们。在 vLLM 的这个实现中,这forward()是在方法中完成的。

如下图−3获取包含部分的索引,并action_projection用通过动作向量序列出来的向量序列替换它。其余部分则用经过该层得到的矢量序列来替换input_idsEmbedding这是在以下地方实现的:

def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    kv_caches: List[torch.Tensor],
    attn_metadata: AttentionMetadata,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
    """模型的正向传播。
    input_ids已经考虑了将要插入的动作嵌入的位置。
    动作令牌由-3表示。
    示例:[1287, 3342, ..., 6571, -3, ..., -3]
    """
    if intermediate_tensors is not None:
        # 这里是第二次及以后调用forward()时 = 使用缓存时
        input_ids = None
        inputs_embeds = None
    else:
        # 不使用缓存时走这个分支
        # 获取插入占位符的索引
        action_token_indices = (input_ids == -3).nonzero(as_tuple=True)[0]
        # 获取非占位符令牌的索引
        image_token_indices = (input_ids > 0).nonzero(as_tuple=True)[0]
        image_tokens = input_ids[image_token_indices]
        # 将非占位符令牌转换为Embedding
        image_token_embeddings = self.model.get_input_embeddings(image_tokens)
        # 创建Transformer的输入数据数组
        inputs_embeds = torch.zeros(
            (input_ids.size(0), image_token_embeddings.size(1)), 
            device=input_ids.device, dtype=image_token_embeddings.dtype
        )
        # 将非占位符部分替换为图像令牌的Embedding
        inputs_embeds[image_token_indices] = image_token_embeddings
        actions = kwargs.pop("actions", None)
        if actions is not None:
            assert len(action_token_indices) == actions.size(0) * actions.size(1), "动作的数量必须与动作令牌的数量相同"
            actions = actions.to(dtype=self.action_projection.weight.dtype)
            action_embeddings = self.action_projection(actions)
            # 将占位符部分替换为动作向量的Embedding
            inputs_embeds[action_token_indices] = action_embeddings.view(-1, action_embeddings.size(-1))
        input_ids = None
        

配置学习到的权重名称映射

在 vLLM 中,线性层和注意层有时使用 vLLM 特定的MergedColumnPrallelLinearQKVParallelLinear(例如和)来实现,并且赋予权重的名称在 HF 和 vLLM 实现之间可能有所不同。此时,为了在vLLM实现中加载HF版本中训练的权重,需要定义一个名称映射,并对名称进行更改和加载。

此名称映射vllm.model_executor.models.llama.LlamaForCausalLM 按原样使用基本名称映射。LlamaForCausalLM如下图,wq替换为q_projoutput替换为lm_head

class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings"
    }
    embedding_padding_modules = ["lm_head"]

    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }

    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm"
    }

    (中略)
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(
            self.maybe_remap_mistral(name, loaded_weight)
            for name, loaded_weight in weights)

    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:

        def permute(w: torch.Tensor, n_heads: int):
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
        if "wk" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
        elif "wq" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

        for item in modules:
            if item in mapping and mapping[item] not in name:
                name = name.replace(item, mapping[item])

        return name, loaded_weight
        

使用 vLLM 模型进行推理

到目前为止,VLLMLlamaActionForCausalLM我们已经实现了推理模型。使用它进行推理的方法与Hugging Face模型略有不同,因此我们将在这里进行解释。

初始化模型

首先,按如下方式加载模型:


import torch
from vllm import LLM, ModelRegistry

# import modeling_llama_action.py
from modeling_llama_action import VLLMLlamaActionForCausalLM


device = torch.device("cuda:0")
path = "/path/to/pretrained_weight"
model = LLM(
    model=path,
    skip_tokenizer_init=True,
    enforce_eager=True,
    max_num_seqs=5,
    device=device,
)

首先,skip_tokenizer_init关于。通常,LLM 带有一个将文本转换为标记的标记器 (Tokenizer)。 vLLM 具有使用包含的 Tokenizer 自动将文本转换为标记的功能。skip_tokenizer_init=False如果将其设置为(默认设置),则包含的 Tokenizer 将从配置文件中加载。另一方面,该模型使用图像标记器 (Image Tokenizer),它将图像而不是文本转换为离散的标记字符串,因此skip_tokenizer_init=True如果不关闭将标记器加载为的功能,将会发生错误。

此外,enforce_eager关于,这是一个必要的选项,以限制 vLLM 的功能以进行离线推理。 vLLM允许使用接收流输入的CUDA Graph 来为 LLM 提供服务。但是,使用流式输入时,某些 PyTorch 操作无法执行。在这个例子中,获取使用占位符标记的部分的索引的计算不支持流式输入,因此enforce_eager=True如果不这样做,则在第一次操作检查时会发生错误。

action_token_indices = (input_ids == -3).nonzero(as_tuple=True)[0]

vLLM使用一种称为自动前缀缓存的 技术来加快推理速度。这是多轮对话中的一种有效技巧,当新文本添加到先前输入的句子并输入到 LLM 中时,会使用缓存来缓存先前输入的句子部分。

不幸的是,截至 2024 年 12 月,APC 函数不再支持多模态模型,所以我们这次不会使用它,但对于一般的 LLM,enable_prefix_caching=True可以在初始化模型时通过添加选项来使用它。

模型推理

在这个实现中,一旦生成了一个图像帧的标记,就将生成的标记连接到输入提示的末尾,并添加动作向量的占位符,然后再次用作模型的输入,以生成图像帧序列的离散表示。代码展示如下:

import torch
from vllm import SamplingParams


n_context_frames = 3
n_frames = 25
n_frames_to_generate = n_frames - n_context_frames
# 事先使用ImageTokenizer对图像进行令牌化,并在其中插入动作向量对应的占位符。
prompt_tokens = [23126, 12318, ..., 8997, -3, -3, -3, -3, -3, -3]
actions = torch.tensor([[0.2, 2.4, 0.5],
                        [0.4, 5.2, 1.0],
                        ...
                        [0.0, 20.2, 2.5]])  # (n_frames * 6, 3)
inputs = [
    {
        "prompt_token_ids": prompt_tokens,
        "multi_modal_data": {"actions": actions[:6 * n_context_frames]}
    }
]
sampling_params = SamplingParams(temperature=1.0, detokenize=False, max_tokens=576, stop_token_ids=None)
all_outputs = []
for step in range(n_frames_to_generate):
    outputs = model.generate(inputs, sampling_params=sampling_params)[0].outputs[0].token_ids
    all_outputs.append(outputs)
    # 创建一个新的提示,其中包含前一步生成的内容和占位符。
    prompt = torch.cat([prompt, torch.tensor(outputs), torch.ones(6, dtype=torch.long) * -3])
    inputs = [
        {
            "prompt_token_ids": prompt.tolist(),
            "multi_modal_data": {"actions": actions[:6 * (n_context_frames + step + 1)]}
        }
    ]

在这种情况下,detokenize=False我们设置它是因为不需要将标记字符串转换为文本。

测速

这样就完成了vLLM版本的实现。最后我们来测量一下速度。将 HF 实现与 vLLM 实现进行了比较。

速度测试结果如下表所示。

vLLM 版本实现HF版本实现
生成时间 2.2 秒93.53 秒242.68 秒
每帧时间4.25 秒11.03 秒

HF 实现需要 11 秒才能生成一帧,而 vLLM 实现可以在 4.25 秒内生成一帧,速度大约快 2.6 倍。

结论

在本文中,我们以视频生成模型为例,介绍了如何使用 vLLM 推断自定义实现的多模式模型。我希望本文对那些想要使用 vLLM 在自己的模型上练习推理的人有所帮助。

© 2011 - 2025 Macro Zhao的分享站

关于我

如遇到加载502错误,请尝试刷新😄

Hi,欢迎访问 Macro Zhao 的博客。Macro Zhao(或 Macro)是我在互联网上经常使用的名字。

我是一个热衷于技术探索和分享的IT工程师,在这里我会记录分享一些关于技术、工作和生活上的事情。

我的CSDN博客:
https://macro-zhao.blog.csdn.net/

欢迎你通过评论或者邮件与我交流。
Mail Me

推荐好玩(You'll Like)
  • AI 动·画
    • 这是一款有趣·免费的能让您画的画中的角色动起来的AI工具。
    • 支持几十种动作生成。
我的项目(My Projects)
  • 爱学习网

  • 小乙日语App

    • 这是一个帮助日语学习者学习日语的App。
      (当然初衷也是为了自用😄)
    • 界面干净,简洁,漂亮!
    • 其中包含 N1 + N2 的全部单词和语法。
    • 不需注册,更不需要订阅!完全免费!
  • 小乙日文阅读器

    • 词汇不够?照样能读日语名著!
    • 越读积累越多,积跬步致千里!
    • 哪里不会点哪里!妈妈再也不担心我读不了原版读物了!
赞助我(Sponsor Me)

如果你喜欢我的作品或者发现它们对你有所帮助,可以考虑给我买一杯咖啡 ☕️。这将激励我在未来创作和分享更多的项目和技术。🦾

👉 请我喝一杯咖啡

If you like my works or find them helpful, please consider buying me a cup of coffee ☕️. It inspires me to create and share more projects in the future. 🦾

👉 Buy me a coffee