diff --git a/qi-lib/flow/compiler.rkt b/qi-lib/flow/compiler.rkt index 823452eb4..733041501 100644 --- a/qi-lib/flow/compiler.rkt +++ b/qi-lib/flow/compiler.rkt @@ -91,7 +91,7 @@ ;;; Core routing elements [(~or* (~datum ⏚) (~datum ground)) - #'(qi0->racket (select))] + #'*->1] [((~or* (~datum ~>) (~datum thread)) onex:clause ...) #`(compose . #,(reverse (syntax->list @@ -101,15 +101,10 @@ #'(qi0->racket (~> ▽ reverse △))] [((~or* (~datum ==) (~datum relay)) onex:clause ...) #'(relay (qi0->racket onex) ...)] - [((~or* (~datum ==*) (~datum relay*)) onex:clause ... rest-onex:clause) - (with-syntax ([len #`#,(length (syntax->list #'(onex ...)))]) - #'(qi0->racket (group len (== onex ...) rest-onex) ))] + [((~or* (~datum ==*) (~datum relay*)) onex:clause ...) + #'(relay* (qi0->racket onex) ...)] [((~or* (~datum -<) (~datum tee)) onex:clause ...) - #'(λ args - (apply values - (append (values->list - (apply (qi0->racket onex) args)) - ...)))] + #'(tee (qi0->racket onex) ...)] [e:select-form (select-parser #'e)] [e:block-form (block-parser #'e)] [((~datum bundle) (n:number ...) @@ -308,6 +303,7 @@ the DSL. (define (select-parser stx) (syntax-parse stx + [(_) #'*->1] [(_ n:number ...) #'(qi0->racket (-< (esc (arg n)) ...))] [(_ arg ...) ; error handling catch-all (report-syntax-error 'select @@ -488,17 +484,14 @@ the DSL. (define (fanout-parser stx) (syntax-parse stx [_:id #'repeat-values] + [(_ 0) #'*->1] [(_ n:number) ;; a slightly more efficient compile-time implementation ;; for literally indicated N #`(λ args (apply values (append #,@(make-list (syntax->datum #'n) 'args))) )] - [(_ n:expr) - #'(lambda args - (apply values - (apply append - (make-list n args))))])) + [(_ e:expr) #`(let ([n e]) (#,fanout-parser n))])) (define (feedback-parser stx) (syntax-parse stx diff --git a/qi-lib/flow/impl.rkt b/qi-lib/flow/impl.rkt index 16df7327c..ce9c1f37c 100644 --- a/qi-lib/flow/impl.rkt +++ b/qi-lib/flow/impl.rkt @@ -10,7 +10,11 @@ map-values filter-values partition-values + 1->1 + *->1 relay + relay* + tee loom-compose parity-xor arg @@ -27,7 +31,8 @@ (require racket/match (only-in racket/function const - negate) + negate + arity-includes?) racket/bool racket/list racket/format @@ -198,6 +203,62 @@ (append (values->list (apply op vs)) (apply zip-with op (map rest seqs)))))) +(define split-input + (λ (n arity*) + (define report-arity-error + (λ () + (raise-arguments-error + 'split-input + (string-append + "arity mismatch;\n" + " the expected number of arguments does not match the given number") + "given" n))) + (define len (length arity*)) + (define-values (m a*) + (for/fold ([m n] [a* '()]) + ([arity (in-list arity*)] + [i (in-naturals)]) + (if (= 1 (- len i)) + (match arity + [(? exact-nonnegative-integer? n) + (values (- m n) a*)] + [(or (arity-at-least n) + (list* n _)) + (values (- m n) `([,i ,n ,arity] . ,a*))]) + (match arity + [(? exact-nonnegative-integer? n) + (values (- m n) a*)] + [(arity-at-least 0) + (values (- m 1) `([,i 1 ,arity] . ,a*))] + [(or (arity-at-least n) + (list* 0 (arity-at-least n)) + (list* 0 n _) + (list* n _)) + (values (- m n) `([,i ,n ,arity] . ,a*))])))) + (unless (>= m 0) + (report-arity-error)) + (apply list-set* + arity* + (for/fold ([m m] [pairs '()] #:result (if (zero? m) pairs (report-arity-error))) + ([a (in-list a*)]) + (define-values (i n arity) (apply values a)) + (cond + [(zero? m) + (values 0 (list* i n pairs))] + [(arity-includes? arity (+ n m)) + (values 0 (list* i (+ n m) pairs))] + [(arity-at-least? arity) + (report-arity-error)] + [(list? arity) + (match (last arity) + [(? arity-at-least?) + (report-arity-error)] + [(? exact-nonnegative-integer? j) + (values (- m j) (list* i (+ n j) pairs))])]))))) + +(define 1->1 (λ () (values))) +(define *->1 (λ _ (values))) + ;; from mischief/function - requiring it runs aground ;; of some "name is protected" error while building docs, not sure why; ;; so including the implementation directly here for now @@ -207,8 +268,40 @@ (keyword-apply f ks vs xs)))) (define (relay . fs) - (λ args - (apply values (zip-with call fs args)))) + (if (null? fs) + 1->1 + (λ args (apply values (zip-with call fs args))))) + +(define (relay* . fs) + (let ([fs (remq* (list 1->1) fs)]) + (if (null? fs) + 1->1 + (λ args + (define args* + (for/fold ([a '()] [a* args] #:result (reverse a)) + ([i (in-list (split-input (length args) (map procedure-arity fs)))]) + (define-values (v v*) (split-at a* i)) + (values (cons v a) v*))) + (apply values + (append* + (for/list ([f (in-list fs)] + [args (in-list args*)]) + (values->list + (match* ((procedure-arity f) args) + [(0 '()) (f)] + [(1 `(,v0)) (f v0)] + [(2 `(,v0 ,v1)) (f v0 v1)] + [(_ _) (apply f args)]))))))))) + +(define (tee . fs) + (let ([fs (remq* (list *->1) fs)]) + (if (null? fs) + *->1 + (λ args + (apply values + (append* + (for/list ([f (in-list fs)]) + (values->list (apply f args))))))))) (define (~all? . args) (match args diff --git a/qi-test/tests/flow.rkt b/qi-test/tests/flow.rkt index 7f41fe490..0fc4cd633 100644 --- a/qi-test/tests/flow.rkt +++ b/qi-test/tests/flow.rkt @@ -484,18 +484,34 @@ 5 7) (list 25 8) "named relay form")) - (test-suite - "==*" - (check-equal? ((☯ (~> (==* add1 sub1 +) ▽)) - 1 1 1 1 1) - (list 2 0 3)) - (check-equal? ((☯ (~> (==* add1 sub1 +) ▽)) - 1 1) - (list 2 0 0)) - (check-equal? ((☯ (~> (relay* add1 sub1 +) ▽)) - 1 1 1 1 1) - (list 2 0 3) - "named relay* form")) + (let ([add (procedure-reduce-arity + 2)] + [mul (procedure-reduce-arity * 2)] + [id (procedure-reduce-arity values 1)]) + (test-suite + "==*" + (check-equal? ((☯ (~> (==* add1 sub1 +) ▽)) + 1 1 1 1 1) + (list 2 0 3)) + (check-equal? ((☯ (~> (==* add1 + sub1) ▽)) + 1 1 1 1 1) + (list 2 3 0)) + (check-equal? ((☯ (~> (==* add1 + + sub1) ▽)) + 1 1 1 1 1) + (list 2 1 2 0)) + (check-equal? ((☯ (~> (==* add1 sub1 +) ▽)) + 1 1) + (list 2 0 0)) + (check-equal? ((☯ ; x y + (~> (-< 1> 1> 1> 2> 3) ; x x x y 3 + (==* mul mul id) ; x*x x*y 3 + (==* id mul) ; x*x x*y*3 + add)) ; x*x+x*y*3 + 3 4) + 45) + (check-equal? ((☯ (~> (relay* add1 sub1 +) ▽)) + 1 1 1 1 1) + (list 2 0 3) + "named relay* form"))) (test-suite "ground" (check-equal? ((☯ (-< ⏚ add1))