Review of VQ-BeT
引言
VQ-BeT(Vector-Quantized Behavior Transformer) 是一种结合向量量化(Vector Quantization)和Transformer架构的行为建模方法,专为多模态连续动作的学习与预测而设计。它通过 残差向量量化(Residual VQ) 将连续动作表示离散化,生成紧凑的潜在表示,解决了传统 k-means 方法无法捕捉复杂分布的问题。
VQ-BeT 通过以下核心组件实现高效行为建模:
- 离散化潜在表示:使用多层残差量化逐步逼近输入动作分布,提升模型对细粒度行为的捕捉能力。
- 条件与非条件任务支持:能够处理基于目标条件的任务(如机器人操作)和无条件任务(如多模态行为生成)。
- 损失设计:结合重建损失和向量量化损失,确保离散表示的准确性和复现能力。
在设计优秀的行为与动作建模模型时,我们希望模型能够满足以下特性:
- 建模长期和短期依赖关系:能够同时捕捉行为中的长期模式(如任务级别的目标)和短期细节(如瞬时动作)。
- 从多样化的行为模式中进行捕捉与生成:支持复杂且多模态的行为特性建模。
- 精确地复现学习到的行为:能够在生成阶段保持与真实演示行为一致的精确度。
参考文献(Shafiullah et al., 2022 ; Chi et al., 2023 )指出,利用向量量化(Vector Quantization, VQ)学习离散表示能够很好地满足上述需求。
背景与基础
2.1 行为克隆(Behavior Cloning)
行为克隆是一种基于监督学习的模仿学习方法,通过学习从状态到动作的映射来模仿人类或专家
行为。其关键在于收集高质量的演示数据,并设计适当的模型来捕捉行为模式。然而,在处理复杂多样的行为时,行为克隆的表现可能受到连续动作空间维度和分布的限制。
2.2Behavior Transformers
Behavior Transformers(BeT) 是一种基于 Transformer 的模型,用于行为建模和预测,其设计目标是通过捕捉时序依赖关系来处理长时间跨度的行为。BeT 的核心流程如下:
动作编码与解码:
- 输入的动作序列通过基于 k-means 的编码器和解码器进行离散化。
- 输出可以是离散或连续成分。
局限性:
- k-means 无法捕捉复杂分布中的细微结构特性,导致离散化表示的不足,进而影响模型性能。
2.3 残差向量量化(Residual Vector Quantization, RVQ)
向量量化(Vector Quantization, VQ)
向量量化是将连续动作数据离散化的重要技术,其核心是引入一个 代码本(codebook):
- 代码本表示:
,其中每个 表示一个离散嵌入向量。 - 量化规则: 对于输入向量
,通过最近邻搜索找到最接近的代码本向量 ,并映射为离散表示:
残差向量量化(Residual VQ, RVQ)
RVQ 是对普通 VQ 的扩展,通过多级递归量化逐步减小误差,生成更高精度的离散表示:
- 第一级量化:
- 对残差
进行进一步量化: - 最终的量化表示为:
这种多层次的残差量化能够有效捕捉动作分布中的细节特征。
Vector-Quantized Behavior Transformers, VQ-BeT
3.1 行为数据的序列预测
3.1.1. 动作离散化和序列预测
- 动作离散化的背景:
- 将动作分箱(binning)以生成离散化的动作类别标记(tokenized class),并基于这些类别进行预测已被广泛应用于多模态行为学习任务。
- 方法的局限性:
- k-means 分箱方法存在问题:
- 容易丢失行为序列中的细微特征。
- 无法有效处理高维或复杂的动作分布。
- k-means 分箱方法存在问题:
3.1.2. 提出的方法
为了解决上述问题,提出了以下方法:
- 学习离散潜在嵌入空间(latent embedding space):
- 通过构建离散化潜在空间,对动作或动作块(action chunk)进行建模。
- 与简单的分箱方法不同,这种嵌入空间能够捕捉复杂行为模式,并保留序列中的细微特征。
- 利用离散化方法(例如 VQ-VAE)建模动作潜在空间:
- 这类离散化模型不仅可以应用于动作建模,还被广泛用于生成式建模任务中(如图像、音乐、视频等领域)。
- 文中提到的一些相关研究:
- Bar-Tal 等人(2024):在图像建模中使用离散潜在表示。
- Ziv 等人(2024);Podell 等人(2023):在音乐和视频生成中验证了离散潜在空间的有效性。
3.1.3. 方法的优势
- 直接预测离散化的动作类别:
- 使用离散化潜在空间后,模型能够根据观测序列直接预测动作类别,无需对连续动作进行复杂建模。
- 与目标条件(如目标向量)相关联的观察序列可以更高效地生成行为标记。
- 提升序列建模能力:
- 离散化后的嵌入表示更加紧凑,可以有效减小动作空间的维度。
- 同时保留行为序列中的关键特征,提升模型的预测精度和泛化性能。
3.2 动作块离散化(Action Chunk Discretization via Residual VQ)
Residual Vector Quantization (Residual VQ) 被用于设计一个可扩展的动作离散化器,以解决真实世界中复杂的动作空间问题。如上图所示,具体方法如下:
- 输入数据的量化过程:
- 对动作块(action chunk)
(其中 )进行编码,生成潜在嵌入向量 。 被映射到第一级代码本(codebook)的向量 ,通过最近邻搜索选取: - 计算残差
,并将其递归地传递给后续 层代码本,依次量化: - 最终量化的离散向量表示为:
- 对动作块(action chunk)
- 重建与解码:
- 使用解码器
,通过离散向量 重建原始动作块:
- 使用解码器
- 损失函数的设计:
- 重建损失(Reconstruction Loss):
- 向量量化损失(VQ Loss) :
其中: :停止梯度操作(stop-gradient),避免对代码本直接更新。 :承诺损失的权重。
- 重建损失(Reconstruction Loss):
- 代码本的作用:
- 主要代码(Primary Codes):第一层代码本捕捉动作的粗粒度分布。
- 次要代码(Secondary Codes):后续层捕捉动作中的细粒度信息。
该方法使用
3.3. 重建损失和向量量化损失
损失函数中的重建损失(Reconstruction Loss)和向量量化损失(Vector Quantization Loss, VQ Loss)被设计为优化离散化模型(如 VQ-VAE 或 VQ-BeT)性能的关键组成部分。两者的使用分别针对模型中的不同目标,以下是具体原因和作用:
3.3.1. 重建损失(Reconstruction Loss)
目的:
- 保证模型能够通过离散化后的表示,尽可能准确地重建原始输入(如动作序列)。
作用:
确保离散化表示的有效性:
- 动作序列被离散化到潜在空间后,必须能够通过解码器还原成接近原始动作的序列。如果重建误差较大,说明离散表示未能有效捕捉动作的特征。
指导编码器与解码器协同优化:
- 重建损失直接衡量了输入与输出的差距,能够引导编码器生成对解码器友好的潜在表示,同时优化解码器的重建能力。
提升模型的行为复现能力:
- 对于行为建模任务,精确的行为复现是核心目标。重建损失确保了模型能够在离散化表示的基础上忠实地再现复杂的行为模式。
公式:
重建损失常用
:原始动作序列。 :编码器。 :解码器。 :量化后的离散表示。
3.3.2. 向量量化损失(VQ Loss)
目的:
- 保证离散化过程中,编码器生成的潜在表示能够有效地映射到代码本中的离散向量。
作用:
- 约束潜在表示与代码本的匹配:
- 编码器输出的连续潜在表示必须尽可能接近代码本中的某个向量。通过向量量化损失,将编码器生成的嵌入表示拉近到最优代码本向量,确保离散化过程的有效性。
- 防止代码本崩塌:
- 如果编码器输出无法充分利用代码本中的离散向量,可能导致代码本某些部分从未被使用,降低模型的表示能力。向量量化损失通过惩罚与代码本的距离,促使所有代码本向量被更均匀地利用。
- 信息紧凑化与减少冗余:
- 向量量化过程将连续动作映射到有限数量的离散向量,这种紧凑化表示能够降低输入数据的维度,同时保留行为的关键特征。
公式:
向量量化损失包括两部分:
- 编码器输出与离散向量的距离:
:停止梯度操作(stop-gradient),防止代码本向量在训练中被错误更新。 :离散化后选中的代码本向量。
- 编码器输出与代码本的中心偏移(承诺损失):
- 防止编码器输出偏离代码本的中心,促使编码器生成更紧凑的潜在表示。
最终的向量量化损失为:
3.3.3. 综合两种损失的必要性
互补性:
重建损失:关注模型最终输出的质量,确保通过离散化后的潜在表示能够精确还原原始动作序列。
向量量化损失:关注潜在表示与代码本的匹配,确保离散化过程的有效性和紧凑性。
平衡短期与长期目标:
- 重建损失着眼于当前输入的准确性,而向量量化损失确保模型的长期表现(如代码本的泛化能力和潜在空间的表示能力)。
共同优化离散化模型的性能:
- 两者的结合,使得模型不仅能够从离散化表示中学习有效的行为模式,还能以高效的方式在离散化潜在空间中捕捉复杂行为分布。
3.4 代码预测的加权更新(Weighted Update for Code Prediction)
在完成 Residual VQ 的训练后,使用类似 GPT 的 Transformer 架构来建模观察序列与动作块之间的概率分布。具体流程如下:
- 代码预测的目标:
- 通过观察序列
来预测离散动作代码 。 - 使用 Residual VQ 生成的代码本索引
作为标签,训练代码预测头 。
- 通过观察序列
- 损失函数的设计:
- 使用 Focal Loss(聚焦损失)来应对类别分布不均问题:
其中: :用于调整主要代码(primary code)和次要代码(secondary code)损失之间的权重。 :预测的离散代码索引。
- 使用 Focal Loss(聚焦损失)来应对类别分布不均问题:
- 重建量化行为:
- 离散化后的行为通过解码器
重建: 其中: :第 层代码本的第 个嵌入向量。 :指示是否选择了第 个向量。
- 离散化后的行为通过解码器
- 偏移修正(Offset Adjustment):
- 引入偏移修正项
,调整离散化动作的中心以提高重建精度:
- 引入偏移修正项
- 总损失函数:
- 最终损失结合代码预测损失和偏移修正损失:
- 最终损失结合代码预测损失和偏移修正损失:
3.4 条件与非条件任务形式化(Conditional and Non-Conditional Task Formulation)
这一部分介绍了 VQ-BeT 在多任务学习中的应用,通过条件任务(Conditional Tasks)和非条件任务(Non-Conditional Tasks)两种形式化方法来预测多模态的连续动作序列。
3.4.1. 背景与目标
目标:
- 构建一个通用的行为学习模型,能够在不同任务环境中预测多模态连续动作。
- VQ-BeT 支持以下两种任务形式化:
- 非条件任务(Non-Conditional Tasks):根据观察序列预测动作序列的分布。
- 条件任务(Conditional Tasks):结合目标条件(如目标状态或期望的观察)预测动作序列。
适用场景:
- 评估了 VQ-BeT 在多种模拟和真实环境中的表现,如:
- 模拟任务:PushT(Chi 等人,2023)、Multimodal Ant(Brockman 等人,2016)、BlockPush(Florence 等人,2022)、UR3 BlockPush(Kim 等人,2022)、Franka Kitchen(Gupta 等人,2019)。
- 真实任务:nuScenes 自动驾驶(Caesar 等人,2020)和 Play Kitchen 机器人环境。
3.4.2. 非条件任务形式化(Non-Conditional Formulation)
定义:
- 给定一个数据集
,目标是预测可能的动作序列 的分布,条件是给定的观察序列 。 - 行为策略的数学形式为:
其中: :观察空间。 :动作空间。 :观察序列的长度。 :动作序列的长度。
核心任务:
- 预测基于一段历史观察数据
的未来动作序列的分布。 - 非条件任务不需要额外的目标信息,仅依赖观察序列进行推断。
3.4.3. 条件任务形式化(Conditional Formulation)
定义:
- 针对目标条件任务,扩展非条件任务形式化,额外引入目标条件向量(goal conditioning vector),该条件可以表示为一个或多个额外的观察。
- 给定当前观察序列
和未来观察序列 ,预测条件动作序列分布。 - 行为策略的数学形式为:
其中: :当前观察序列。 :未来观察序列(目标条件)。 :动作序列分布。
核心任务:
- 结合当前状态和目标状态,通过条件化的行为策略生成动作序列。
- 条件任务形式化适用于目标明确的场景(如机器人抓取任务需要达到某个具体位置)。
3.4.4. 示例任务环境
以下是 VQ-BeT 所评估的多种任务环境:
- PushT(Chi 等人,2023):一个推箱子的任务,涉及多种复杂的行为模式。
- Multimodal Ant(Brockman 等人,2016):需要控制多关节机器人完成特定行为。
- BlockPush 和 UR3 BlockPush:机器人抓取与推动任务,动作要求精确。
- Franka Kitchen(Gupta 等人,2019):模拟厨房环境中的复杂操作。
- nuScenes 自动驾驶(Caesar 等人,2020):预测车辆的驾驶动作。
- Play Kitchen:真实世界中厨房操作的任务。
Code [Future Work]
- Title: Review of VQ-BeT
- Author: xiangyu fu
- Created at : 2024-11-26 16:36:47
- Updated at : 2024-11-26 18:16:42
- Link: https://redefine.ohevan.com/2024/11/26/Reviews/vqbet/
- License: This work is licensed under CC BY-NC-SA 4.0.