23 Rust与大模型:用 Candle 做一个聊天机器人
你好,我是Mike。今天我们来聊一聊如何用Rust做一个基于大语言模型的聊天机器人。
大语言模型(LLM)是2023年最火的领域,没有之一。这波热潮是由OpenAI的ChatGPT在今年上半年发布后引起的,之后全世界的研究机构和科技公司都卷入了大模型的竞争中。目前业界应用大模型训练及推理的主要语言是Python和C/C++。Python一般用来实现上层框架,而C/C++一般起底层高性能执行的作用,比如著名的框架 PyTorch,它的上层业务层面是用Python写的,下层执行层面由C执行,因为GPU加速的部分只能由C/C++来调用。
看起来大语言模型好像和Rust没什么关系,必须承认,由于历史积累的原因,在AI这一块儿Rust的影响力还非常小。但从另一方面来讲呢,Rust是目前业界除Python、C++ 外,唯一有潜力在未来20年的AI 发展中发挥重要作用的语言了。为什么这么说呢?
首先Rust的性能与C/C++一致,并且在调用GPU能力方面也同样方便;其次,Rust强大的表达能力,不输于Python,这让人们使用Rust做业务并不难;然后,Rust的cargo编译成单文件的能力,以及对WebAssembly完善的支持,部署应用的时候非常方便,这比Py + C/C++组合需要安装的一堆依赖和数G的库方便太多。
目前Rust生态中其实已经有很多AI相关的基础设施了,你可以从我给出的链接里找到。
世界上最大的机器学习模型仓库平台 HuggingFace(机器学习领域的Github) 推出了 Rust 机器学习框架 Candle。在这个官方代码仓库上,HuggingFace上解释了为什么要做一个Rust的机器学习框架。
- Candle上云端的 Serverless 推理可行。PyTorch那一套体积太大,安装完得几个G,而Candle编译后的可执行文件才十几M到几十M。
- Candle可以让你避免Python的 GIL,从而提高性能。
- HuggingFace已经用Rust写了不少基础工具了,比如 safetensors 和 tokenizers。
Elon Musk的 x.ai 发布后,页面上也有一段对Rust的溢美之词:Rust被证明是一个理想的选择,用于构建可扩展的、可靠的、可维护性的基础设施。它提供了高性能、丰富的生态系统,它能阻止大部分错误,这些错误在分布式系统中经常会碰到。由于我们的团队规模很小,基础设施的可靠性就显得至关重要,否则维护工作会浪费大量创新的时间。Rust给我们提供了信心,任何的代码修改或重构都可以产出可工作的程序,并且在最小监管下可以持续运行好几个月(而不会出问题)。
Musk甚至说Rust语言是未来构建 AGI 的语言。
在使用Rust尝试做聊天机器人之前,我们先来了解一下相关的背景知识。
大语言模型背景知识
这节课我们还是主要讲Rust及Rust的应用,所以相关的背景知识我们就简单概括一下。
机器学习(Machine Learning,ML)泛指用计算机对数据集按照一定的算法进行数据处理、分析、聚类、回归等。在执行前,人往往不知道结果是什么,所以叫机器学习。机器学习可以用于自动提取信息,自动或辅助人类做决策。
而神经网络(Neural Network)是机器学习的一类算法,它是模拟人的大脑神经元和连接的一种算法。目前整个业界投入资源最多的就是在这类算法上面,在这个类别中的创新也是最多的,有种观点认为神经网络算法是通向真正的AI最可能正确的路径。
而深度学习(Deep Learning),其实就是层次很多很深的神经网络。你可以想象,层次越深,节点数(图里那些圈)越多,就越能模拟大脑。但是深度学习带来的问题就是,层次越深,节点数越多,则那些线(就是权重值)就会越多,呈指数级增长,那么计算量就会越来越大。其实之前十几、二十年AI进展不大,主要就是因为这个计算能力限制了。
通过无数人的探索,我们发现可以把深度学习的计算并行化,优化了很多工程上的算法。这样就能够在大家熟悉的主要用来玩游戏的GPU显卡上运行,大大提高了计算速度。这个视角,二十多年前Nvidia的老黄就看明白了,早早地提供好了基础设施,CUDA、CUDNN等。所以现在你就看到了,一卡难求,已经严重影响到了游戏玩家的生存。
ChatGPT与LLaMA
那么ChatGPT是什么呢?英文是Chat Generative Pre-trained Transformer,是OpenAI提供的在线AI对话服务,具有令人惊叹的理解能力和回答能力。你可以把ChatGPT理解成一个为对话调优的预训练转换模型。GPT 3 有 1750 亿参数,有传言GPT 4 有 1.76万亿参数。参数是什么,就是图中节点之间的连线,你可以想象一下1750亿根线的场景。
OpenAI搞出了ChatGPT,让全世界惊掉了下巴,但是它是闭源的。于是Meta公司(前Facebook)的Yann LeCun(图灵奖得主)团队,搞了一个开源版本的GPT,叫 LLaMA。它也是在巨量的源数据集上进行的训练,生成了 70亿(7B),130亿(13B),650亿(65B)三个版本参数的大模型文件,可以供业界做学习和研究使用。现在已推出LLaMA 2,正在做LLaMA 3。
LLaMA搞出来后,掀起了LLM界的狂欢。全世界的团队在LLaMA的基础上,继续调优,推出了各种各样的大模型。比如国内清华的 ChatGLM, 零一万物的 Yi。国外的 Mistral、OpenChat、Starling 等。这块儿非常卷,大家都在争相推出自己调优后的版本,每天早上睡醒起来,都发现又推出了几个新的LLM。
这些训练好后的文件一般从几个G到几十G不等,也有几百G的。要运行它们,得有非常强大的机器才行。比如7B 的 LLaMA 2 文件,每个权重为一个 f16 浮点数,占两个字节,所以可以估算出要运行 LLaMA 2 模型,起码得有 14G 的内存或显存。内存还好,显存超过14G的个人用户真不多。并且,LLaMA2的模型文件是PyTorch导出的,只能由PyTorch框架来运行。
llama.cpp与量子化方法
车到山前必有路,大神 Georgi Gerganov 搞了一个项目:llama.cpp。它是一个用C/C++重新实现引擎的版本,不需要安装PyTorch,就可以运行LLaMA 2模型文件。最关键的是,它提出了一种量子化(quantization)方法,可以将权重从 16 位量子化到8位、6位、5位、4位,甚至2位。这样,就相当于等比缩小了占用内存的规模。比如,一个4位量子化版本的LLaMA 2 7B模型,就只需要不到4G的内存/显存就能运行。这样,就能适配大多数的个人计算机了。
这种量子化方法是个重大创新,它直接促进了LLM生态的进一步繁荣。现在HuggingFace上有大量量子化后的模型,比如 openchat_3.5.Q4_K_M.gguf 就是一个OpenChat的4位量子化的版本。我们下载的时候,直接下载这些量子化后的模型文件就可以了。
请注意,这些文件是训练后的成品,我们下载它是用来做推理(infer)的,而不是训练(train)的。当然,我们可以在这些成品模型上运行调优(fine tune)。
大模型文件格式
目前HuggingFace上有几种常见的LLM文件格式。
- bin格式:Pytorch导出的模型文件格式
- safetensors格式:HuggingFace定义的一种新的模型文件格式,有可能成为未来的主流格式。HuggingFace用Rust实现safetensors格式的解析,并导出为Py接口,请参见链接。
- ggml格式:llama.cpp 项目量子化模型的前期模型格式。
- gguf格式:llama.cpp项目量子化模型的后期模型格式,也是现在主流的量子化LLM格式。
Rust的机器学习框架
Rust生态现在有几个比较不错的ML框架,最好的两个是:Candle 和 burn。后续,我们以Candle为例来介绍。
Candle介绍
据Candle官网介绍,它是一个极小主义机器学习框架,也就是没什么依赖,不像Pytorch那样装一堆东西,部署起来很麻烦。但其实它也能用来训练。
它有下面这些特性:
- HuggingFace出品。近水楼台先得月,Candle几乎能支持HuggingFace上所有的模型(有的需要经过转换)。
- 语法简单,跟PyTorch差不多。
- CPU、Cuda、Metal的支持。
- 让serverless和快速部署成为可能。
- 模型训练。
- 分布式计算(通过NCCL)。
- 开箱即用的模型支持,LLaMA、Whisper、 Falcon 等等。
Candle不仅仅是大模型深度学习框架,它还是一个机器学习框架,因此它也支持其他的机器学习算法和强化学习(reinforcement learning)。下面我们就来看看如何利用Candle框架做一个聊天机器人。
注:这节课的代码适用于 candle_core v0.3 版本。
使用Candle做一个聊天机器人
下载模型文件
我对一些大模型进行了测试,发现OpenChat的对话效果比较好,所以下面我们用OpenChat LLM来进行展示。我们会用 quantized 8bit 的版本。其实 4bit 的版本也是可以的,效果也非常好。
Candle 官方的示例比较复杂,我为这个课程定制了一个更简单的独立运行的示例。你可以将这个仓库克隆下来,进入目录。在运行代码之前,下载模型文件和tokenizer.json文件。
与代码目录同级的位置,创建一个目录
mkdir hf_hub
进入这个目录,请下载
https://huggingface.co/TheBloke/openchat_3.5-GGUF/blob/main/openchat_3.5.Q8_0.gguf
和
https://huggingface.co/openchat/openchat_3.5/blob/main/tokenizer.json
将这个 tokenizer.json 重命名为 openchat_3.5_tokenizer.json
目录结构:
运行演示
然后,进入 23-candle_chat/ 运行:
出现如下界面,就可以聊天了。
下面是我问的一个问题,bot的回答好像有点问题,这个模型用英文问的效果会好一些。
代码讲解
你可以看一下代码。
#![allow(unused)]
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use tokenizers::Tokenizer;
use candle_core::quantized::gguf_file;
use candle_core::utils;
use candle_core::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_llama as quantized_model;
use anyhow::Result;
mod token_output_stream;
use token_output_stream::TokenOutputStream;
struct Args {
tokenizer: String,
model: String,
sample_len: usize,
temperature: f64,
seed: u64,
repeat_penalty: f32,
repeat_last_n: usize,
gqa: usize,
}
impl Args {
fn tokenizer(&self) -> Result<Tokenizer> {
let tokenizer_path = PathBuf::from(&self.tokenizer);
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
}
fn model(&self) -> Result<PathBuf> {
Ok(std::path::PathBuf::from(&self.model))
}
}
fn main() -> anyhow::Result<()> {
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
utils::with_avx(),
utils::with_neon(),
utils::with_simd128(),
utils::with_f16c()
);
let args = Args {
tokenizer: String::from("../hf_hub/openchat_3.5_tokenizer.json"),
model: String::from("../hf_hub/openchat_3.5.Q8_0.gguf"),
sample_len: 1000,
temperature: 0.8,
seed: 299792458,
repeat_penalty: 1.1,
repeat_last_n: 64,
gqa: 8,
};
// load model
let model_path = args.model()?;
let mut file = File::open(&model_path)?;
let start = std::time::Instant::now();
// This is the model instance
let model = gguf_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
}
println!(
"loaded {:?} tensors ({}bytes) in {:.2}s",
model.tensor_infos.len(),
total_size_in_bytes,
start.elapsed().as_secs_f32(),
);
let mut model = quantized_model::ModelWeights::from_gguf(model, &mut file)?;
println!("model built");
// load tokenizer
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
// left for future improvement: interactive
for prompt_index in 0.. {
print!("> ");
std::io::stdout().flush()?;
let mut prompt = String::new();
std::io::stdin().read_line(&mut prompt)?;
if prompt.ends_with('\n') {
prompt.pop();
if prompt.ends_with('\r') {
prompt.pop();
}
}
let prompt_str = format!("User: {prompt} <|end_of_turn|> Assistant: ");
print!("bot: ");
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
let prompt_tokens = tokens.get_ids();
let mut all_tokens = vec![];
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None);
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
let input = Tensor::new(prompt_tokens, &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = "<|end_of_turn|>";
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
let start_post_prompt = std::time::Instant::now();
let to_sample = args.sample_len.saturating_sub(1);
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle_core::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
prompt_tokens.len(),
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
}
Ok(())
}
我们分段讲解这100多行代码。
第19行,定义了Args参数,这是模型必要的参数配置定义。第42~48行,看看有哪些CPU特性支持。第50~59行,实例化Args,这些参数都是硬编码进去的,除了两个文件的路径外,需要LLM相关知识才能理解。
第62~81行,加载大模型文件,并生成模型对象。我们这个模型是GGUF格式的,因此需要用gguf_file模块来读。Tensor是LLM中的重要概念,它是一个多维数组,可以在CPU和GPU上计算,在GPU上还可以并行计算。一个大模型由很多的Tensor组成。我们这个模型中,加载进来了 291 个 Tensor。
第84行,加载tokenizer.json文件,并生成 tokenizer 实例。tokenizer 用于将输入和输出的文本转化为Tensor,变成大模型可理解的数据。第85行创建token输出流实例。第89~99行,建立问答界面,输入提示符为一个 > 号,输出为 bot:
开头。第101~104行将输入的问答转化为 tokens。这个tokens就是Tensor实例。
第106~122行,是用于处理输入,大模型对输入的token做一处理,你可以理解成大模型对你的输入问题先要进行一下理解,然后后面才能做出对应的回答。
LogitsProcessor 是一个用于修改模型输出概率分布的工具。我们可以看到,这个过程中使用的设备写死了,用的CPU。
第124行,"<|end_of_turn|>"
是 OpenChat 模式定义的一轮对话结束的标志。第125~156行,就是对问题的回答。用的设备仍然为CPU,我们可以猜测应该会很慢。它这里面还有对penalty机制的处理。细节也需要去查阅大模型NLP相关的知识。
第158~167行,是对本次处理性能的一个汇总。在我的电脑上,纯用CPU计算的话,只能达到1秒一个多token的速度,非常卡。有GPU加持的话,会快很多。
你可能发现了,第87行有个循环,它是用来实现 interactive 交互效果的,问完一句,回答完,还可以问下一句。
添加命令行参数
前面的simple示例,我们所有的参数都是写死在代码里面的。这样方便理解,但不方便使用。我们可以尝试为它添加命令行参数功能。
在Rust中,写一个命令行非常简单,直接用clap,改几行代码就可以了。将上面示例中的Args结构体的定义变成下面这样就可以了,然后在调用的时候使用 Args::parse()
生成 Args 实例。
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long, default_value = "../hf_hub/openchat_3.5_tokenizer.json")]
tokenizer: String,
#[arg(long, default_value = "../hf_hub/openchat_3.5.Q8_0.gguf")]
model: String,
#[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize,
#[arg(long, default_value_t = 0.8)]
temperature: f64,
#[arg(long, default_value_t = 299792458)]
seed: u64,
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long, default_value_t = 8)]
gqa: usize,
}
fn main() {
// ...
let args = Args::parse();
// ...
}
经过升级的命令有了下面这些参数:
$ cargo run --release --bin cli -- --help
Finished release [optimized] target(s) in 0.04s
Running `target/release/cli --help`
avx: false, neon: false, simd128: false, f16c: false
Usage: cli [OPTIONS]
Options:
--tokenizer <TOKENIZER> [default: ../hf_hub/openchat_3.5_tokenizer.json]
--model <MODEL> [default: ../hf_hub/openchat_3.5.Q8_0.gguf]
-n, --sample-len <SAMPLE_LEN> [default: 1000]
--temperature <TEMPERATURE> [default: 0.8]
--seed <SEED> [default: 299792458]
--repeat-penalty <REPEAT_PENALTY> [default: 1.1]
--repeat-last-n <REPEAT_LAST_N> [default: 64]
--gqa <GQA> [default: 8]
-h, --help Print help
-V, --version Print version
是不是很方便?谁说Rust生产力不行的!
点击这里可以找到源文件,你可以本地试试跑起一个大模型对话机器人。另外,Candle官方仓库中的示例功能更强大,但也更复杂,你可以继续深入研究。
小结
这节课我们一起探索了使用Rust利用Candle机器学习框架开发一个大模型聊天机器人的应用。Rust目前在AI界虽然还不够有影响力,但是未来是相当有潜力的,这也是为什么HuggingFace带头出一个Rust机器学习框架的原因。
不过这节课我们只是讲了怎么用起来,而如果要深入下去的话,机器学习、深度学习的基础知识就必不可少了。如果你有时间精力的话,你可以深入下去好好补充一下这方面的学术知识,毕竟未来几十年,AI是一个主要问题,也会是一个主要机会。目前AI发展速度太快了,有种学习跟不上业界发展速度的感觉。但是不管怎样,学好基础,永远不会过时。
另外,Rust AI这块儿,虽然已经小有起色,但是作为生态来讲,空白处还很多。所以这也正是学好Rust的机会,Rust可以在AI基建这块做大量的工作,这些工作可以服务于Rust社区,也可以服务于Python乃至整个AI社区。
思考题
你可以在我的示例上继续捣鼓,添加GPU的支持,在Linux、Windows、macOS多种平台上测试一下。欢迎你在评论区贴出你自己的代码,也欢迎你把这节课分享给其他朋友,我们下节课再见!
- eriklee 👍(6) 💬(1)
"首先 Rust 的性能与 C/C++ 一致,并且在调用 GPU 能力方面也同样方便" rust目前应该还不能直接访问cuda吧?毕竟cuda是c接口
2024-01-13 - 听雨 👍(2) 💬(1)
老师,我下载到了window下的子Linux中运行,报OS error。请问是不是不支持在这里面运行啊
2023-12-23 - tan 👍(1) 💬(2)
wsl: Error: No such file or directory (os error 2) . 处理:simple.rs中 tikenizer和model的参数写全[/mnt/xx/xx/openchat_3.5_tokenizer.json]
2024-01-15 - zhuxiufenghust 👍(0) 💬(2)
avx: false, neon: false, simd128: false, f16c: false Error: unknown magic 0x6f64213c 这个错误要怎么解决
2024-03-25 - 不忘初心 👍(0) 💬(1)
老师, AI 电销, 需要什么样的技术栈?
2024-03-21 - xl000 👍(0) 💬(2)
candle-core 0.4版本, 只需要改两处 tensor.ggml_dtype.blck_size() 改为 tensor.ggml_dtype.block_size() from_gguf(model, &mut file)那行改为 let device = Device::cuda_if_available(0)?; let mut model = quantized_model::ModelWeights::from_gguf(model, &mut file, &device)?;
2024-03-20 - Geek_3b58b9 👍(0) 💬(1)
Candle 要是支持 ROCm 就更好了
2024-02-04 - tan 👍(0) 💬(1)
WSL: linker `cc` not found 处理方式: sudo apt update && sudo apt install build-essential
2024-01-15 - eriklee 👍(0) 💬(1)
老师能对比下candle和burn吗? 另外,感觉rust优势是边缘端推理,毕竟边缘侧资源紧张. 服务器端推理,毕竟还是比不过python生态
2024-01-13 - 鸠摩智 👍(0) 💬(1)
老师,请问一下,可以下载别人训练好的模型,通过candle 来实现根据需求描述自动生成测试用例的这种功能吗?
2024-01-06 - superggn 👍(0) 💬(1)
mac 2021 款 intel 芯片, 系统版本 Montery 12.5.1 会卡住, model built 之后输入一个 hello 就不动了, 要过 5 分钟以上才会有回复
2024-01-02 - Geek_54dac1 👍(0) 💬(1)
该模型在 Candle 中还暂时不支持在 GPU 上运行,因为 Quantized models on Cuda 还不支持,参考:https://github.com/huggingface/candle/issues/1250;避免大家挖坑
2023-12-15 - 无限可能 👍(0) 💬(3)
老师好,大模型小白有几个问题。通过 huggingface-cli scan-cache,扫描了一下 huggingface 下载过的文件,都是 7G 或者更大: 1. 这些模型文件里主要是什么内容,是包括了数据么?为啥会这么大? 2. 类似这么大的文件,一定要下载到本地么,是否可以云部署之类,可以用完销毁。 3. 我理解,我想要运行本节课的demo,这 7G 的空间就是要长期规划出去了吧
2023-12-15 - 学水 👍(0) 💬(1)
这张感觉主要是给对rust有兴趣的mle,中间一些代码如果不是因为用过pytorch,完全不知道在干啥
2023-12-15 - 刘丹 👍(0) 💬(1)
老师,你好,请问能否推荐几个和 ChatGPT 交互的 Rust SDK ? 最好也能支持其它 LLM 。
2023-12-15