Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions fusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
edition = "2024"
rust-version = "1.85"
name = "stardust-xr-fusion"
version.workspace = true
authors = ["Nova King <technobaboo@proton.me>"]
description = "High level client library for the Stardust XR display server"
version.workspace = true
license.workspace = true
repository.workspace = true
homepage.workspace = true
Expand Down Expand Up @@ -37,14 +37,13 @@ zbus = { version = "5.11.0", features = [
"tokio",
], default-features = false }
stardust-xr-wire = { version = "0.50.0", path = "../wire" }
stardust-xr-gluon = { version = "2.0.0", path = "../gluon" }
stardust-xr-gluon = { version = "0.50.0", path = "../gluon" }

[build-dependencies]
convert_case = "0.8.0"
quote = "1.0.33"
stardust-xr-protocol = { version = "2.0.0", path = "../protocol" }
stardust-xr-protocol = { version = "0.50.0", path = "../protocol" }
proc-macro2 = "1.0.71"
clap = { version = "4.4", features = ["derive"] }
color-eyre = "0.6"
syn = "2.0.106"
prettyplease = "0.2.37"
Expand Down
149 changes: 114 additions & 35 deletions fusion/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@ use convert_case::{Case, Casing};
use proc_macro2::{Ident, Span, TokenStream};
use quote::{ToTokens, quote};
use stardust_xr_protocol::*;
use std::collections::HashMap;
use std::env;
use std::fs;
use std::path::Path;
use std::path::PathBuf;
use std::sync::LazyLock;
use std::sync::Mutex;

fn main() {
// Watch for changes to KDL schema files
let schema_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap())
Expand Down Expand Up @@ -106,6 +110,48 @@ pub fn get_all_protocols() -> Vec<ProtocolInfo> {
]
}

static CLONE_PARTIAL_EQ_CACHE: LazyLock<Mutex<HashMap<String, bool>>> =
LazyLock::new(Default::default);
fn can_impl_clone_and_partial_eq(arg: &ArgumentType) -> bool {
match &arg {
ArgumentType::Union(name) => *CLONE_PARTIAL_EQ_CACHE.lock().unwrap().get(name).unwrap(),
ArgumentType::Struct(name) => *CLONE_PARTIAL_EQ_CACHE.lock().unwrap().get(name).unwrap(),
ArgumentType::Fd => false,
_ => true,
}
}
fn setup_impl_clone_and_partial_eq(protocol_definitions: &[&Protocol], arg: &ArgumentType) -> bool {
match &arg {
ArgumentType::Union(name) | ArgumentType::Struct(name) => {
if let Some(v) = CLONE_PARTIAL_EQ_CACHE.lock().unwrap().get(name) {
return *v;
};
let v = protocol_definitions
.iter()
.flat_map(|v| v.custom_unions.iter())
.filter(|v| &v.name == name)
.flat_map(|v| v.options.iter())
.map(|v| &v._type)
.chain(
protocol_definitions
.iter()
.flat_map(|v| v.custom_structs.iter())
.filter(|v| &v.name == name)
.flat_map(|v| v.fields.iter())
.map(|v| &v._type),
)
.all(|v| setup_impl_clone_and_partial_eq(protocol_definitions, v));
CLONE_PARTIAL_EQ_CACHE
.lock()
.unwrap()
.insert(name.clone(), v);
v
}
ArgumentType::Fd => false,
_ => true,
}
}

pub fn generate_protocol_file(
protocols: Vec<ProtocolInfo>,
file_path: &Path,
Expand All @@ -118,6 +164,25 @@ pub fn generate_protocol_file(

let mut protocol_definitions = protocols.iter_mut().map(|(p, _)| p).collect::<Vec<_>>();
stardust_xr_protocol::resolve_inherits(&mut protocol_definitions).unwrap();
let protocol_definitions = protocols.iter().map(|(p, _)| p).collect::<Vec<_>>();
protocol_definitions
.iter()
.flat_map(|v| v.custom_structs.iter())
.for_each(|v| {
setup_impl_clone_and_partial_eq(
&protocol_definitions,
&ArgumentType::Struct(v.name.clone()),
);
});
protocol_definitions
.iter()
.flat_map(|v| v.custom_unions.iter())
.for_each(|v| {
setup_impl_clone_and_partial_eq(
&protocol_definitions,
&ArgumentType::Union(v.name.clone()),
);
});

// panic!("{protocol_definitions:# ?}");

Expand Down Expand Up @@ -258,10 +323,19 @@ impl Tokenize for CustomUnion {
.iter()
.map(|e| e.tokenize(_generate_node, partial_eq));

let derive = if partial_eq {
quote!( #[derive(Debug, Clone, PartialEq, serde::Deserialize, serde::Serialize)] )
} else {
quote!( #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] )
let derive = match (
partial_eq,
can_impl_clone_and_partial_eq(&ArgumentType::Enum(self.name.clone())),
) {
(true, true) => {
quote!( #[derive(Debug, Clone, PartialEq, serde::Deserialize, serde::Serialize)] )
}
(false, true) => {
quote!( #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] )
}
(_, false) => {
quote!( #[derive(Debug, serde::Deserialize, serde::Serialize)] )
}
};
quote! {
#[doc = #description]
Expand Down Expand Up @@ -322,16 +396,21 @@ impl Tokenize for CustomStruct {
let name = Ident::new(&self.name.to_case(Case::Pascal), Span::call_site());
let description = &self.description;

let argument_decls = self
.fields
.iter()
.map(|a| generate_argument_decl(a, true))
.map(|d| quote!(pub #d));
let argument_decls = self.fields.iter().map(|a| generate_pub_field_decl(a, true));

let derive = if partial_eq {
quote!( #[derive(Debug, Clone, PartialEq, serde::Deserialize, serde::Serialize)] )
} else {
quote!( #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] )
let derive = match (
partial_eq,
can_impl_clone_and_partial_eq(&ArgumentType::Struct(self.name.clone())),
) {
(true, true) => {
quote!( #[derive(Debug, Clone, PartialEq, serde::Deserialize, serde::Serialize)] )
}
(false, true) => {
quote!( #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] )
}
(_, false) => {
quote!( #[derive(Debug, serde::Deserialize, serde::Serialize)] )
}
};
quote! {
#[doc = #description]
Expand Down Expand Up @@ -643,7 +722,7 @@ fn generate_event_sender_impl(aspect: &Aspect) -> TokenStream {
member._type,
quote! {
#opcode => {
let (#(#field_names),*): (#(#deserialize_types),*) = stardust_xr_wire::flex::deserialize(_data)?;
let (#(#field_names),*): (#(#deserialize_types),*) = stardust_xr_wire::flex::deserialize(_data, _fds)?;

#debug
Ok(#event_name::#variant_name { #(#field_uses,)* #response_sender })
Expand Down Expand Up @@ -732,19 +811,17 @@ fn generate_server_member(
MemberType::Signal => {
let mut body = if let Some(interface_node_id) = &interface_node_id {
quote! {
let mut _fds = Vec::new();
let data = (#(#argument_uses),*);
let serialized_data = stardust_xr_wire::flex::serialize(&data)?;
_client.message_sender_handle.signal(#interface_node_id, #aspect_id, #opcode, &serialized_data, _fds)?;
let (serialized_data, fds) = stardust_xr_wire::flex::serialize(&data)?;
_client.message_sender_handle.signal(#interface_node_id, #aspect_id, #opcode, &serialized_data, fds)?;

let (#(#arguments),*) = data;
tracing::trace!(#arguments_debug "Sent signal to server, {}::{}", #aspect_name, #name_str);
}
} else {
quote! {
let mut _fds = Vec::new();
let data = (#(#argument_uses),*);
self.node().send_signal(#aspect_id, #opcode, &data, _fds)?;
self.node().send_signal(#aspect_id, #opcode, &data)?;

let (#(#arguments),*) = data;
tracing::trace!(#arguments_debug "Sent signal to server, {}::{}", #aspect_name, #name_str);
Expand Down Expand Up @@ -795,28 +872,26 @@ fn generate_server_member(
let deserialize = generate_argument_deserialize("result", &argument_type, false);
let body = if let Some(interface_node_id) = &interface_node_id {
quote! {
let mut _fds = Vec::new();
let data = (#(#argument_uses),*);
{
let (#(#arguments),*) = &data;
tracing::trace!(#arguments_debug "Called method on server, {}::{}", #aspect_name, #name_str);
}
let serialized_data = stardust_xr_wire::flex::serialize(&data)?;
let message = _client.message_sender_handle.method(#interface_node_id, #aspect_id, #opcode, &serialized_data, _fds).await?.map_err(|e| crate::node::NodeError::ReturnedError { e })?.into_message();
let result: #deserializeable_type = stardust_xr_wire::flex::deserialize(&message)?;
let (serialized_data, fds) = stardust_xr_wire::flex::serialize(&data)?;
let (message, message_fds) = _client.message_sender_handle.method(#interface_node_id, #aspect_id, #opcode, &serialized_data, fds).await?.map_err(|e| crate::node::NodeError::ReturnedError { e })?.into_components();
let result: #deserializeable_type = stardust_xr_wire::flex::deserialize(&message, message_fds)?;
let deserialized = #deserialize;
tracing::trace!("return" = ?deserialized, "Method return from server, {}::{}", #aspect_name, #name_str);
Ok(deserialized)
}
} else {
quote! {{
let mut _fds = Vec::new();
let data = (#(#argument_uses),*);
{
let (#(#arguments),*) = &data;
tracing::trace!(#arguments_debug "Called method on server, {}::{}", #aspect_name, #name_str);
}
let result: #deserializeable_type = self.node().call_method(#aspect_id, #opcode, &data, _fds).await?;
let result: #deserializeable_type = self.node().call_method(#aspect_id, #opcode, &data).await?;
let deserialized = #deserialize;
tracing::trace!("return" = ?deserialized, "Method return from server, {}::{}", #aspect_name, #name_str);
Ok(deserialized)
Expand Down Expand Up @@ -869,9 +944,6 @@ fn generate_argument_deserialize(
let mapping = generate_argument_deserialize("a", v, false);
quote!(#name.into_iter().map(|(k, a)| Ok((k, #mapping))).collect::<Result<stardust_xr_wire::values::Map<String, _>, crate::node::NodeError>>()?)
}
ArgumentType::Fd => {
quote!(_fds.remove(0))
}
_ => quote!(#name),
}
}
Expand Down Expand Up @@ -928,12 +1000,6 @@ fn generate_argument_serialize(
let mapping = generate_argument_serialize("a", v, false);
quote!(#name.iter().map(|(k, a)| Ok((k, #mapping))).collect::<crate::node::NodeResult<rustc_hash::FxHashMap<String, _>>>()?)
}
ArgumentType::Fd => {
quote!({
_fds.push(#name);
(_fds.len() - 1) as u32
})
}
_ => quote!(#name),
}
}
Expand All @@ -945,6 +1011,19 @@ fn generate_argument_decl(argument: &Argument, returned: bool) -> TokenStream {
}
quote!(#name: #_type)
}
fn generate_pub_field_decl(argument: &Argument, returned: bool) -> TokenStream {
let name = Ident::new(&argument.name.to_case(Case::Snake), Span::call_site());
let mut _type = generate_argument_type(&argument._type, returned);
if argument.optional {
_type = quote!(Option<#_type>);
}
let description = argument
.description
.as_ref()
.map(|d| quote!(#[doc = #d]))
.unwrap_or_default();
quote!(#description pub #name: #_type)
}
fn generate_argument_type(argument_type: &ArgumentType, owned: bool) -> TokenStream {
match argument_type {
ArgumentType::Empty => quote!(()),
Expand Down Expand Up @@ -1057,7 +1136,7 @@ fn generate_argument_type(argument_type: &ArgumentType, owned: bool) -> TokenStr
}
}
ArgumentType::Fd => {
quote!(std::os::unix::io::OwnedFd)
quote!(stardust_xr_wire::fd::ProtocolFd)
}
}
}
15 changes: 6 additions & 9 deletions fusion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use serde::Serialize;
use stardust_xr_wire::{flex::serialize, messenger::MethodResponse, scenegraph::ScenegraphError};
use std::{error::Error, fmt::Debug, marker::PhantomData, os::fd::OwnedFd};
use std::{error::Error, fmt::Debug, marker::PhantomData};

pub use client::*;
pub use stardust_xr_gluon::*;
Expand Down Expand Up @@ -44,23 +44,20 @@ impl<T: Serialize> TypedMethodResponse<T> {
return;
}
};
let Ok(serialized) = stardust_xr_wire::flex::serialize(data) else {
let Ok((serialized, fds)) = stardust_xr_wire::flex::serialize(data) else {
self.0.send(Err(ScenegraphError::MemberError {
error: "Internal: Failed to serialize".to_string(),
}));
return;
};
self.0.send(Ok((&serialized, Vec::<OwnedFd>::new())));
self.0.send(Ok((&serialized, fds)));
}
pub fn wrap<E: Error, F: FnOnce() -> Result<T, E>>(self, f: F) {
self.send(f())
}
pub fn wrap_async<E: Error>(
self,
f: impl Future<Output = Result<(T, Vec<OwnedFd>), E>> + Send + 'static,
) {
pub fn wrap_async<E: Error>(self, f: impl Future<Output = Result<T, E>> + Send + 'static) {
tokio::task::spawn(async move {
let (value, fds) = match f.await {
let value = match f.await {
Ok(d) => d,
Err(e) => {
self.0.send(Err(ScenegraphError::MemberError {
Expand All @@ -69,7 +66,7 @@ impl<T: Serialize> TypedMethodResponse<T> {
return;
}
};
let Ok(serialized) = serialize(value) else {
let Ok((serialized, fds)) = serialize(value) else {
self.0.send(Err(ScenegraphError::MemberError {
error: "Internal: Failed to serialize".to_string(),
}));
Expand Down
11 changes: 5 additions & 6 deletions fusion/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use stardust_xr_wire::{
messenger::MessengerError,
scenegraph::ScenegraphError,
};
use std::{fmt::Debug, os::fd::OwnedFd, sync::Arc, vec::Vec};
use std::{fmt::Debug, sync::Arc};
use thiserror::Error;

pub use crate::protocol::node::*;
Expand Down Expand Up @@ -106,9 +106,8 @@ impl NodeCore {
aspect: u64,
signal: u64,
data: &S,
fds: Vec<OwnedFd>,
) -> Result<(), NodeError> {
let serialized = serialize(data).map_err(|e| NodeError::Serialization { e })?;
let (serialized, fds) = serialize(data).map_err(|e| NodeError::Serialization { e })?;
self.client
.message_sender_handle
.signal(self.id, aspect, signal, &serialized, fds)
Expand All @@ -124,9 +123,8 @@ impl NodeCore {
aspect: u64,
method: u64,
data: &S,
fds: Vec<OwnedFd>,
) -> Result<D, NodeError> {
let serialized = serialize(data).map_err(|e| NodeError::Serialization { e })?;
let (serialized, fds) = serialize(data).map_err(|e| NodeError::Serialization { e })?;

let response = self
.client
Expand All @@ -139,7 +137,8 @@ impl NodeCore {
})?
.map_err(|e| NodeError::ReturnedError { e })?;

deserialize(&response.into_message()).map_err(|e| NodeError::Deserialization { e })
let (response, fds) = response.into_components();
deserialize(&response, fds).map_err(|e| NodeError::Deserialization { e })
}
}
impl NodeType for NodeCore {
Expand Down
Loading