Skip to content

Commit bb62fe0

Browse files
chore(internal): add utils methods for parsing SSE (#90)
1 parent 09ec55d commit bb62fe0

File tree

11 files changed

+504
-105
lines changed

11 files changed

+504
-105
lines changed

lib/orb/base_client.rb

Lines changed: 33 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class << self
2828
# @raise [ArgumentError]
2929
#
3030
def validate!(req)
31-
keys = [:method, :path, :query, :headers, :body, :unwrap, :page, :model, :options]
31+
keys = [:method, :path, :query, :headers, :body, :unwrap, :page, :stream, :model, :options]
3232
case req
3333
in Hash
3434
req.each_key do |k|
@@ -201,6 +201,8 @@ def initialize(
201201
#
202202
# @option req [Class, nil] :page
203203
#
204+
# @option req [Class, nil] :stream
205+
#
204206
# @option req [Orb::Converter, Class, nil] :model
205207
#
206208
# @param opts [Hash{Symbol=>Object}] .
@@ -319,7 +321,7 @@ def initialize(
319321
# @param send_retry_header [Boolean]
320322
#
321323
# @raise [Orb::APIError]
322-
# @return [Array(Net::HTTPResponse, Enumerable)]
324+
# @return [Array(Integer, Net::HTTPResponse, Enumerable)]
323325
#
324326
private def send_request(request, redirect_count:, retry_count:, send_retry_header:)
325327
url, headers, max_retries, timeout = request.fetch_values(:url, :headers, :max_retries, :timeout)
@@ -342,7 +344,7 @@ def initialize(
342344

343345
case status
344346
in ..299
345-
[response, stream]
347+
[status, response, stream]
346348
in 300..399 if redirect_count >= self.class::MAX_REDIRECTS
347349
message = "Failed to complete the request within #{self.class::MAX_REDIRECTS} redirects."
348350

@@ -360,13 +362,15 @@ def initialize(
360362
)
361363
in Orb::APIConnectionError if retry_count >= max_retries
362364
raise status
363-
in (400..) if retry_count >= max_retries || (response && !self.class.should_retry?(
364-
status,
365-
headers: response
366-
))
365+
in (400..) if retry_count >= max_retries || !self.class.should_retry?(status, headers: response)
367366
decoded = Orb::Util.decode_content(response, stream: stream, suppress_error: true)
368367

369-
stream.each { srv_fault ? break : next }
368+
if srv_fault
369+
Orb::Util.close_fused!(stream)
370+
else
371+
stream.each { next }
372+
end
373+
370374
raise Orb::APIStatusError.for(
371375
url: url,
372376
status: status,
@@ -377,7 +381,11 @@ def initialize(
377381
in (400..) | Orb::APIConnectionError
378382
delay = retry_delay(response, retry_count: retry_count)
379383

380-
stream&.each { srv_fault ? break : next }
384+
if srv_fault
385+
Orb::Util.close_fused!(stream)
386+
else
387+
stream&.each { next }
388+
end
381389
sleep(delay)
382390

383391
send_request(
@@ -389,48 +397,6 @@ def initialize(
389397
end
390398
end
391399

392-
# @private
393-
#
394-
# @param req [Hash{Symbol=>Object}] .
395-
#
396-
# @option req [Symbol] :method
397-
#
398-
# @option req [String, Array<String>] :path
399-
#
400-
# @option req [Hash{String=>Array<String>, String, nil}, nil] :query
401-
#
402-
# @option req [Hash{String=>String, Integer, Array<String, Integer, nil>, nil}, nil] :headers
403-
#
404-
# @option req [Object, nil] :body
405-
#
406-
# @option req [Symbol, nil] :unwrap
407-
#
408-
# @option req [Class, nil] :page
409-
#
410-
# @option req [Orb::Converter, Class, nil] :model
411-
#
412-
# @option req [Orb::RequestOptions, Hash{Symbol=>Object}, nil] :options
413-
#
414-
# @param headers [Hash{String=>String}, Net::HTTPHeader]
415-
#
416-
# @param stream [Enumerable]
417-
#
418-
# @return [Object]
419-
#
420-
private def parse_response(req, headers:, stream:)
421-
decoded = Orb::Util.decode_content(headers, stream: stream)
422-
unwrapped = Orb::Util.dig(decoded, req[:unwrap])
423-
424-
case [req[:page], req.fetch(:model, Orb::Unknown)]
425-
in [Class => page, _]
426-
page.new(client: self, req: req, headers: headers, unwrapped: unwrapped)
427-
in [nil, Class | Orb::Converter => model]
428-
Orb::Converter.coerce(model, unwrapped)
429-
in [nil, nil]
430-
unwrapped
431-
end
432-
end
433-
434400
# Execute the request specified by `req`. This is the method that all resource
435401
# methods call into.
436402
#
@@ -450,6 +416,8 @@ def initialize(
450416
#
451417
# @option req [Class, nil] :page
452418
#
419+
# @option req [Class, nil] :stream
420+
#
453421
# @option req [Orb::Converter, Class, nil] :model
454422
#
455423
# @option req [Orb::RequestOptions, Hash{Symbol=>Object}, nil] :options
@@ -459,19 +427,31 @@ def initialize(
459427
#
460428
def request(req)
461429
self.class.validate!(req)
430+
model = req.fetch(:model) { Orb::Unknown }
462431
opts = req[:options].to_h
463432
Orb::RequestOptions.validate!(opts)
464433
request = build_request(req.except(:options), opts)
434+
url = request.fetch(:url)
465435

466436
# Don't send the current retry count in the headers if the caller modified the header defaults.
467437
send_retry_header = request.fetch(:headers)["x-stainless-retry-count"] == "0"
468-
response, stream = send_request(
438+
status, response, stream = send_request(
469439
request,
470440
redirect_count: 0,
471441
retry_count: 0,
472442
send_retry_header: send_retry_header
473443
)
474-
parse_response(req, headers: response, stream: stream)
444+
445+
decoded = Orb::Util.decode_content(response, stream: stream)
446+
case req
447+
in { stream: Class => st }
448+
st.new(model: model, url: url, status: status, response: response, messages: decoded)
449+
in { page: Class => page }
450+
page.new(client: self, req: req, headers: response, unwrapped: decoded)
451+
else
452+
unwrapped = Orb::Util.dig(decoded, req[:unwrap])
453+
Orb::Converter.coerce(model, unwrapped)
454+
end
475455
end
476456

477457
# @return [String]

lib/orb/errors.rb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,13 @@ class APIStatusError < Orb::APIError
9999
# @param body [Object, nil]
100100
# @param request [nil]
101101
# @param response [nil]
102+
# @param message [String, nil]
102103
#
103104
# @return [Orb::APIStatusError]
104105
#
105-
def self.for(url:, status:, body:, request:, response:)
106+
def self.for(url:, status:, body:, request:, response:, message: nil)
106107
key = Orb::Util.dig(body, :type)
107-
kwargs = {url: url, status: status, body: body, request: request, response: response}
108+
kwargs = {url: url, status: status, body: body, request: request, response: response, message: message}
108109

109110
case [status, key]
110111
in [400, Orb::ConstraintViolation::TYPE]

lib/orb/pooled_net_requester.rb

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,33 +131,38 @@ def execute(request)
131131
req = self.class.build_request(request)
132132

133133
eof = false
134+
finished = false
134135
enum = Enumerator.new do |y|
135136
with_pool(url) do |conn|
137+
next if finished
138+
136139
self.class.calibrate_socket_timeout(conn, deadline)
137140
conn.start unless conn.started?
138141

139142
self.class.calibrate_socket_timeout(conn, deadline)
140143
conn.request(req) do |rsp|
141144
y << [conn, rsp]
145+
break if finished
146+
142147
rsp.read_body do |bytes|
143148
y << bytes
149+
break if finished
150+
144151
self.class.calibrate_socket_timeout(conn, deadline)
145152
end
146153
eof = true
147154
end
148155
end
149156
end
150157

151-
# need to protect the `Enumerator` against `#.rewind`
152-
fused = false
153158
conn, response = enum.next
154-
body = Enumerator.new do |y|
155-
next if fused
156-
157-
fused = true
158-
loop { y << enum.next }
159-
ensure
160-
conn.finish if !eof && conn.started?
159+
body = Orb::Util.fused_enum(enum) do
160+
finished = true
161+
tap do
162+
enum.next
163+
rescue StopIteration
164+
end
165+
conn.finish if !eof && conn&.started?
161166
end
162167
[response, (response.body = body)]
163168
end

lib/orb/util.rb

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,9 @@ def encode_content(headers, body)
496496
#
497497
def decode_content(headers, stream:, suppress_error: false)
498498
case headers["content-type"]
499+
in %r{^text/event-stream}
500+
lines = enum_lines(stream)
501+
parse_sse(lines)
499502
in %r{^application/json}
500503
json = stream.to_a.join
501504
begin
@@ -512,6 +515,121 @@ def decode_content(headers, stream:, suppress_error: false)
512515
end
513516
end
514517
end
518+
519+
class << self
520+
# @private
521+
#
522+
# https://doc.rust-lang.org/std/iter/trait.FusedIterator.html
523+
#
524+
# @param enum [Enumerable]
525+
# @param close [Proc]
526+
#
527+
# @return [Enumerable]
528+
#
529+
def fused_enum(enum, &close)
530+
fused = false
531+
iter = Enumerator.new do |y|
532+
next if fused
533+
534+
fused = true
535+
loop { y << enum.next }
536+
ensure
537+
close&.call
538+
close = nil
539+
end
540+
541+
iter.define_singleton_method(:rewind) do
542+
fused = true
543+
self
544+
end
545+
iter
546+
end
547+
548+
# @private
549+
#
550+
# @param enum [Enumerable, nil]
551+
#
552+
def close_fused!(enum)
553+
return unless enum.is_a?(Enumerator)
554+
555+
# rubocop:disable Lint/UnreachableLoop
556+
enum.rewind.each { break }
557+
# rubocop:enable Lint/UnreachableLoop
558+
end
559+
560+
# @private
561+
#
562+
# @param enum [Enumerable, nil]
563+
# @param blk [Proc]
564+
#
565+
def chain_fused(enum, &blk)
566+
iter = Enumerator.new { blk.call(_1) }
567+
fused_enum(iter) { close_fused!(enum) }
568+
end
569+
end
570+
571+
class << self
572+
# @private
573+
#
574+
# @param enum [Enumerable]
575+
#
576+
# @return [Enumerable]
577+
#
578+
def enum_lines(enum)
579+
chain_fused(enum) do |y|
580+
buffer = String.new
581+
enum.each do |row|
582+
buffer << row
583+
while (idx = buffer.index("\n"))
584+
y << buffer.slice!(..idx)
585+
end
586+
end
587+
y << buffer unless buffer.empty?
588+
end
589+
end
590+
591+
# @private
592+
#
593+
# https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
594+
#
595+
# @param lines [Enumerable]
596+
#
597+
# @return [Hash{Symbol=>Object}]
598+
#
599+
def parse_sse(lines)
600+
chain_fused(lines) do |y|
601+
blank = {event: nil, data: nil, id: nil, retry: nil}
602+
current = {}
603+
604+
lines.each do |line|
605+
case line.strip
606+
in ""
607+
next if current.empty?
608+
y << {**blank, **current}
609+
current = {}
610+
in /^:/
611+
next
612+
in /^([^:]+):\s?(.*)$/
613+
_, field, value = Regexp.last_match.to_a
614+
case field
615+
in "event"
616+
current.merge!(event: value)
617+
in "data"
618+
(current[:data] ||= String.new) << value << "\n"
619+
in "id" unless value.include?("\0")
620+
current.merge!(id: value)
621+
in "retry" if /^\d+$/ =~ value
622+
current.merge!(retry: Integer(value))
623+
else
624+
end
625+
else
626+
end
627+
end
628+
629+
y << {**blank, **current} unless current.empty?
630+
end
631+
end
632+
end
515633
end
516634

517635
# rubocop:enable Metrics/ModuleLength

0 commit comments

Comments
 (0)