Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 147 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ use datafusion::prelude::*;
pub const VERTEX_ID: &str = "id";
pub const EDGE_SRC: &str = "src";
pub const EDGE_DST: &str = "dst";
pub const EDGE_COL: &str = "edge";
pub const SRC_VERTEX: &str = "src_vertex";
pub const DST_VERTEX: &str = "dst_vertex";

#[derive(Debug, Clone)]
pub struct GraphFrame {
vertices: DataFrame,
edges: DataFrame,
pub vertices: DataFrame,
pub edges: DataFrame,
}

impl GraphFrame {
Expand Down Expand Up @@ -43,13 +46,107 @@ impl GraphFrame {
)?;
Ok(df.select(vec![col(EDGE_SRC).alias(VERTEX_ID), col("out_degree")])?)
}

/// Generates a DataFrame containing "triplets" by combining information from edges and vertices.
///
/// This method aggregates data about source vertices, edges, and destination vertices,
/// producing a combined representation of these relationships as triplets.
/// It constructs structured representations of edges and vertices, then performs
/// joins to associate source and destination vertices with their respective edges.
///
/// # Returns
///
/// Returns a `Result<DataFrame>` which can either:
/// - Contain the `DataFrame` representing the triplets (source vertex, edge, destination vertex).
/// - Return an error if an operation (e.g., selection or join) fails during the process.
///
/// Output `DataFrame` contains the following columns:
/// - `SRC_VERTEX` - struct with all the columns of vertices, associated with a source of the triple
/// - `EDGE_COL` - struct with all the columns of edges, associated with an edge
/// - `DST_VERTEX` - struct with all the columns of vertices, associated with a destination of the triplet
///
/// # Errors
///
/// This method will return an error if:
/// - Either the source vertices or destination vertices cannot be joined with edges due to schema mismatches.
/// - Any selection or transformation process internally fails due to invalid queries.
///
/// # Example
///
/// ```
/// use datafusion::dataframe;
/// use graphframes_rs::{GraphFrame, VERTEX_ID, EDGE_SRC, EDGE_DST};
/// let vertices = dataframe!(
/// VERTEX_ID => vec![1i64, 2i64, 3i64],
/// "attr" => vec!["a", "b", "c"]
/// ).unwrap();
/// let edges = dataframe!(
/// EDGE_SRC => vec![1i64, 2i64, 3i64],
/// EDGE_DST => vec![3i64, 1i64, 2i64],
/// "attr" => vec!["d", "j", "h"]
/// ).unwrap();
///
/// let graph = GraphFrame { vertices, edges };
/// let triplets = graph.triplets();
/// ```
/// // Assuming `edges_df` and `vertices_df` are initialized DataFrames for
pub async fn triplets(&self) -> Result<DataFrame> {
let edges_struct = self.edges.clone().select(vec![
col(EDGE_SRC),
col(EDGE_DST),
named_struct(
self.edges
.clone()
.schema()
.fields()
.iter()
.map(|field| field.name())
.flat_map(|name| vec![lit(name), col(name)])
.collect(),
)
.alias(EDGE_COL),
])?;
let vertices_struct = self.vertices.clone().select(vec![
col(VERTEX_ID),
named_struct(
self.vertices
.clone()
.schema()
.fields()
.iter()
.map(|field| field.name())
.flat_map(|name| vec![lit(name), col(name)])
.collect(),
)
.alias("_vertex_struct"),
])?;
edges_struct
.join_on(
vertices_struct.clone().select(vec![
col(VERTEX_ID),
col("_vertex_struct").alias(SRC_VERTEX),
])?,
JoinType::Left,
vec![col(EDGE_SRC).eq(col(VERTEX_ID))],
)?
.select(vec![col(SRC_VERTEX), col(EDGE_DST), col(EDGE_COL)])?
.join_on(
vertices_struct.select(vec![
col(VERTEX_ID),
col("_vertex_struct").alias(DST_VERTEX),
])?,
JoinType::Left,
vec![col(EDGE_DST).eq(col(VERTEX_ID))],
)?
.select(vec![col(SRC_VERTEX), col(EDGE_COL), col(DST_VERTEX)])
}
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion::arrow::array::{Int64Array, RecordBatch, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
use std::collections::HashMap;
use std::sync::Arc;

Expand Down Expand Up @@ -183,4 +280,51 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_triplets() -> Result<()> {
let vertices =
dataframe!(VERTEX_ID => vec![1i64, 2i64, 3i64], "attr" => vec!["a", "b", "c"])?;
let edges = dataframe!(EDGE_SRC => vec![1i64, 2i64, 3i64], EDGE_DST => vec![3i64, 1i64, 2i64], "attr" => vec!["d", "j", "h"])?;

let graph = GraphFrame { vertices, edges };
let triplets = graph.triplets().await?;

// Check schema
let schema = triplets.schema();
assert_eq!(schema.fields().len(), 3);
assert_eq!(schema.field(0).name(), SRC_VERTEX);
assert_eq!(schema.field(1).name(), EDGE_COL);
assert_eq!(schema.field(2).name(), DST_VERTEX);
assert!(
schema
.field(0)
.data_type()
.eq(&DataType::Struct(Fields::from(vec![
Field::new(VERTEX_ID, DataType::Int64, true),
Field::new("attr", DataType::Utf8, true)
])))
);
assert!(
schema
.field(1)
.data_type()
.eq(&DataType::Struct(Fields::from(vec![
Field::new(EDGE_SRC, DataType::Int64, true),
Field::new(EDGE_DST, DataType::Int64, true),
Field::new("attr", DataType::Utf8, true),
])))
);
assert!(
schema
.field(2)
.data_type()
.eq(&DataType::Struct(Fields::from(vec![
Field::new(VERTEX_ID, DataType::Int64, true),
Field::new("attr", DataType::Utf8, true)
])))
);

Ok(())
}
}
Loading