diff --git a/benchmarks/test_functions.py b/benchmarks/test_functions.py index d8ec00830..cf0efd6f1 100644 --- a/benchmarks/test_functions.py +++ b/benchmarks/test_functions.py @@ -203,3 +203,39 @@ def queries(): eng.execute_and_collect(f"SELECT ST_Perimeter(geom1) from {table}") benchmark(queries) + + @pytest.mark.parametrize( + "eng", [SedonaDBSingleThread, PostGISSingleThread, DuckDBSingleThread] + ) + @pytest.mark.parametrize( + "table", + [ + "collections_simple", + "segments_large", + ], + ) + def test_st_start_point(self, benchmark, eng, table): + eng = self._get_eng(eng) + + def queries(): + eng.execute_and_collect(f"SELECT ST_StartPoint(geom1) from {table}") + + benchmark(queries) + + @pytest.mark.parametrize( + "eng", [SedonaDBSingleThread, PostGISSingleThread, DuckDBSingleThread] + ) + @pytest.mark.parametrize( + "table", + [ + "collections_simple", + "segments_large", + ], + ) + def test_st_end_point(self, benchmark, eng, table): + eng = self._get_eng(eng) + + def queries(): + eng.execute_and_collect(f"SELECT ST_EndPoint(geom1) from {table}") + + benchmark(queries) diff --git a/python/sedonadb/tests/functions/test_functions.py b/python/sedonadb/tests/functions/test_functions.py index 041a0a8e4..da009ff9e 100644 --- a/python/sedonadb/tests/functions/test_functions.py +++ b/python/sedonadb/tests/functions/test_functions.py @@ -1016,6 +1016,55 @@ def test_st_pointm(eng, x, y, m, expected): ) +@pytest.mark.parametrize("eng", [SedonaDB, PostGIS]) +@pytest.mark.parametrize( + ("geometry", "expected"), + [ + ("LINESTRING (1 2, 3 4, 5 6)", "POINT (1 2)"), + ("LINESTRING Z (1 2 3, 3 4 5, 5 6 7)", "POINT Z (1 2 3)"), + ("LINESTRING M (1 2 3, 3 4 5, 5 6 7)", "POINT M (1 2 3)"), + ("LINESTRING ZM (1 2 3 4, 3 4 5 6, 5 6 7 8)", "POINT ZM (1 2 3 4)"), + ("POINT (1 2)", "POINT (1 2)"), + ("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", "POINT (0 0)"), + ("MULTIPOINT (0 0, 10 0, 10 10, 0 10, 0 0)", "POINT (0 0)"), + ("MULTILINESTRING ((1 2, 3 4), (5 6, 7 8))", "POINT (1 2)"), + ("MULTIPOLYGON (((0 0, 10 0, 10 10, 0 10, 0 0)))", "POINT (0 0)"), + ("GEOMETRYCOLLECTION (POINT (1 2), LINESTRING (3 4, 5 6))", "POINT (1 2)"), + ( + "GEOMETRYCOLLECTION (GEOMETRYCOLLECTION (GEOMETRYCOLLECTION (POINT (1 2), LINESTRING (3 4, 5 6))))", + "POINT (1 2)", + ), + ], +) +def test_st_start_point(eng, geometry, expected): + eng = eng.create_or_skip() + eng.assert_query_result( + f"SELECT ST_StartPoint({geom_or_null(geometry)})", + expected, + ) + + +@pytest.mark.parametrize("eng", [SedonaDB, PostGIS]) +@pytest.mark.parametrize( + ("geometry", "expected"), + [ + ("LINESTRING (1 2, 3 4, 5 6)", "POINT (5 6)"), + ("LINESTRING Z (1 2 3, 3 4 5, 5 6 7)", "POINT Z (5 6 7)"), + ("LINESTRING M (1 2 3, 3 4 5, 5 6 7)", "POINT M (5 6 7)"), + ("LINESTRING ZM (1 2 3 4, 3 4 5 6, 5 6 7 8)", "POINT ZM (5 6 7 8)"), + ("POINT (1 2)", None), + ("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", None), + ("MULTILINESTRING ((1 2, 3 4), (5 6, 7 8))", None), + ], +) +def test_st_end_point(eng, geometry, expected): + eng = eng.create_or_skip() + eng.assert_query_result( + f"SELECT ST_EndPoint({geom_or_null(geometry)})", + expected, + ) + + @pytest.mark.parametrize("eng", [SedonaDB, PostGIS]) @pytest.mark.parametrize( ("x", "y", "z", "m", "expected"), diff --git a/rust/sedona-functions/benches/native-functions.rs b/rust/sedona-functions/benches/native-functions.rs index f7d4ed247..af89a4683 100644 --- a/rust/sedona-functions/benches/native-functions.rs +++ b/rust/sedona-functions/benches/native-functions.rs @@ -126,6 +126,22 @@ fn criterion_benchmark(c: &mut Criterion) { ), ); + benchmark::scalar( + c, + &f, + "native", + "st_startpoint", + BenchmarkArgs::Array(LineString(10)), + ); + + benchmark::scalar( + c, + &f, + "native", + "st_endpoint", + BenchmarkArgs::Array(LineString(10)), + ); + benchmark::scalar(c, &f, "native", "st_x", Point); benchmark::scalar(c, &f, "native", "st_y", Point); benchmark::scalar(c, &f, "native", "st_z", Point); diff --git a/rust/sedona-functions/src/lib.rs b/rust/sedona-functions/src/lib.rs index cf7608df4..42ee3dc06 100644 --- a/rust/sedona-functions/src/lib.rs +++ b/rust/sedona-functions/src/lib.rs @@ -49,6 +49,7 @@ mod st_point; mod st_pointzm; mod st_setsrid; mod st_srid; +mod st_start_point; mod st_transform; pub mod st_union_aggr; mod st_xyzm; diff --git a/rust/sedona-functions/src/register.rs b/rust/sedona-functions/src/register.rs index 0e0d87130..30a9007aa 100644 --- a/rust/sedona-functions/src/register.rs +++ b/rust/sedona-functions/src/register.rs @@ -92,6 +92,8 @@ pub fn default_function_set() -> FunctionSet { crate::st_setsrid::st_set_srid_udf, crate::st_srid::st_crs_udf, crate::st_srid::st_srid_udf, + crate::st_start_point::st_end_point_udf, + crate::st_start_point::st_start_point_udf, crate::st_xyzm::st_m_udf, crate::st_xyzm::st_x_udf, crate::st_xyzm::st_y_udf, diff --git a/rust/sedona-functions/src/st_start_point.rs b/rust/sedona-functions/src/st_start_point.rs new file mode 100644 index 000000000..a8726859b --- /dev/null +++ b/rust/sedona-functions/src/st_start_point.rs @@ -0,0 +1,276 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use arrow_array::builder::BinaryBuilder; +use datafusion_common::error::Result; +use datafusion_expr::{ + scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility, +}; +use geo_traits::{ + CoordTrait, GeometryCollectionTrait, GeometryTrait, LineStringTrait, MultiLineStringTrait, + MultiPointTrait, MultiPolygonTrait, PointTrait, PolygonTrait, +}; +use sedona_common::sedona_internal_err; +use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF}; +use sedona_geometry::{ + error::SedonaGeometryError, + wkb_factory::{write_wkb_coord_trait, write_wkb_point_header, WKB_MIN_PROBABLE_BYTES}, +}; +use sedona_schema::{ + datatypes::{SedonaType, WKB_GEOMETRY}, + matchers::ArgMatcher, +}; +use std::{io::Write, sync::Arc}; + +use crate::executor::WkbExecutor; + +/// ST_StartPoint() scalar UDF +/// +/// Native implementation to get the start point of a geometry +pub fn st_start_point_udf() -> SedonaScalarUDF { + SedonaScalarUDF::new( + "st_startpoint", + vec![Arc::new(STStartOrEndPoint::new(true))], + Volatility::Immutable, + Some(st_start_point_doc()), + ) +} + +fn st_start_point_doc() -> Documentation { + Documentation::builder( + DOC_SECTION_OTHER, + "Returns the start point of a LINESTRING geometry. Returns NULL if the geometry is not a LINESTRING.", + "ST_StartPoint (geom: Geometry)", + ) + .with_argument("geom", "geometry: Input geometry") + .with_sql_example("SELECT ST_StartPoint(ST_GeomFromWKT('LINESTRING(0 1, 2 3, 4 5)'))") + .build() +} + +/// ST_EndPoint() scalar UDF +/// +/// Native implementation to get the end point of a geometry +pub fn st_end_point_udf() -> SedonaScalarUDF { + SedonaScalarUDF::new( + "st_endpoint", + vec![Arc::new(STStartOrEndPoint::new(false))], + Volatility::Immutable, + Some(st_end_point_doc()), + ) +} + +fn st_end_point_doc() -> Documentation { + Documentation::builder( + DOC_SECTION_OTHER, + "Returns the end point of a LINESTRING geometry. Returns NULL if the geometry is not a LINESTRING.", + "ST_EndPoint (geom: Geometry)", + ) + .with_argument("geom", "geometry: Input geometry") + .with_sql_example("SELECT ST_EndPoint(ST_GeomFromWKT('LINESTRING(0 1, 2 3, 4 5)'))") + .build() +} + +#[derive(Debug)] +struct STStartOrEndPoint { + from_start: bool, +} + +impl STStartOrEndPoint { + fn new(from_start: bool) -> Self { + STStartOrEndPoint { from_start } + } +} + +impl SedonaScalarKernel for STStartOrEndPoint { + fn return_type(&self, args: &[SedonaType]) -> Result> { + let matcher = ArgMatcher::new(vec![ArgMatcher::is_geometry()], WKB_GEOMETRY); + + matcher.match_args(args) + } + + fn invoke_batch( + &self, + arg_types: &[SedonaType], + args: &[ColumnarValue], + ) -> Result { + let executor = WkbExecutor::new(arg_types, args); + let mut builder = BinaryBuilder::with_capacity( + executor.num_iterations(), + WKB_MIN_PROBABLE_BYTES * executor.num_iterations(), + ); + + executor.execute_wkb_void(|maybe_wkb| { + if let Some(wkb) = maybe_wkb { + if let Some(coord) = extract_first_geometry(&wkb, self.from_start) { + if write_wkb_start_point(&mut builder, coord).is_err() { + return sedona_internal_err!("Failed to write WKB point header"); + }; + builder.append_value([]); + return Ok(()); + } + } + + builder.append_null(); + Ok(()) + })?; + + executor.finish(Arc::new(builder.finish())) + } +} + +fn write_wkb_start_point( + buf: &mut impl Write, + coord: impl CoordTrait, +) -> Result<(), SedonaGeometryError> { + write_wkb_point_header(buf, coord.dim())?; + write_wkb_coord_trait(buf, &coord) +} + +// - ST_StartPoint returns result for all types of geometries +// - ST_EndPoint returns result only for LINESTRING +fn extract_first_geometry<'a>( + wkb: &'a wkb::reader::Wkb<'a>, + from_start: bool, +) -> Option> { + match (wkb.as_type(), from_start) { + (geo_traits::GeometryType::Point(point), true) => point.coord(), + (geo_traits::GeometryType::LineString(line_string), true) => line_string.coord(0), + (geo_traits::GeometryType::LineString(line_string), false) => { + line_string.coord(line_string.num_coords() - 1) + } + (geo_traits::GeometryType::Polygon(polygon), true) => match polygon.exterior() { + Some(ring) => ring.coord(0), + None => None, + }, + (geo_traits::GeometryType::MultiPoint(multi_point), true) => match multi_point.point(0) { + Some(point) => point.coord(), + None => None, + }, + (geo_traits::GeometryType::MultiLineString(multi_line_string), true) => { + match multi_line_string.line_string(0) { + Some(line_string) => line_string.coord(0), + None => None, + } + } + (geo_traits::GeometryType::MultiPolygon(multi_polygon), true) => { + match multi_polygon.polygon(0) { + Some(polygon) => match polygon.exterior() { + Some(ring) => ring.coord(0), + None => None, + }, + None => None, + } + } + (geo_traits::GeometryType::GeometryCollection(geometry_collection), true) => { + match geometry_collection.geometry(0) { + Some(geometry) => extract_first_geometry(geometry, from_start), + None => None, + } + } + (geo_traits::GeometryType::Rect(_), true) => None, + (geo_traits::GeometryType::Triangle(_), true) => None, + (geo_traits::GeometryType::Line(_), true) => None, + _ => None, + } +} + +#[cfg(test)] +mod tests { + use datafusion_expr::ScalarUDF; + use rstest::rstest; + use sedona_schema::datatypes::WKB_VIEW_GEOMETRY; + use sedona_testing::{ + compare::assert_array_equal, create::create_array, testers::ScalarUdfTester, + }; + + use super::*; + + #[test] + fn udf_metadata() { + let st_start_point_udf: ScalarUDF = st_start_point_udf().into(); + assert_eq!(st_start_point_udf.name(), "st_startpoint"); + assert!(st_start_point_udf.documentation().is_some()); + + let st_end_point_udf: ScalarUDF = st_end_point_udf().into(); + assert_eq!(st_end_point_udf.name(), "st_endpoint"); + assert!(st_end_point_udf.documentation().is_some()); + } + + #[rstest] + fn udf(#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] sedona_type: SedonaType) { + let tester_start_point = + ScalarUdfTester::new(st_start_point_udf().into(), vec![sedona_type.clone()]); + let tester_end_point = + ScalarUdfTester::new(st_end_point_udf().into(), vec![sedona_type.clone()]); + + let input = create_array( + &[ + Some("LINESTRING (1 2, 3 4, 5 6)"), + Some("LINESTRING Z (1 2 3, 3 4 5, 5 6 7)"), + Some("LINESTRING M (1 2 3, 3 4 5, 5 6 7)"), + Some("LINESTRING ZM (1 2 3 4, 3 4 5 6, 5 6 7 8)"), + Some("POINT (1 2)"), + Some("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))"), + Some("MULTIPOINT (0 0, 10 0, 10 10, 0 10, 0 0)"), + Some("MULTILINESTRING ((1 2, 3 4), (5 6, 7 8))"), + Some("MULTIPOLYGON (((0 0, 10 0, 10 10, 0 10, 0 0)))"), + Some("GEOMETRYCOLLECTION (POINT (1 2), LINESTRING (3 4, 5 6))"), + None, + ], + &sedona_type, + ); + + let expected_start_point = create_array( + &[ + Some("POINT (1 2)"), + Some("POINT Z (1 2 3)"), + Some("POINT M (1 2 3)"), + Some("POINT ZM (1 2 3 4)"), + Some("POINT (1 2)"), + Some("POINT (0 0)"), + Some("POINT (0 0)"), + Some("POINT (1 2)"), + Some("POINT (0 0)"), + Some("POINT (1 2)"), + None, + ], + &WKB_GEOMETRY, + ); + + let result_start_point = tester_start_point.invoke_array(input.clone()).unwrap(); + assert_array_equal(&result_start_point, &expected_start_point); + + let expected_end_point = create_array( + &[ + Some("POINT (5 6)"), + Some("POINT Z (5 6 7)"), + Some("POINT M (5 6 7)"), + Some("POINT ZM (5 6 7 8)"), + None, + None, + None, + None, + None, + None, + None, + ], + &WKB_GEOMETRY, + ); + + let result_end_point = tester_end_point.invoke_array(input).unwrap(); + assert_array_equal(&result_end_point, &expected_end_point); + } +} diff --git a/rust/sedona-geometry/src/wkb_factory.rs b/rust/sedona-geometry/src/wkb_factory.rs index 000788f0a..9db1d29d3 100644 --- a/rust/sedona-geometry/src/wkb_factory.rs +++ b/rust/sedona-geometry/src/wkb_factory.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. use crate::error::SedonaGeometryError; -use geo_traits::Dimensions; +use geo_traits::{CoordTrait, Dimensions}; use std::io::Write; pub const WKB_MIN_PROBABLE_BYTES: usize = 21; @@ -392,6 +392,37 @@ where Ok(()) } +/// Write a single coordinate of CoordTrait to WKB +/// This function always writes little endian coordinates. +pub fn write_wkb_coord_trait(buf: &mut impl Write, coord: &C) -> Result<(), SedonaGeometryError> +where + C: CoordTrait, +{ + match coord.dim().size() { + 2 => { + let coord_tuple = coord.x_y(); + write_wkb_coord(buf, coord_tuple) + } + 3 => { + let coord_tuple: (::T, _, _) = + (coord.x(), coord.y(), coord.nth_or_panic(2)); + write_wkb_coord(buf, coord_tuple) + } + 4 => { + let coord_tuple = ( + coord.x(), + coord.y(), + coord.nth_or_panic(2), + coord.nth_or_panic(3), + ); + write_wkb_coord(buf, coord_tuple) + } + _ => Err(SedonaGeometryError::Invalid( + "Unsupported number of dimensions".to_string(), + )), + } +} + /// Write multiple coordinates to WKB /// /// This function takes an iterator of coordinates and writes them to the provided buffer. @@ -537,6 +568,31 @@ mod test { check_bytes(&wkb, "POINT ZM(12 13 14 15)"); } + #[test] + fn test_write_wkb_coord_trait() { + let cases = [ + (None, None, "POINT(0 1)"), + (Some(2.0), None, "POINT Z(0 1 2)"), + (None, Some(3.0), "POINT M(0 1 3)"), + (Some(2.0), Some(3.0), "POINT ZM(0 1 2 3)"), + ]; + let mut wkb = vec![]; + + for (z, m, expected) in cases { + let coord = wkt::types::Coord { + x: 0.0, + y: 1.0, + z, + m, + }; + + wkb.clear(); + write_wkb_point_header(&mut wkb, coord.dim()).unwrap(); + write_wkb_coord_trait(&mut wkb, &coord).unwrap(); + check_bytes(&wkb, expected); + } + } + #[test] fn test_wkb_linestring() { let wkt: Wkt = Wkt::from_str("LINESTRING EMPTY").unwrap();