Skip to content

Commit

Permalink
Merge pull request #24 from rainj-me/master
Browse files Browse the repository at this point in the history
add reader type flag to switch between opendal and awssdk
  • Loading branch information
marsupialtail authored Apr 24, 2024
2 parents bc7498a + ba11b6a commit 55bb33e
Show file tree
Hide file tree
Showing 12 changed files with 221 additions and 183 deletions.
9 changes: 3 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@ crate-type = ["cdylib"]

[features]
default = [] #['py']
# default = ["opendal"]
py = ["dep:pyo3", "pyarrow", "dep:pyo3-log"]
pyarrow = ["arrow/pyarrow"]
opendal = ["dep:opendal"]
aws_sdk = ["dep:aws-sdk-s3", "dep:aws-config"]


[dependencies]
Expand All @@ -27,7 +24,7 @@ pyo3-log = { version = "0.9.0", optional = true }
arrow = { version = "50.0.0", default-features = false }
tokenizers = { version = "0.15.2", features = ["http"] }
whatlang = "0.16.4"
opendal = { version = "0", optional = true }
opendal = { version = "0" }
zstd = "0.13.0" # Check for the latest version of zstd crate
serde = { version = "1.0", features = ["derive"] }
bincode = "1.3" # For serialization and deserialization
Expand Down Expand Up @@ -69,8 +66,8 @@ rand = "0.8.5"
serde_json = "1.0"
uuid = { version = "1.0", features = ["v4", "serde"] }
async-recursion = "1.0.5"
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] , optional = true}
aws-sdk-s3 = { version = "1.23.0", optional = true }
aws-config = { version = "1.1.7", features = ["behavior-version-latest"]}
aws-sdk-s3 = { version = "1.23.0" }
bitvector = "0.1.5"
ndarray = { version = "0.15.6", features = ["rayon", "serde"] }
numpy = "0.20.0"
Expand Down
25 changes: 17 additions & 8 deletions src/formats/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ use crate::{
lava::error::LavaError,
};

async fn get_reader_and_size_from_file(file: &str) -> Result<(usize, AsyncReader), LavaError> {
get_file_size_and_reader(file.to_string()).await
}
use super::readers::ReaderType;

// async fn get_reader_and_size_from_file(file: &str) -> Result<(usize, AsyncReader), LavaError> {
// get_file_size_and_reader(file.to_string()).await
// }

async fn parse_metadata(
reader: &mut AsyncReader,
Expand Down Expand Up @@ -184,16 +186,18 @@ fn read_page_header<C: ChunkReader>(
Ok((tracked.1, header))
}

async fn parse_metadatas(file_paths: &Vec<String>) -> HashMap<String, ParquetMetaData> {
async fn parse_metadatas(file_paths: &Vec<String>, reader_type: ReaderType) -> HashMap<String, ParquetMetaData> {
let iter = file_paths.iter().dedup();

let handles = stream::iter(iter)
.map(|file_path: &String| {
let file_path = file_path.clone();
let reader_type = reader_type.clone();

tokio::spawn(async move {
let (file_size, mut reader) =
get_reader_and_size_from_file(&file_path).await.unwrap();
get_file_size_and_reader(file_path.clone(), reader_type).await.unwrap();

let metadata = parse_metadata(&mut reader, file_size as usize)
.await
.unwrap();
Expand Down Expand Up @@ -231,8 +235,9 @@ pub struct ParquetLayout {
pub async fn get_parquet_layout(
column_name: &str,
file_path: &str,
reader_type: ReaderType,
) -> Result<(arrow::array::ArrayData, ParquetLayout), LavaError> {
let (file_size, mut reader) = get_reader_and_size_from_file(file_path).await?;
let (file_size, mut reader) = get_file_size_and_reader(file_path.to_string(), reader_type).await?;
let metadata = parse_metadata(&mut reader, file_size as usize).await?;

let codec_options = CodecOptionsBuilder::default()
Expand Down Expand Up @@ -410,6 +415,7 @@ pub async fn read_indexed_pages_async(
page_offsets: Vec<u64>,
page_sizes: Vec<usize>,
dict_page_sizes: Vec<usize>, // 0 means no dict page
reader_type: ReaderType,
) -> Result<Vec<ArrayData>, LavaError> {
// current implementation might re-read dictionary pages, this should be optimized
// we are assuming that all the files are either on disk or cloud.
Expand All @@ -418,7 +424,7 @@ pub async fn read_indexed_pages_async(
.set_backward_compatible_lz4(false)
.build();

let metadatas = parse_metadatas(&file_paths).await;
let metadatas = parse_metadatas(&file_paths, reader_type.clone()).await;

let iter = izip!(
file_paths,
Expand Down Expand Up @@ -456,11 +462,12 @@ pub async fn read_indexed_pages_async(
let mut codec = create_codec(compression_scheme, &codec_options)
.unwrap()
.unwrap();
let reader_type = reader_type.clone();

let handle = tokio::spawn(async move {
debug!("tokio spawn thread: {:?}", std::thread::current().id());
let (_file_size, mut reader) =
get_reader_and_size_from_file(&file_path).await.unwrap();
get_file_size_and_reader(file_path.clone(), reader_type).await.unwrap();
let mut pages: Vec<parquet::column::page::Page> = Vec::new();
if dict_page_size > 0 {
let start = dict_page_offset.unwrap() as u64;
Expand Down Expand Up @@ -560,6 +567,7 @@ pub async fn read_indexed_pages(
page_offsets: Vec<u64>,
page_sizes: Vec<usize>,
dict_page_sizes: Vec<usize>, // 0 means no dict page
reader_type: ReaderType,
) -> Result<Vec<ArrayData>, LavaError> {
read_indexed_pages_async(
column_name,
Expand All @@ -568,6 +576,7 @@ pub async fn read_indexed_pages(
page_offsets,
page_sizes,
dict_page_sizes,
reader_type,
)
.await
}
36 changes: 1 addition & 35 deletions src/formats/readers/aws_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl Operator {
}
}

pub async fn get_file_size_and_reader(
pub(crate) async fn get_reader(
file: String,
) -> Result<(usize, AsyncAwsReader), LavaError> {
// Extract filename
Expand Down Expand Up @@ -160,37 +160,3 @@ pub async fn get_file_size_and_reader(

Ok((file_size as usize, reader))
}

pub async fn get_file_sizes_and_readers(
files: &[String],
) -> Result<(Vec<usize>, Vec<AsyncAwsReader>), LavaError> {
let tasks: Vec<_> = files
.iter()
.map(|file| {
let file = file.clone(); // Clone file name to move into the async block
tokio::spawn(async move {
get_file_size_and_reader(file).await
})
})
.collect();

// Wait for all tasks to complete
let results = futures::future::join_all(tasks).await;

// Process results, separating out file sizes and readers
let mut file_sizes = Vec::new();
let mut readers = Vec::new();

for result in results {
match result {
Ok(Ok((size, reader))) => {
file_sizes.push(size);
readers.push(reader);
}
Ok(Err(e)) => return Err(e), // Handle error from inner task
Err(e) => return Err(LavaError::Parse(format!("Task join error: {}", e.to_string()))), // Handle join error
}
}

Ok((file_sizes, readers))
}
120 changes: 68 additions & 52 deletions src/formats/readers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ use zstd::stream::read::Decoder;

use crate::lava::error::LavaError;

#[cfg(feature = "opendal")]
mod opendal_reader;

#[cfg(feature = "aws_sdk")]
mod aws_reader;

#[async_trait]
Expand Down Expand Up @@ -77,65 +74,84 @@ impl AsyncReader {
}
}

pub async fn get_file_sizes_and_readers(
files: &[String],
) -> Result<(Vec<usize>, Vec<AsyncReader>), LavaError> {
#[cfg(feature = "opendal")]
{
let (file_sizes, readers) = opendal_reader::get_file_sizes_and_readers(files).await?;
let async_readers = readers
.into_iter()
.map(|reader| {
let filename = reader.filename.clone();
AsyncReader::new(Box::new(reader), filename)
})
.collect();
#[derive(Debug, Clone, Default)]
pub enum ReaderType {
#[default]
Opendal,
AwsSdk,
}

Ok((file_sizes, async_readers))
impl From<String> for ReaderType {
fn from(value: String) -> Self {
match value.to_lowercase().as_str() {
"opendal" => ReaderType::Opendal,
"aws" => ReaderType::AwsSdk,
_ => Default::default(),
}
}
}

#[cfg(feature = "aws_sdk")]
{
let (file_sizes, readers) = aws_reader::get_file_sizes_and_readers(files).await?;
let async_readers = readers
.into_iter()
.map(|reader| {
let filename = reader.filename.clone();
AsyncReader::new(Box::new(reader), filename)
pub async fn get_file_sizes_and_readers(
files: &[String], reader_type: ReaderType
) -> Result<(Vec<usize>, Vec<AsyncReader>), LavaError> {
let tasks: Vec<_> = files
.iter()
.map(|file| {
let file = file.clone();
let reader_type = reader_type.clone();
tokio::spawn(async move {
get_file_size_and_reader(file, reader_type).await
})
.collect();
Ok((file_sizes, async_readers))
})
.collect();

// Wait for all tasks to complete
let results = futures::future::join_all(tasks).await;

// Process results, separating out file sizes and readers
let mut file_sizes = Vec::new();
let mut readers = Vec::new();

for result in results {
match result {
Ok(Ok((size, reader))) => {
file_sizes.push(size);
readers.push(reader);
}
Ok(Err(e)) => return Err(e), // Handle error from inner task
Err(e) => return Err(LavaError::Parse(format!("Task join error: {}", e.to_string()))), // Handle join error
}
}

#[cfg(not(any(feature = "opendal", feature = "aws_sdk")))]
{
let _ = files;
Err(LavaError::Unsupported("Must set either opendal or aws_sdk feature.".to_string()))
}
Ok((file_sizes, readers))
}

pub async fn get_file_size_and_reader(
file: String,
file: String, reader_type: ReaderType
) -> Result<(usize, AsyncReader), LavaError> {
#[cfg(feature = "opendal")]
{
let (file_size, reader) = opendal_reader::get_file_size_and_reader(file).await?;
let filename = reader.filename.clone();
let async_reader = AsyncReader::new(Box::new(reader), filename);
Ok((file_size, async_reader))
}
#[cfg(feature = "aws_sdk")]
{
let (file_size, reader) = aws_reader::get_file_size_and_reader(file).await?;
let filename = reader.filename.clone();
let async_reader = AsyncReader::new(Box::new(reader), filename);
Ok((file_size, async_reader))
}

#[cfg(not(any(feature = "opendal", feature = "aws_sdk")))]
{
let _ = file;
Err(LavaError::Unsupported("Must set either opendal or aws_sdk feature.".to_string()))
}
// always choose opendal for none s3 file
let reader_type = if file.starts_with("s3://") {
reader_type
} else {
Default::default()
};

let (file_size, reader) = match reader_type {
ReaderType::Opendal => {
let (file_size, reader) = opendal_reader::get_reader(file).await?;
let filename = reader.filename.clone();
let reader = AsyncReader::new(Box::new(reader), filename);
(file_size, reader)
}
ReaderType::AwsSdk => {
let (file_size, reader) = aws_reader::get_reader(file).await?;
let filename = reader.filename.clone();
let async_reader = AsyncReader::new(Box::new(reader), filename);
(file_size, async_reader)
}
};


Ok((file_size, reader))
}
36 changes: 1 addition & 35 deletions src/formats/readers/opendal_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl From<FsBuilder> for Operators {
}
}

pub async fn get_file_size_and_reader(
pub(crate) async fn get_reader(
file: String,
) -> Result<(usize, AsyncOpendalReader), LavaError> {
// Determine the operator based on the file scheme
Expand Down Expand Up @@ -187,37 +187,3 @@ pub async fn get_file_size_and_reader(

Ok((file_size as usize, reader))
}

pub async fn get_file_sizes_and_readers(
files: &[String],
) -> Result<(Vec<usize>, Vec<AsyncOpendalReader>), LavaError> {
let tasks: Vec<_> = files
.iter()
.map(|file| {
let file = file.clone(); // Clone file name to move into the async block
tokio::spawn(async move {
get_file_size_and_reader(file).await
})
})
.collect();

// Wait for all tasks to complete
let results = futures::future::join_all(tasks).await;

// Process results, separating out file sizes and readers
let mut file_sizes = Vec::new();
let mut readers = Vec::new();

for result in results {
match result {
Ok(Ok((size, reader))) => {
file_sizes.push(size);
readers.push(reader);
}
Ok(Err(e)) => return Err(e), // Handle error from inner task
Err(e) => return Err(LavaError::Parse(format!("Task join error: {}", e.to_string()))), // Handle join error
}
}

Ok((file_sizes, readers))
}
4 changes: 0 additions & 4 deletions src/lava/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ pub enum LavaError {
Bincode(#[from] bincode::Error),
Compression(String),
Arrow(#[from] arrow::error::ArrowError),
#[cfg(feature = "opendal")]
OpenDAL(#[from] opendal::Error),
#[cfg(feature = "aws_sdk")]
AwsSdk(String),
Parse(String),
Parquet(#[from] parquet::errors::ParquetError),
Expand All @@ -27,9 +25,7 @@ impl Display for LavaError {
LavaError::Bincode(err) => write!(f, "Bincode error: {}", err),
LavaError::Compression(err) => write!(f, "Compression error: {}", err),
LavaError::Arrow(err) => write!(f, "Arrow error: {}", err),
#[cfg(feature = "opendal")]
LavaError::OpenDAL(err) => write!(f, "OpenDAL error: {}", err),
#[cfg(feature = "aws_sdk")]
LavaError::AwsSdk(err) => write!(f, "AWS SDK error: {}", err),
LavaError::Parse(err) => write!(f, "Parse error: {}", err),
LavaError::Unknown => write!(f, "Unkown error"),
Expand Down
Loading

0 comments on commit 55bb33e

Please sign in to comment.