00:00:00
Mamba Note
1 Mamba的安装
IMPORTANT
参考博客:
1、Linux 下 Mamba 环境安装踩坑问题汇总(重置版)_linux安装mamba ssm1.1.3-CSDN博客
2、mambassm和causal-conv1d安装教程不同torch版本的mamba-ssm-CSDN博客
3、Anaconda虚拟环境中安装cudatoolkit和cudnn包并配置tensorflow-gpu_conda install cudatoolkit-CSDN博客
安装步骤:
bash
# 1、创建conda虚拟环境
conda create -n your_env_name python=3.10.13
conda activate your_env_name
# 2、安装cudatoolkit
# 安装前进行search
conda search cudatoolkit -c nvidia
# 安装对应的版本
conda install cudatoolkit==11.8 -c nvidia
# 3、安装torch torchvision torchaudio
# 魔法版
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
# 国内版-上交镜像
pip install torch==2.1.1+cu118 torchvision==0.16.1+cu118 torchaudio==2.1.1+cu118 -f https://mirror.sjtu.edu.cn/pytorch-wheels/torch_stable.html
# 国内版-阿里镜像
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 -f https://mirrors.aliyun.com/pytorch-wheels/cu118
# 4、安装 nvcc
# 安装前进行search
conda search -c nvidia cuda-nvcc
# 安装对应的版本
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
# 5、安装 mamba 包
pip install causal-conv1d==1.1.1 # 版本号根据实际情况选择,或者不指定直接安装最新
pip install mamba-ssm==1.1.3.post1 # 版本号根据实际情况选择,1.1 和 1.2 实测有函数不兼容,不设定默认装最新版本1、对于
causal_conv1d和mamba_ssm可以手动下载(1)
causal_conv1d下载链接:Dao-AILab/causal-conv1d: Causal depthwise conv1d in CUDA, with a PyTorch interface的release
(2)
mamba_ssm下载链接:state-spaces/mamba: Mamba SSM architecture的release
请下载对应版本的包:
例如,
causal_conv1d-1.5.0.post8+cu11torch2.3cxx11abiFALSE-cp39-cp39-linux_x86_64.whl表示cuda版本为11.x,torch版本为2.3,cp表示python版本为3.9.x,并且请下载含有FALSE的包下载完之后,进行
pip install causal_conv1d-1.5.0xxx.whl安装
2 Mamba代码测试
2.1 Mamba测试
python
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 32, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
print('参数量', sum(p.numel() for p in model.parameters()))
print(y.shape)
assert y.shape == x.shape
'''
参数量 7968
torch.Size([2, 32, 16])
'''TIP
d_model需要和dim相对应
38632451
2.2 Mamba2测试
python
import torch
from mamba_ssm import Mamba
from mamba_ssm import Mamba2
batch, length, dim = 2, 64, 1024
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
print('参数量', sum(p.numel() for p in model.parameters()))
print(y.shape)
assert y.shape == x.shape输出结果:
参数量 6468320
torch.Size([2, 64, 1024])