diff --git a/README.md b/README.md index d1926247..9f26a188 100644 --- a/README.md +++ b/README.md @@ -254,7 +254,7 @@ client.chat( model: "deepseek-chat", # Required. messages: [{ role: "user", content: "Hello!"}], # Required. temperature: 0.7, - stream: proc do |chunk, _bytesize| + stream: proc do |chunk, _event| print chunk.dig("choices", 0, "delta", "content") end } @@ -285,7 +285,7 @@ client.chat( model: "llama3", # Required. messages: [{ role: "user", content: "Hello!"}], # Required. temperature: 0.7, - stream: proc do |chunk, _bytesize| + stream: proc do |chunk, _event| print chunk.dig("choices", 0, "delta", "content") end } @@ -309,7 +309,7 @@ client.chat( model: "llama3-8b-8192", # Required. messages: [{ role: "user", content: "Hello!"}], # Required. temperature: 0.7, - stream: proc do |chunk, _bytesize| + stream: proc do |chunk, _event| print chunk.dig("choices", 0, "delta", "content") end } @@ -371,7 +371,7 @@ client.chat( model: "gpt-4o", # Required. messages: [{ role: "user", content: "Describe a character called Anna!"}], # Required. temperature: 0.7, - stream: proc do |chunk, _bytesize| + stream: proc do |chunk, _event| print chunk.dig("choices", 0, "delta", "content") end } @@ -457,7 +457,7 @@ You can stream it as well! model: "gpt-4o", messages: [{ role: "user", content: "Can I have some JSON please?"}], response_format: { type: "json_object" }, - stream: proc do |chunk, _bytesize| + stream: proc do |chunk, _event| print chunk.dig("choices", 0, "delta", "content") end } @@ -542,7 +542,7 @@ client.responses.create( parameters: { model: "gpt-4o", # Required. input: "Hello!", # Required. - stream: proc do |chunk, _bytesize| + stream: proc do |chunk, _event| if chunk["type"] == "response.output_text.delta" print chunk["delta"] $stdout.flush # Ensure output is displayed immediately @@ -1163,7 +1163,7 @@ client.runs.create( assistant_id: assistant_id, max_prompt_tokens: 256, max_completion_tokens: 16, - stream: proc do |chunk, _bytesize| + stream: proc do |chunk, _event| if chunk["object"] == "thread.message.delta" print chunk.dig("delta", "content", 0, "text", "value") end diff --git a/lib/openai.rb b/lib/openai.rb index 978206f4..ae2557d7 100644 --- a/lib/openai.rb +++ b/lib/openai.rb @@ -12,6 +12,7 @@ require_relative "openai/messages" require_relative "openai/runs" require_relative "openai/run_steps" +require_relative "openai/stream" require_relative "openai/vector_stores" require_relative "openai/vector_store_files" require_relative "openai/vector_store_file_batches" diff --git a/lib/openai/http.rb b/lib/openai/http.rb index 644e692d..5459f64c 100644 --- a/lib/openai/http.rb +++ b/lib/openai/http.rb @@ -55,27 +55,6 @@ def parse_json(response) original_response end - # Given a proc, returns an outer proc that can be used to iterate over a JSON stream of chunks. - # For each chunk, the inner user_proc is called giving it the JSON object. The JSON object could - # be a data object or an error object as described in the OpenAI API documentation. - # - # @param user_proc [Proc] The inner proc to call for each JSON object in the chunk. - # @return [Proc] An outer proc that iterates over a raw stream, converting it to JSON. - def to_json_stream(user_proc:) - parser = EventStreamParser::Parser.new - - proc do |chunk, _bytes, env| - if env && env.status != 200 - raise_error = Faraday::Response::RaiseError.new - raise_error.on_complete(env.merge(body: try_parse_json(chunk))) - end - - parser.feed(chunk) do |_type, data| - user_proc.call(JSON.parse(data)) unless data == "[DONE]" - end - end - end - def conn(multipart: false) connection = Faraday.new do |f| f.options[:timeout] = @request_timeout @@ -120,7 +99,7 @@ def configure_json_post_request(req, parameters) req_parameters = parameters.dup if parameters[:stream].respond_to?(:call) - req.options.on_data = to_json_stream(user_proc: parameters[:stream]) + req.options.on_data = Stream.new(user_proc: parameters[:stream]).to_proc req_parameters[:stream] = true # Necessary to tell OpenAI to stream. elsif parameters[:stream] raise ArgumentError, "The stream parameter must be a Proc or have a #call method" @@ -129,11 +108,5 @@ def configure_json_post_request(req, parameters) req.headers = headers req.body = req_parameters.to_json end - - def try_parse_json(maybe_json) - JSON.parse(maybe_json) - rescue JSON::ParserError - maybe_json - end end end diff --git a/lib/openai/stream.rb b/lib/openai/stream.rb new file mode 100644 index 00000000..6bd70e6c --- /dev/null +++ b/lib/openai/stream.rb @@ -0,0 +1,50 @@ +module OpenAI + class Stream + DONE = "[DONE]".freeze + private_constant :DONE + + def initialize(user_proc:, parser: EventStreamParser::Parser.new) + @user_proc = user_proc + @parser = parser + + # To be backwards compatible, we need to check how many arguments the user_proc takes. + @user_proc_arity = + case user_proc + when Proc + user_proc.arity.abs + else + user_proc.method(:call).arity.abs + end + end + + def call(chunk, _bytes, env) + handle_http_error(chunk: chunk, env: env) if env && env.status != 200 + + parser.feed(chunk) do |event, data| + next if data == DONE + + args = [JSON.parse(data), event].first(user_proc_arity) + user_proc.call(*args) + end + end + + def to_proc + method(:call).to_proc + end + + private + + attr_reader :user_proc, :parser, :user_proc_arity + + def handle_http_error(chunk:, env:) + raise_error = Faraday::Response::RaiseError.new + raise_error.on_complete(env.merge(body: try_parse_json(chunk))) + end + + def try_parse_json(maybe_json) + JSON.parse(maybe_json) + rescue JSON::ParserError + maybe_json + end + end +end diff --git a/spec/openai/client/chat_spec.rb b/spec/openai/client/chat_spec.rb index bc8ce93e..9a3f7fda 100644 --- a/spec/openai/client/chat_spec.rb +++ b/spec/openai/client/chat_spec.rb @@ -82,7 +82,7 @@ describe "streaming" do let(:chunks) { [] } let(:stream) do - proc do |chunk, _bytesize| + proc do |chunk, _event| chunks << chunk end end @@ -196,7 +196,7 @@ def call(chunk) end let(:chunks) { [] } let(:stream) do - proc do |chunk, _bytesize| + proc do |chunk, _event| chunks << chunk end end @@ -224,7 +224,7 @@ def call(chunk) end let(:chunks) { [] } let(:stream) do - proc do |chunk, _bytesize| + proc do |chunk, _event| chunks << chunk end end diff --git a/spec/openai/client/http_spec.rb b/spec/openai/client/http_spec.rb index 39261518..7995e2d2 100644 --- a/spec/openai/client/http_spec.rb +++ b/spec/openai/client/http_spec.rb @@ -55,7 +55,7 @@ context "streaming" do let(:chunks) { [] } let(:stream) do - proc do |chunk, _bytesize| + proc do |chunk, _event| chunks << chunk end end @@ -120,75 +120,6 @@ end end - describe ".to_json_stream" do - context "with a proc" do - let(:user_proc) { proc { |x| x } } - let(:stream) { OpenAI::Client.new.send(:to_json_stream, user_proc: user_proc) } - - it "returns a proc" do - expect(stream).to be_a(Proc) - end - - context "when called with a string containing a single JSON object" do - it "calls the user proc with the data parsed as JSON" do - expect(user_proc).to receive(:call).with(JSON.parse('{"foo": "bar"}')) - stream.call(<<~CHUNK) - data: { "foo": "bar" } - - # - CHUNK - end - end - - context "when called with a string containing more than one JSON object" do - it "calls the user proc for each data parsed as JSON" do - expect(user_proc).to receive(:call).with(JSON.parse('{"foo": "bar"}')) - expect(user_proc).to receive(:call).with(JSON.parse('{"baz": "qud"}')) - - stream.call(<<~CHUNK) - data: { "foo": "bar" } - - data: { "baz": "qud" } - - data: [DONE] - - # - CHUNK - end - end - - context "when called with string containing invalid JSON" do - let(:chunk) do - <<~CHUNK - data: { "foo": "bar" } - - data: NOT JSON - - # - CHUNK - end - - it "raise an error" do - expect(user_proc).to receive(:call).with(JSON.parse('{"foo": "bar"}')) - - expect do - stream.call(chunk) - end.to raise_error(JSON::ParserError) - end - end - - context "when called with JSON split across chunks" do - it "calls the user proc with the data parsed as JSON" do - expect(user_proc).to receive(:call).with(JSON.parse('{ "foo": "bar" }')) - expect do - stream.call("data: { \"foo\":") - stream.call(" \"bar\" }\n\n") - end.not_to raise_error - end - end - end - end - describe ".parse_json" do context "with a jsonl string" do let(:body) { "{\"prompt\":\":)\"}\n{\"prompt\":\":(\"}\n" } diff --git a/spec/openai/client/responses_spec.rb b/spec/openai/client/responses_spec.rb index 50c2b561..80f57cf2 100644 --- a/spec/openai/client/responses_spec.rb +++ b/spec/openai/client/responses_spec.rb @@ -91,7 +91,7 @@ describe "streaming" do let(:chunks) { [] } let(:stream) do - proc do |chunk, _bytesize| + proc do |chunk, _event| chunks << chunk end end @@ -112,13 +112,15 @@ let(:cassette) { "responses stream without proc" } let(:stream) do Class.new do - attr_reader :chunks + attr_reader :chunks, :events def initialize @chunks = [] + @events = [] end - def call(chunk) + def call(chunk, event) + @events << event @chunks << chunk end end.new @@ -132,6 +134,8 @@ def call(chunk) .map { |chunk| chunk["delta"] } .join expect(output_text).to include("?") + expect(stream.events.first).to eq("response.created") + expect(stream.events.last).to eq("response.completed") end end end diff --git a/spec/openai/client/runs_spec.rb b/spec/openai/client/runs_spec.rb index 1d5f7e0d..1469356d 100644 --- a/spec/openai/client/runs_spec.rb +++ b/spec/openai/client/runs_spec.rb @@ -85,7 +85,7 @@ describe "streaming" do let(:chunks) { [] } let(:stream) do - proc do |chunk, _bytesize| + proc do |chunk, _event| chunks << chunk end end diff --git a/spec/openai/client/stream_spec.rb b/spec/openai/client/stream_spec.rb new file mode 100644 index 00000000..df97972e --- /dev/null +++ b/spec/openai/client/stream_spec.rb @@ -0,0 +1,117 @@ +RSpec.describe OpenAI::Stream do + let(:user_proc) { proc { |data, event| [data, event] } } + let(:stream) { OpenAI::Stream.new(user_proc: user_proc) } + let(:bytes) { 0 } + let(:env) { Faraday::Env.new.tap { |env| env.status = 200 } } + + describe "#call" do + context "with a proc" do + context "when called with a string containing a single JSON object" do + it "calls the user proc with the data parsed as JSON" do + expect(user_proc).to receive(:call) + .with( + JSON.parse('{"foo": "bar"}'), + "event.test" + ) + + stream.call(<<~CHUNK, bytes, env) + event: event.test + data: { "foo": "bar" } + + # + CHUNK + end + end + + context "when called with a string containing more than one JSON object" do + it "calls the user proc for each data parsed as JSON" do + expect(user_proc).to receive(:call) + .with( + JSON.parse('{"foo": "bar"}'), + "event.test.first" + ) + expect(user_proc).to receive(:call) + .with( + JSON.parse('{"baz": "qud"}'), + "event.test.second" + ) + + stream.call(<<~CHUNK, bytes, env) + event: event.test.first + data: { "foo": "bar" } + + event: event.test.second + data: { "baz": "qud" } + + event: event.complete + data: [DONE] + + # + CHUNK + end + end + + context "when called with string containing invalid JSON" do + let(:chunk) do + <<~CHUNK + event: event.test + data: { "foo": "bar" } + + data: NOT JSON + + # + CHUNK + end + + it "raise an error" do + expect(user_proc).to receive(:call) + .with( + JSON.parse('{"foo": "bar"}'), + "event.test" + ) + + expect do + stream.call(chunk, bytes, env) + end.to raise_error(JSON::ParserError) + end + end + + context "when called with JSON split across chunks" do + it "calls the user proc with the data parsed as JSON" do + expect(user_proc).to receive(:call) + .with( + JSON.parse('{ "foo": "bar" }'), + "event.test" + ) + + expect do + stream.call("event: event.test\n", bytes, env) + stream.call("data: { \"foo\":", bytes, env) + stream.call(" \"bar\" }\n\n", bytes, env) + end.not_to raise_error + end + end + + context "with a call method that only takes one argument" do + let(:user_proc) { proc { |data| data } } + + it "succeeds" do + expect(user_proc).to receive(:call).with(JSON.parse('{"foo": "bar"}')) + + stream.call(<<~CHUNK, bytes, env) + event: event.test + data: { "foo": "bar" } + + # + CHUNK + end + end + end + end + + describe "#to_proc" do + it "returns a proc" do + expect(stream.to_proc).to be_a(Proc) + end + end +end