Skip to content
0

文章发布较早,内容可能过时,阅读注意甄别。

Mamba Note

1 Mamba的安装

安装步骤:

bash
# 1、创建conda虚拟环境
conda create -n your_env_name python=3.10.13
conda activate your_env_name

# 2、安装cudatoolkit
# 安装前进行search
conda search cudatoolkit -c nvidia
# 安装对应的版本
conda install cudatoolkit==11.8 -c nvidia

# 3、安装torch torchvision torchaudio
# 魔法版
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
# 国内版-上交镜像
pip install torch==2.1.1+cu118 torchvision==0.16.1+cu118 torchaudio==2.1.1+cu118 -f https://mirror.sjtu.edu.cn/pytorch-wheels/torch_stable.html
# 国内版-阿里镜像
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 -f https://mirrors.aliyun.com/pytorch-wheels/cu118

# 4、安装 nvcc
# 安装前进行search
conda search -c nvidia cuda-nvcc
# 安装对应的版本
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging

# 5、安装 mamba 包
pip install causal-conv1d==1.1.1  # 版本号根据实际情况选择,或者不指定直接安装最新
pip install mamba-ssm==1.1.3.post1  # 版本号根据实际情况选择,1.1 和 1.2 实测有函数不兼容,不设定默认装最新版本

1、对于causal_conv1dmamba_ssm 可以手动下载

(1) causal_conv1d下载链接:

Dao-AILab/causal-conv1d: Causal depthwise conv1d in CUDA, with a PyTorch interface的release

(2)mamba_ssm 下载链接:

state-spaces/mamba: Mamba SSM architecture的release

请下载对应版本的包:

例如,causal_conv1d-1.5.0.post8+cu11torch2.3cxx11abiFALSE-cp39-cp39-linux_x86_64.whl表示cuda版本为11.x,torch版本为2.3,cp表示python版本为3.9.x,并且请下载含有FALSE的包

下载完之后,进行pip install causal_conv1d-1.5.0xxx.whl 安装

2 Mamba代码测试

2.1 Mamba测试

python
import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 32, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
print('参数量', sum(p.numel() for p in model.parameters()))
print(y.shape)
assert y.shape == x.shape
'''
参数量 7968
torch.Size([2, 32, 16])
'''

TIP

d_model需要和dim相对应

38632451

2.2 Mamba2测试

python

import torch
from mamba_ssm import Mamba

from mamba_ssm import Mamba2
batch, length, dim = 2, 64, 1024
x = torch.randn(batch, length, dim).to("cuda")

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
print('参数量', sum(p.numel() for p in model.parameters()))
print(y.shape)
assert y.shape == x.shape

输出结果:

参数量 6468320
torch.Size([2, 64, 1024])
最近更新