diff --git a/src/lib.rs b/src/lib.rs index 301cfc1..5de45c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,11 +10,14 @@ use datafusion::prelude::*; pub const VERTEX_ID: &str = "id"; pub const EDGE_SRC: &str = "src"; pub const EDGE_DST: &str = "dst"; +pub const EDGE_COL: &str = "edge"; +pub const SRC_VERTEX: &str = "src_vertex"; +pub const DST_VERTEX: &str = "dst_vertex"; #[derive(Debug, Clone)] pub struct GraphFrame { - vertices: DataFrame, - edges: DataFrame, + pub vertices: DataFrame, + pub edges: DataFrame, } impl GraphFrame { @@ -43,13 +46,107 @@ impl GraphFrame { )?; Ok(df.select(vec![col(EDGE_SRC).alias(VERTEX_ID), col("out_degree")])?) } + + /// Generates a DataFrame containing "triplets" by combining information from edges and vertices. + /// + /// This method aggregates data about source vertices, edges, and destination vertices, + /// producing a combined representation of these relationships as triplets. + /// It constructs structured representations of edges and vertices, then performs + /// joins to associate source and destination vertices with their respective edges. + /// + /// # Returns + /// + /// Returns a `Result` which can either: + /// - Contain the `DataFrame` representing the triplets (source vertex, edge, destination vertex). + /// - Return an error if an operation (e.g., selection or join) fails during the process. + /// + /// Output `DataFrame` contains the following columns: + /// - `SRC_VERTEX` - struct with all the columns of vertices, associated with a source of the triple + /// - `EDGE_COL` - struct with all the columns of edges, associated with an edge + /// - `DST_VERTEX` - struct with all the columns of vertices, associated with a destination of the triplet + /// + /// # Errors + /// + /// This method will return an error if: + /// - Either the source vertices or destination vertices cannot be joined with edges due to schema mismatches. + /// - Any selection or transformation process internally fails due to invalid queries. + /// + /// # Example + /// + /// ``` + /// use datafusion::dataframe; + /// use graphframes_rs::{GraphFrame, VERTEX_ID, EDGE_SRC, EDGE_DST}; + /// let vertices = dataframe!( + /// VERTEX_ID => vec![1i64, 2i64, 3i64], + /// "attr" => vec!["a", "b", "c"] + /// ).unwrap(); + /// let edges = dataframe!( + /// EDGE_SRC => vec![1i64, 2i64, 3i64], + /// EDGE_DST => vec![3i64, 1i64, 2i64], + /// "attr" => vec!["d", "j", "h"] + /// ).unwrap(); + /// + /// let graph = GraphFrame { vertices, edges }; + /// let triplets = graph.triplets(); + /// ``` + /// // Assuming `edges_df` and `vertices_df` are initialized DataFrames for + pub async fn triplets(&self) -> Result { + let edges_struct = self.edges.clone().select(vec![ + col(EDGE_SRC), + col(EDGE_DST), + named_struct( + self.edges + .clone() + .schema() + .fields() + .iter() + .map(|field| field.name()) + .flat_map(|name| vec![lit(name), col(name)]) + .collect(), + ) + .alias(EDGE_COL), + ])?; + let vertices_struct = self.vertices.clone().select(vec![ + col(VERTEX_ID), + named_struct( + self.vertices + .clone() + .schema() + .fields() + .iter() + .map(|field| field.name()) + .flat_map(|name| vec![lit(name), col(name)]) + .collect(), + ) + .alias("_vertex_struct"), + ])?; + edges_struct + .join_on( + vertices_struct.clone().select(vec![ + col(VERTEX_ID), + col("_vertex_struct").alias(SRC_VERTEX), + ])?, + JoinType::Left, + vec![col(EDGE_SRC).eq(col(VERTEX_ID))], + )? + .select(vec![col(SRC_VERTEX), col(EDGE_DST), col(EDGE_COL)])? + .join_on( + vertices_struct.select(vec![ + col(VERTEX_ID), + col("_vertex_struct").alias(DST_VERTEX), + ])?, + JoinType::Left, + vec![col(EDGE_DST).eq(col(VERTEX_ID))], + )? + .select(vec![col(SRC_VERTEX), col(EDGE_COL), col(DST_VERTEX)]) + } } #[cfg(test)] mod tests { use super::*; use datafusion::arrow::array::{Int64Array, RecordBatch, StringArray}; - use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion::arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use std::collections::HashMap; use std::sync::Arc; @@ -183,4 +280,51 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_triplets() -> Result<()> { + let vertices = + dataframe!(VERTEX_ID => vec![1i64, 2i64, 3i64], "attr" => vec!["a", "b", "c"])?; + let edges = dataframe!(EDGE_SRC => vec![1i64, 2i64, 3i64], EDGE_DST => vec![3i64, 1i64, 2i64], "attr" => vec!["d", "j", "h"])?; + + let graph = GraphFrame { vertices, edges }; + let triplets = graph.triplets().await?; + + // Check schema + let schema = triplets.schema(); + assert_eq!(schema.fields().len(), 3); + assert_eq!(schema.field(0).name(), SRC_VERTEX); + assert_eq!(schema.field(1).name(), EDGE_COL); + assert_eq!(schema.field(2).name(), DST_VERTEX); + assert!( + schema + .field(0) + .data_type() + .eq(&DataType::Struct(Fields::from(vec![ + Field::new(VERTEX_ID, DataType::Int64, true), + Field::new("attr", DataType::Utf8, true) + ]))) + ); + assert!( + schema + .field(1) + .data_type() + .eq(&DataType::Struct(Fields::from(vec![ + Field::new(EDGE_SRC, DataType::Int64, true), + Field::new(EDGE_DST, DataType::Int64, true), + Field::new("attr", DataType::Utf8, true), + ]))) + ); + assert!( + schema + .field(2) + .data_type() + .eq(&DataType::Struct(Fields::from(vec![ + Field::new(VERTEX_ID, DataType::Int64, true), + Field::new("attr", DataType::Utf8, true) + ]))) + ); + + Ok(()) + } }