use std::borrow::{Borrow, Cow};
use arrow_format::ipc::planus::Builder;
use crate::array::*;
use crate::chunk::Chunk;
use crate::datatypes::*;
use crate::error::{Error, Result};
use crate::io::ipc::endianess::is_native_little_endian;
use crate::io::ipc::read::Dictionaries;
use super::super::IpcField;
use super::{write, write_dictionary};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Compression {
LZ4,
ZSTD,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct WriteOptions {
pub compression: Option<Compression>,
}
fn encode_dictionary(
field: &IpcField,
array: &dyn Array,
options: &WriteOptions,
dictionary_tracker: &mut DictionaryTracker,
encoded_dictionaries: &mut Vec<EncodedData>,
) -> Result<()> {
use PhysicalType::*;
match array.data_type().to_physical_type() {
Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
| FixedSizeBinary => Ok(()),
Dictionary(key_type) => match_integer_type!(key_type, |$T| {
let dict_id = field.dictionary_id
.ok_or_else(|| Error::InvalidArgumentError("Dictionaries must have an associated id".to_string()))?;
let emit = dictionary_tracker.insert(dict_id, array)?;
let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
let values = array.values();
encode_dictionary(field,
values.as_ref(),
options,
dictionary_tracker,
encoded_dictionaries
)?;
if emit {
encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>(
dict_id,
array,
options,
is_native_little_endian(),
));
};
Ok(())
}),
Struct => {
let array = array.as_any().downcast_ref::<StructArray>().unwrap();
let fields = field.fields.as_slice();
if array.fields().len() != fields.len() {
return Err(Error::InvalidArgumentError(
"The number of fields in a struct must equal the number of children in IpcField".to_string(),
));
}
fields
.iter()
.zip(array.values().iter())
.try_for_each(|(field, values)| {
encode_dictionary(
field,
values.as_ref(),
options,
dictionary_tracker,
encoded_dictionaries,
)
})
}
List => {
let values = array
.as_any()
.downcast_ref::<ListArray<i32>>()
.unwrap()
.values();
let field = &field.fields[0]; encode_dictionary(
field,
values.as_ref(),
options,
dictionary_tracker,
encoded_dictionaries,
)
}
LargeList => {
let values = array
.as_any()
.downcast_ref::<ListArray<i64>>()
.unwrap()
.values();
let field = &field.fields[0]; encode_dictionary(
field,
values.as_ref(),
options,
dictionary_tracker,
encoded_dictionaries,
)
}
FixedSizeList => {
let values = array
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.values();
let field = &field.fields[0]; encode_dictionary(
field,
values.as_ref(),
options,
dictionary_tracker,
encoded_dictionaries,
)
}
Union => {
let values = array
.as_any()
.downcast_ref::<UnionArray>()
.unwrap()
.fields();
let fields = &field.fields[..]; if values.len() != fields.len() {
return Err(Error::InvalidArgumentError(
"The number of fields in a union must equal the number of children in IpcField"
.to_string(),
));
}
fields
.iter()
.zip(values.iter())
.try_for_each(|(field, values)| {
encode_dictionary(
field,
values.as_ref(),
options,
dictionary_tracker,
encoded_dictionaries,
)
})
}
Map => {
let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
let field = &field.fields[0]; encode_dictionary(
field,
values.as_ref(),
options,
dictionary_tracker,
encoded_dictionaries,
)
}
}
}
pub fn encode_chunk(
chunk: &Chunk<Box<dyn Array>>,
fields: &[IpcField],
dictionary_tracker: &mut DictionaryTracker,
options: &WriteOptions,
) -> Result<(Vec<EncodedData>, EncodedData)> {
let mut encoded_message = EncodedData::default();
let encoded_dictionaries = encode_chunk_amortized(
chunk,
fields,
dictionary_tracker,
options,
&mut encoded_message,
)?;
Ok((encoded_dictionaries, encoded_message))
}
pub fn encode_chunk_amortized(
chunk: &Chunk<Box<dyn Array>>,
fields: &[IpcField],
dictionary_tracker: &mut DictionaryTracker,
options: &WriteOptions,
encoded_message: &mut EncodedData,
) -> Result<Vec<EncodedData>> {
let mut encoded_dictionaries = vec![];
for (field, array) in fields.iter().zip(chunk.as_ref()) {
encode_dictionary(
field,
array.as_ref(),
options,
dictionary_tracker,
&mut encoded_dictionaries,
)?;
}
chunk_to_bytes_amortized(chunk, options, encoded_message);
Ok(encoded_dictionaries)
}
fn serialize_compression(
compression: Option<Compression>,
) -> Option<Box<arrow_format::ipc::BodyCompression>> {
if let Some(compression) = compression {
let codec = match compression {
Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd,
};
Some(Box::new(arrow_format::ipc::BodyCompression {
codec,
method: arrow_format::ipc::BodyCompressionMethod::Buffer,
}))
} else {
None
}
}
fn chunk_to_bytes_amortized(
chunk: &Chunk<Box<dyn Array>>,
options: &WriteOptions,
encoded_message: &mut EncodedData,
) {
let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
let mut arrow_data = std::mem::take(&mut encoded_message.arrow_data);
arrow_data.clear();
let mut offset = 0;
for array in chunk.arrays() {
write(
array.as_ref(),
&mut buffers,
&mut arrow_data,
&mut nodes,
&mut offset,
is_native_little_endian(),
options.compression,
)
}
let compression = serialize_compression(options.compression);
let message = arrow_format::ipc::Message {
version: arrow_format::ipc::MetadataVersion::V5,
header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
arrow_format::ipc::RecordBatch {
length: chunk.len() as i64,
nodes: Some(nodes),
buffers: Some(buffers),
compression,
},
))),
body_length: arrow_data.len() as i64,
custom_metadata: None,
};
let mut builder = Builder::new();
let ipc_message = builder.finish(&message, None);
encoded_message.ipc_message = ipc_message.to_vec();
encoded_message.arrow_data = arrow_data
}
fn dictionary_batch_to_bytes<K: DictionaryKey>(
dict_id: i64,
array: &DictionaryArray<K>,
options: &WriteOptions,
is_little_endian: bool,
) -> EncodedData {
let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
let mut arrow_data: Vec<u8> = vec![];
let length = write_dictionary(
array,
&mut buffers,
&mut arrow_data,
&mut nodes,
&mut 0,
is_little_endian,
options.compression,
false,
);
let compression = serialize_compression(options.compression);
let message = arrow_format::ipc::Message {
version: arrow_format::ipc::MetadataVersion::V5,
header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
arrow_format::ipc::DictionaryBatch {
id: dict_id,
data: Some(Box::new(arrow_format::ipc::RecordBatch {
length: length as i64,
nodes: Some(nodes),
buffers: Some(buffers),
compression,
})),
is_delta: false,
},
))),
body_length: arrow_data.len() as i64,
custom_metadata: None,
};
let mut builder = Builder::new();
let ipc_message = builder.finish(&message, None);
EncodedData {
ipc_message: ipc_message.to_vec(),
arrow_data,
}
}
pub struct DictionaryTracker {
pub dictionaries: Dictionaries,
pub cannot_replace: bool,
}
impl DictionaryTracker {
pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> Result<bool> {
let values = match array.data_type() {
DataType::Dictionary(key_type, _, _) => {
match_integer_type!(key_type, |$T| {
let array = array
.as_any()
.downcast_ref::<DictionaryArray<$T>>()
.unwrap();
array.values()
})
}
_ => unreachable!(),
};
if let Some(last) = self.dictionaries.get(&dict_id) {
if last.as_ref() == values.as_ref() {
return Ok(false);
} else if self.cannot_replace {
return Err(Error::InvalidArgumentError(
"Dictionary replacement detected when writing IPC file format. \
Arrow IPC files only support a single dictionary for a given field \
across all batches."
.to_string(),
));
}
};
self.dictionaries.insert(dict_id, values.clone());
Ok(true)
}
}
#[derive(Debug, Default)]
pub struct EncodedData {
pub ipc_message: Vec<u8>,
pub arrow_data: Vec<u8>,
}
#[inline]
pub(crate) fn pad_to_64(len: usize) -> usize {
((len + 63) & !63) - len
}
#[derive(Debug, Clone, PartialEq)]
pub struct Record<'a> {
columns: Cow<'a, Chunk<Box<dyn Array>>>,
fields: Option<Cow<'a, [IpcField]>>,
}
impl<'a> Record<'a> {
pub fn fields(&self) -> Option<&[IpcField]> {
self.fields.as_deref()
}
pub fn columns(&self) -> &Chunk<Box<dyn Array>> {
self.columns.borrow()
}
}
impl From<Chunk<Box<dyn Array>>> for Record<'static> {
fn from(columns: Chunk<Box<dyn Array>>) -> Self {
Self {
columns: Cow::Owned(columns),
fields: None,
}
}
}
impl<'a, F> From<(Chunk<Box<dyn Array>>, Option<F>)> for Record<'a>
where
F: Into<Cow<'a, [IpcField]>>,
{
fn from((columns, fields): (Chunk<Box<dyn Array>>, Option<F>)) -> Self {
Self {
columns: Cow::Owned(columns),
fields: fields.map(|f| f.into()),
}
}
}
impl<'a, F> From<(&'a Chunk<Box<dyn Array>>, Option<F>)> for Record<'a>
where
F: Into<Cow<'a, [IpcField]>>,
{
fn from((columns, fields): (&'a Chunk<Box<dyn Array>>, Option<F>)) -> Self {
Self {
columns: Cow::Borrowed(columns),
fields: fields.map(|f| f.into()),
}
}
}