|
1 | 1 | use pyo3::exceptions::PyValueError; |
2 | 2 | use pyo3::prelude::*; |
3 | 3 | use pyo3::wrap_pyfunction; |
| 4 | +use pythonize::PythonizeError; |
| 5 | +use sqlparser::ast::Statement; |
4 | 6 |
|
| 7 | +use core::ops::ControlFlow; |
5 | 8 | use pythonize::pythonize; |
6 | | - |
| 9 | +use sqlparser::ast::visit_relations; |
7 | 10 | use sqlparser::dialect::*; |
8 | 11 | use sqlparser::parser::Parser; |
9 | 12 |
|
@@ -65,8 +68,53 @@ fn parse_sql(py: Python, sql: &str, dialect: &str) -> PyResult<PyObject> { |
65 | 68 | Ok(output) |
66 | 69 | } |
67 | 70 |
|
| 71 | +/// |
| 72 | +/// Function to extract relations from a parsed query. |
| 73 | +/// Returns a nested list of relations, one list per query statement. |
| 74 | +/// |
| 75 | +/// Example: |
| 76 | +/// ```python |
| 77 | +/// from sqloxide import parse_sql, extract_relations |
| 78 | +/// |
| 79 | +/// sql = "SELECT * FROM table1 JOIN table2 ON table1.id = table2.id" |
| 80 | +/// parsed_query = parse_sql(sql, "generic") |
| 81 | +/// relations = extract_relations(parsed_query) |
| 82 | +/// print(relations) |
| 83 | +/// ``` |
| 84 | +/// |
| 85 | +#[pyfunction] |
| 86 | +#[pyo3(text_signature = "(parsed_query)")] |
| 87 | +fn extract_relations(py: Python, parsed_query: &PyAny) -> PyResult<PyObject> { |
| 88 | + let parse_result: Result<Vec<Statement>, PythonizeError> = pythonize::depythonize(parsed_query); |
| 89 | + |
| 90 | + let mut relations = Vec::new(); |
| 91 | + |
| 92 | + match parse_result { |
| 93 | + Ok(statements) => { |
| 94 | + for statement in statements { |
| 95 | + visit_relations(&statement, |relation| { |
| 96 | + relations.push(relation.clone()); |
| 97 | + ControlFlow::<()>::Continue(()) |
| 98 | + }); |
| 99 | + } |
| 100 | + } |
| 101 | + Err(_e) => { |
| 102 | + let msg = _e.to_string(); |
| 103 | + return Err(PyValueError::new_err(format!( |
| 104 | + "Query serialization failed.\n\t{}", |
| 105 | + msg |
| 106 | + ))); |
| 107 | + } |
| 108 | + }; |
| 109 | + |
| 110 | + let output = pythonize(py, &relations).expect("Internal python deserialization failed."); |
| 111 | + |
| 112 | + Ok(output) |
| 113 | +} |
| 114 | + |
68 | 115 | #[pymodule] |
69 | 116 | fn sqloxide(_py: Python, m: &PyModule) -> PyResult<()> { |
70 | 117 | m.add_function(wrap_pyfunction!(parse_sql, m)?)?; |
| 118 | + m.add_function(wrap_pyfunction!(extract_relations, m)?)?; |
71 | 119 | Ok(()) |
72 | 120 | } |
0 commit comments