Building a Retrieval‑Augmented Generation (RAG) System with Rust and Qdrant
This article explains how to construct a Retrieval‑Augmented Generation pipeline in Rust, covering knowledge‑base creation with Qdrant, model loading and embedding using the candle library, data ingestion, and integration of a Rust‑based inference service based on mistral.rs, while also discussing resource usage and common pitfalls.
Retrieval‑Augmented Generation (RAG) enhances large language models (LLMs) by retrieving up‑to‑date information from external sources, improving answer accuracy and allowing knowledge updates without retraining.
The tutorial shows how to build a complete RAG demo in Rust, using Langchain‑like components, Qdrant as a pure‑Rust vector database, and the candle framework for model inference.
Knowledge Base Construction
The knowledge base consists of a model and a vector store. The Rust‑only vector database Qdrant is selected, and the most critical step is generating embeddings for documents.
Model Loading
The following Rust code loads a BERT model and its tokenizer from HuggingFace, handling both PyTorch and safetensors weights:
async fn build_model_and_tokenizer(model_config: &ConfigModel) -> Result<(BertModel, Tokenizer)> {
let device = Device::new_cuda(0)?;
let repo = Repo::with_revision(
model_config.model_id.clone(),
RepoType::Model,
model_config.revision.clone(),
);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = ApiBuilder::new().build()?;
let api = api.repo(repo);
let config = api.get("config.json").await?;
let tokenizer = api.get("tokenizer.json").await?;
let weights = if model_config.use_pth {
api.get("pytorch_model.bin").await?
} else {
api.get("model.safetensors").await?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if model_config.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
if model_config.approximate_gelu {
config.hidden_act = HiddenAct::GeluApproximate;
}
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}To avoid repeated loading, the model and tokenizer are stored in static OnceCell globals.
pub static GLOBAL_EMBEDDING_MODEL: OnceCell<(BertModel, Tokenizer)> = OnceCell::const_new();
pub async fn init_model_and_tokenizer() -> Arc<(BertModel, Tokenizer)> {
let config = get_config().unwrap();
let (m, t) = build_model_and_tokenizer(&config.model).await.unwrap();
Arc::new((m, t))
}Embedding Function
The embedding routine tokenizes input text, runs it through the model, averages token vectors, and normalizes the result:
pub async fn embedding_setence(content: &str) -> Result<Vec<Vec<f64>>> {
let m_t = GLOBAL_EMBEDDING_MODEL.get().unwrap();
let tokens = m_t.1.encode(content, true).map_err(E::msg)?.get_ids().to_vec();
let token_ids = Tensor::new(&tokens[..], &m_t.0.device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
let sequence_output = m_t.0.forward(&token_ids, &token_type_ids)?;
let (_n_sentence, n_tokens, _hidden_size) = sequence_output.dims3()?;
let embeddings = (sequence_output.sum(1)? / (n_tokens as f64))?;
let embeddings = normalize_l2(&embeddings)?;
let encodings = embeddings.to_vec2::()?;
Ok(encodings)
}Data Ingestion
Documents are read from a directory, deserialized into a Doc struct, embedded, and up‑serted into Qdrant in batches of 100 points:
pub async fn load_dir(&self, path: &str, collection_name: &str) {
let mut points = vec![];
for entry in WalkDir::new(path).into_iter().filter_map(Result::ok) {
if let Some(p) = entry.path().to_str() {
let id = Uuid::new_v4();
let content = fs::read_to_string(p).unwrap_or_default();
let doc: Doc = from_str::(content.as_str()).unwrap();
let mut payload = Payload::new();
payload.insert("content", doc.content);
payload.insert("title", doc.title);
payload.insert("product", doc.product);
payload.insert("url", doc.url);
let vector_contens = embedding_setence(content.as_str()).await.unwrap();
let ps = PointStruct::new(id.to_string(), vector_contens[0].clone(), payload);
points.push(ps);
if points.len() == 100 {
self.client.upsert_points(UpsertPointsBuilder::new(collection_name, points.clone()).wait(true)).await.unwrap();
points.clear();
}
}
}
if !points.is_empty() {
self.client.upsert_points(UpsertPointsBuilder::new(collection_name, points).wait(true)).await.unwrap();
}
}Inference Service
The inference server is built with mistral.rs , exposing an OpenAI‑compatible API. The model (Qwen2‑7B) is downloaded via a mirror and run with CUDA support:
git clone https://github.com/EricLBuehler/mistral.rs
cd mistral.rs
cargo run --bin mistralrs-server --features cuda -- --port 3333 plain -m /root/Qwen2-7B -a qwen2A global OpenAI client is configured with a 30‑second timeout:
pub static GLOBAL_OPENAI_CLIENT: Lazy<Arc<OpenAIClient>> = Lazy::new(|| {
let mut client = OpenAIClient::new_with_endpoint("http://10.0.0.7:3333/v1".to_string(), "EMPTY".to_string());
client.timeout = Some(30);
Arc::new(client)
});The final answer function retrieves relevant chunks from Qdrant, builds a Chinese prompt, and calls the OpenAI‑compatible endpoint:
pub async fn answer(question: &str, max_len: i64) -> Result<String> {
let retriver = retriever(question, 1).await?;
let mut context = String::new();
for sp in retriver.result {
let payload = sp.payload;
context.push_str(payload.get("product").unwrap());
context.push_str(payload.get("title").unwrap());
context.push_str(payload.get("content").unwrap());
}
let prompt = format!("你是一个云技术专家, 使用以下检索到的Context回答问题。用中文回答问题。\nQuestion: {}\nContext: {}", question, context);
let req = ChatCompletionRequest::new("".to_string(), vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(prompt), name: None, tool_calls: None, tool_call_id: None }]).max_tokens(max_len);
let cr = GLOBAL_OPENAI_CLIENT.chat_completion(req).await?;
Ok(cr.choices[0].message.content.clone())
}Resource Comparison & Pitfalls
The author compares GPU memory consumption of embedding models (m3e‑large) and inference models (Qwen1.5‑1.8B‑Chat, Qwen2‑7B) across the vllm and mistral.rs runtimes, noting that Qwen2‑7B exceeds memory limits in vllm but runs within limits in mistral.rs.
A common issue is that the default hf‑hub client does not support domestic mirrors; the solution is to set a custom endpoint in ApiBuilder::with_endpoint .
impl ApiBuilder {
/// Set endpoint example 'https://hf-mirror.com'
pub fn with_endpoint(mut self, endpoint: &str) -> Self {
self.endpoint = endpoint.to_string();
self
}
}The article concludes with a link to the full project repository and an invitation to join a technical discussion group.
JD Tech Talk
Official JD Tech public account delivering best practices and technology innovation.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.