Skip to content

Commit

Permalink
Merge pull request #23 from rainj-me/master
Browse files Browse the repository at this point in the history
support directly use aws sdk
  • Loading branch information
marsupialtail authored Apr 24, 2024
2 parents a1701eb + dcdaea3 commit bc7498a
Show file tree
Hide file tree
Showing 17 changed files with 476 additions and 321 deletions.
10 changes: 7 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ crate-type = ["cdylib"]

[features]
default = [] #['py']
# default = ["opendal"]
py = ["dep:pyo3", "pyarrow", "dep:pyo3-log"]
pyarrow = ["arrow/pyarrow"]
opendal = ["dep:opendal"]
aws_sdk = ["dep:aws-sdk-s3", "dep:aws-config"]


[dependencies]
pyo3 = { version = "0.20.0", features = [
Expand All @@ -23,7 +27,7 @@ pyo3-log = { version = "0.9.0", optional = true }
arrow = { version = "50.0.0", default-features = false }
tokenizers = { version = "0.15.2", features = ["http"] }
whatlang = "0.16.4"
opendal = "0"
opendal = { version = "0", optional = true }
zstd = "0.13.0" # Check for the latest version of zstd crate
serde = { version = "1.0", features = ["derive"] }
bincode = "1.3" # For serialization and deserialization
Expand Down Expand Up @@ -65,8 +69,8 @@ rand = "0.8.5"
serde_json = "1.0"
uuid = { version = "1.0", features = ["v4", "serde"] }
async-recursion = "1.0.5"
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-sdk-s3 = "1.23.0"
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] , optional = true}
aws-sdk-s3 = { version = "1.23.0", optional = true }
bitvector = "0.1.5"
ndarray = { version = "0.15.6", features = ["rayon", "serde"] }
numpy = "0.20.0"
Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,9 @@ Rottnest not only supports BM25 indices but also other indices, like regex and v

### Build Python wheel
```bash
maturin develop --features py
maturin develop --features "py,opendal"
```
or
```bash
maturin develop --features "py,aws_sdk"
```
3 changes: 2 additions & 1 deletion src/formats/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod readers;

pub mod parquet;
pub mod io;

pub use parquet::get_parquet_layout;
pub use parquet::read_indexed_pages;
Expand Down
35 changes: 3 additions & 32 deletions src/formats/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,52 +21,23 @@ use parquet::{
};
use thrift::protocol::TCompactInputProtocol;

use opendal::raw::oio::ReadExt;

use bytes::Bytes;
use std::convert::TryFrom;
use std::io::{Read, SeekFrom};
use std::io::Read;

use futures::stream::{self, StreamExt};
use itertools::{izip, Itertools};
use regex::Regex;
use std::collections::HashMap;

use std::{env, usize};
use tokio::{self};

use crate::{
formats::io::{AsyncReader, FsBuilder, Operators, S3Builder},
formats::readers::{AsyncReader, get_file_size_and_reader},
lava::error::LavaError,
};

use super::io::READER_BUFFER_SIZE;

async fn get_reader_and_size_from_file(file: &str) -> Result<(usize, AsyncReader), LavaError> {
let mut file_name = file.to_string();
let operator = if file.starts_with("s3://") {
file_name = file_name.replace("s3://", "");
let mut iter = file_name.split("/");
let bucket = iter.next().expect("malformed s3 path");
file_name = file_name[bucket.len() + 1..].to_string();

Operators::from(S3Builder::from(file)).into_inner()
} else {
let current_path = env::current_dir().unwrap();
Operators::from(FsBuilder::from(current_path.to_str().expect("no path"))).into_inner()
};

let file_size: usize = operator.stat(&file_name).await?.content_length() as usize;
let reader: AsyncReader = AsyncReader::new(
operator
.clone()
.reader_with(&file_name)
.buffer(READER_BUFFER_SIZE)
.await?,
file_name.clone(),
);

Ok((file_size, reader))
get_file_size_and_reader(file.to_string()).await
}

async fn parse_metadata(
Expand Down
196 changes: 196 additions & 0 deletions src/formats/readers/aws_reader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use bytes::{Bytes, BytesMut};
use std::ops::{Deref, DerefMut};

use async_trait::async_trait;
use aws_sdk_s3::Client;

use crate::lava::error::LavaError;


pub struct AsyncAwsReader {
reader: Client,
pub bucket: String,
pub filename: String,
pub file_size: u64,
}

impl Deref for AsyncAwsReader {
type Target = Client;

fn deref(&self) -> &Self::Target {
&self.reader
}
}

impl DerefMut for AsyncAwsReader {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.reader
}
}

impl AsyncAwsReader {
pub fn new(reader: Client, bucket: String, filename: String) -> Self {
Self {
reader,
bucket,
filename,
file_size: 0,
}
}

async fn stat(&self) -> Result<u64, LavaError> {
let (bucket, filename) = (&self.bucket, &self.filename);
self.head_object()
.bucket(bucket)
.key(filename)
.send()
.await
.map_err(|e| LavaError::AwsSdk(e.to_string()))
.map(|res| match res.content_length() {
Some(size) if size > 0 => size as u64,
_ => 0,
})
}
}

#[async_trait]
impl super::Reader for AsyncAwsReader {
async fn read_range(&mut self, from: u64, to: u64) -> Result<Bytes, LavaError> {
if from >= to {
return Err(LavaError::Io(std::io::ErrorKind::InvalidData.into()));
}

let total = to - from;
let mut res = BytesMut::with_capacity(total as usize);
let (bucket, filename) = (&self.bucket, &self.filename);

let mut object = self
.get_object()
.bucket(bucket)
.key(filename)
.set_range(Some(format!("bytes={}-{}", from, to).to_string()))
.send()
.await
.map_err(|e| LavaError::AwsSdk(e.to_string()))?;

while let Some(chunk) = object.body.try_next().await.map_err(|e| LavaError::AwsSdk(e.to_string()))? {
res.extend_from_slice(&chunk);
}

if res.len() < total as usize {
return Err(LavaError::Io(std::io::ErrorKind::Interrupted.into()));
}

Ok(res.freeze())
}

async fn read_usize_from_end(&mut self, offset: i64, n: u64) -> Result<Vec<u64>, LavaError> {
let mut result: Vec<u64> = vec![];
let from = self.file_size as i64 + offset;
let to = from + (n as i64) * 8;
let bytes = self.read_range(from as u64, to as u64).await?;
bytes.chunks_exact(8).for_each(|chunk| {
result.push(u64::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7]]));
});
Ok(result)
}

async fn read_usize_from_start(&mut self, offset: u64, n: u64) -> Result<Vec<u64>, LavaError> {
let mut result: Vec<u64> = vec![];
let from = offset as i64;
let to = from + (n as i64) * 8;
let bytes = self.read_range(from as u64, to as u64).await?;
bytes.chunks_exact(8).for_each(|chunk| {
result.push(u64::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7]]));
});
Ok(result)
}
}

#[derive(Clone)]
pub struct Config(aws_config::SdkConfig);

impl Config {
pub async fn from_env() -> Self {
let config = aws_config::load_from_env().await;
Config(config)
}
}

#[derive(Clone)]
pub struct Operator(aws_sdk_s3::Client);

impl From<Config> for Operator {
fn from(config: Config) -> Self {
Operator(aws_sdk_s3::Client::new(&config.0))
}
}

impl Operator {
fn into_inner(self) -> aws_sdk_s3::Client {
self.0
}
}

pub async fn get_file_size_and_reader(
file: String,
) -> Result<(usize, AsyncAwsReader), LavaError> {
// Extract filename
if !file.starts_with("s3://") {
return Err(LavaError::Parse("File scheme not supported".to_string()));
}

let config = Config::from_env().await;
let operator = Operator::from(config);

let tokens = file[5..].split('/').collect::<Vec<_>>();
let bucket = tokens[0].to_string();
let filename = tokens[1..].join("/");

// Create the reader
let mut reader = AsyncAwsReader::new(operator.into_inner(), bucket.clone(), filename.clone());

// Get the file size
let file_size = reader.stat().await?;

if file_size == 0 {
return Err(LavaError::Parse("File size is zero".to_string()));
}
reader.file_size = file_size;

Ok((file_size as usize, reader))
}

pub async fn get_file_sizes_and_readers(
files: &[String],
) -> Result<(Vec<usize>, Vec<AsyncAwsReader>), LavaError> {
let tasks: Vec<_> = files
.iter()
.map(|file| {
let file = file.clone(); // Clone file name to move into the async block
tokio::spawn(async move {
get_file_size_and_reader(file).await
})
})
.collect();

// Wait for all tasks to complete
let results = futures::future::join_all(tasks).await;

// Process results, separating out file sizes and readers
let mut file_sizes = Vec::new();
let mut readers = Vec::new();

for result in results {
match result {
Ok(Ok((size, reader))) => {
file_sizes.push(size);
readers.push(reader);
}
Ok(Err(e)) => return Err(e), // Handle error from inner task
Err(e) => return Err(LavaError::Parse(format!("Task join error: {}", e.to_string()))), // Handle join error
}
}

Ok((file_sizes, readers))
}
Loading

0 comments on commit bc7498a

Please sign in to comment.