diff --git a/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationTool.tsx b/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationTool.tsx index 3e809406b..8c068c2bc 100644 --- a/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationTool.tsx +++ b/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationTool.tsx @@ -363,14 +363,17 @@ export const createAllRelationShapeTools = ( ); const relation = discourseContext.relations[name].find( - (r) => r.source === target?.type, + (r) => r.source === target?.type || r.destination === target?.type, ); if (relation) { this.shapeType = relation.id; } else { - const acceptableTypes = discourseContext.relations[name].map( - (r) => discourseContext.nodes[r.source].text, - ); + const acceptableTypes = discourseContext.relations[name] + .flatMap((r) => [ + discourseContext.nodes[r.source]?.text, + discourseContext.nodes[r.destination]?.text, + ]) + .filter(Boolean); const uniqueTypes = [...new Set(acceptableTypes)]; this.cancelAndWarn( `Starting node must be one of ${uniqueTypes.join(", ")}`, diff --git a/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationUtil.tsx b/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationUtil.tsx index 198e75916..23d1dbf8a 100644 --- a/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationUtil.tsx +++ b/apps/roam/src/components/canvas/DiscourseRelationShape/DiscourseRelationUtil.tsx @@ -537,14 +537,30 @@ export const createAllRelationShapeUtils = ( const relations = Object.values(discourseContext.relations).flat(); const relation = relations.find((r) => r.id === arrow.type); if (!relation) return; - const possibleTargets = discourseContext.relations[relation.label] - .filter((r) => r.source === relation.source) - .map((r) => r.destination); - if (!possibleTargets.includes(target.type)) { - const uniqueTargets = [...new Set(possibleTargets)]; - const uniqueTargetTexts = uniqueTargets.map( - (t) => discourseContext.nodes[t].text, + const sourceNodeType = source.type; + const targetNodeType = target.type; + + const { isDirect, isReverse } = this.checkConnectionType( + relation, + sourceNodeType, + targetNodeType, + ); + + if (!isDirect && !isReverse) { + const possibleTargets = discourseContext.relations[relation.label] + .filter((r) => r.source === relation.source) + .map((r) => r.destination); + const possibleReverseTargets = discourseContext.relations[ + relation.label + ] + .filter((r) => r.destination === relation.source) + .map((r) => r.source); + const allPossibleTargets = [ + ...new Set([...possibleTargets, ...possibleReverseTargets]), + ]; + const uniqueTargetTexts = allPossibleTargets.map( + (t) => discourseContext.nodes[t]?.text || t, ); return deleteAndWarn( `Target node must be of type ${uniqueTargetTexts.join(", ")}`, @@ -553,6 +569,7 @@ export const createAllRelationShapeUtils = ( if (arrow.type !== target.type) { editor.updateShapes([{ id: arrow.id, type: target.type }]); } + arrow = editor.getShape(arrow.id) as DiscourseRelationShape; if (getSetting("use-reified-relations")) { const sourceAsDNS = asDiscourseNodeShape(source, editor); const targetAsDNS = asDiscourseNodeShape(target, editor); @@ -572,8 +589,9 @@ export const createAllRelationShapeUtils = ( }).catch(() => undefined); } } else { - const { triples, label: relationLabel } = relation; - const isOriginal = arrow.props.text === relationLabel; + const { triples } = relation; + const isOriginal = isDirect && !isReverse; + const newTriples = triples .map((t) => { if (/is a/i.test(t[1])) { @@ -756,6 +774,52 @@ export const createAllRelationShapeUtils = ( return update; } + // Validate target node type compatibility before creating binding + if ( + target.type !== "arrow" && + otherBinding && + target.id !== otherBinding.toId && + (!currentBinding || target.id !== currentBinding.toId) + ) { + const sourceNodeId = otherBinding.toId; + const sourceNode = this.editor.getShape(sourceNodeId); + const targetNodeType = target.type; + const sourceNodeType = sourceNode?.type; + + if (sourceNodeType && targetNodeType && shape.type) { + const isValidConnection = this.isValidNodeConnection( + sourceNodeType, + targetNodeType, + shape.type, + ); + + if (!isValidConnection) { + const sourceNodeTypeText = + discourseContext.nodes[sourceNodeType]?.text || sourceNodeType; + const targetNodeTypeText = + discourseContext.nodes[targetNodeType]?.text || targetNodeType; + const relations = Object.values( + discourseContext.relations, + ).flat(); + const relation = relations.find((r) => r.id === shape.type); + const relationLabel = relation?.label || shape.type; + + const errorMessage = `Cannot connect "${sourceNodeTypeText}" to "${targetNodeTypeText}" with "${relationLabel}" relation`; + dispatchToastEvent({ + id: `tldraw-invalid-connection-${shape.id}`, + title: "Invalid Connection", + description: errorMessage, + severity: "error", + }); + + removeArrowBinding(this.editor, shape, handleId); + update.props![handleId] = { x: handle.x, y: handle.y }; + this.editor.deleteShapes([shape.id]); + return update; + } + } + } + // we've got a target! the handle is being dragged over a shape, bind to it const targetGeometry = this.editor.getShapeGeometry(target); @@ -832,6 +896,37 @@ export const createAllRelationShapeUtils = ( this.editor.setHintingShapes([target.id]); const newBindings = getArrowBindings(this.editor, shape); + + // Check if both ends are bound and determine the correct text based on direction + if (newBindings.start && newBindings.end) { + const relations = Object.values(discourseContext.relations).flat(); + const relation = relations.find((r) => r.id === shape.type); + + if (relation) { + const startNode = this.editor.getShape(newBindings.start.toId); + const endNode = this.editor.getShape(newBindings.end.toId); + + if (startNode && endNode) { + const startNodeType = startNode.type; + const endNodeType = endNode.type; + + const { isDirect, isReverse } = this.checkConnectionType( + relation, + startNodeType, + endNodeType, + ); + + const newText = + isReverse && !isDirect ? relation.complement : relation.label; + + if (shape.props.text !== newText) { + update.props = update.props || {}; + update.props.text = newText; + } + } + } + } + if ( newBindings.start && newBindings.end && @@ -1454,6 +1549,40 @@ export class BaseDiscourseRelationUtil extends ShapeUtil ]; } + checkConnectionType( + relation: { source: string; destination: string }, + sourceNodeType: string, + targetNodeType: string, + ): { isDirect: boolean; isReverse: boolean } { + const isDirect = + sourceNodeType === relation.source && + targetNodeType === relation.destination; + + const isReverse = + sourceNodeType === relation.destination && + targetNodeType === relation.source; + + return { isDirect, isReverse }; + } + + isValidNodeConnection( + sourceNodeType: string, + targetNodeType: string, + relationId: string, + ): boolean { + const relations = Object.values(discourseContext.relations).flat(); + const relation = relations.find((r) => r.id === relationId); + if (!relation) return false; + + const { isDirect, isReverse } = this.checkConnectionType( + relation, + sourceNodeType, + targetNodeType, + ); + + return isDirect || isReverse; + } + component(shape: DiscourseRelationShape) { // eslint-disable-next-line react-hooks/rules-of-hooks // const theme = useDefaultColorTheme();