diff --git a/README.md b/README.md index 1cd5e8de..21678c3f 100644 --- a/README.md +++ b/README.md @@ -189,7 +189,13 @@ client.chat( messages: [{ role: "user", content: "Describe a character called Anna!"}], # Required. temperature: 0.7, stream: proc do |chunk, _bytesize| - print chunk.dig("choices", 0, "delta", "content") + if chunk["result_type"] == "data" + print chunk.dig("choices", 0, "delta", "content") + elsif chunk["result_type"] == "error" + STDERR.puts "Error: #{chunk.inspect}" + else + STDERR.puts "Unknown chunk type: #{chunk.inspect}" + end end }) # => "Anna is a young woman in her mid-twenties, with wavy chestnut hair that falls to her shoulders..." diff --git a/lib/openai/http.rb b/lib/openai/http.rb index 837e3733..07ff41dc 100644 --- a/lib/openai/http.rb +++ b/lib/openai/http.rb @@ -54,10 +54,24 @@ def to_json(string) # @return [Proc] An outer proc that iterates over a raw stream, converting it to JSON. def to_json_stream(user_proc:) proc do |chunk, _| - chunk.scan(/(?:data|error): (\{.*\})/i).flatten.each do |data| - user_proc.call(JSON.parse(data)) - rescue JSON::ParserError - # Ignore invalid JSON. + results = chunk.scan(/^\s*(data|error): *(\{.+\})/i) + if results.length.positive? + results.each do |result_type, result_json| + result = JSON.parse(result_json) + result.merge!("result_type" => result_type) + user_proc.call(result) + rescue JSON::ParserError + # Ignore invalid JSON. + end + elsif !chunk.match(/^\s*(data|error):/i) + begin + result = JSON.parse(chunk) + result_type = result["error"] ? "error" : "unknown" + result.merge!("result_type" => result_type) + user_proc.call(result) + rescue JSON::ParserError + # Ignore invalid JSON. + end end end end diff --git a/spec/openai/client/http_spec.rb b/spec/openai/client/http_spec.rb index 1f0b451d..61202754 100644 --- a/spec/openai/client/http_spec.rb +++ b/spec/openai/client/http_spec.rb @@ -107,17 +107,17 @@ context "when called with a string containing a single JSON object" do it "calls the user proc with the data parsed as JSON" do - expect(user_proc).to receive(:call).with(JSON.parse('{"foo": "bar"}')) + expect(user_proc).to receive(:call).with({ "foo" => "bar", "result_type" => "data" }) stream.call('data: { "foo": "bar" }') end end context "when called with string containing more than one JSON object" do it "calls the user proc for each data parsed as JSON" do - expect(user_proc).to receive(:call).with(JSON.parse('{"foo": "bar"}')) - expect(user_proc).to receive(:call).with(JSON.parse('{"baz": "qud"}')) + expect(user_proc).to receive(:call).with({ "foo" => "bar", "result_type" => "data" }) + expect(user_proc).to receive(:call).with({ "baz" => "qud", "result_type" => "data" }) - stream.call(<<-CHUNK) + stream.call(<<~CHUNK) data: { "foo": "bar" } data: { "baz": "qud" } @@ -141,14 +141,14 @@ context "when called with a string containing that looks like a JSON object but is invalid" do let(:chunk) do - <<-CHUNK + <<~CHUNK data: { "foo": "bar" } data: { BAD ]:-> JSON } CHUNK end it "does not raise an error" do - expect(user_proc).to receive(:call).with(JSON.parse('{"foo": "bar"}')) + expect(user_proc).to receive(:call).with({ "foo" => "bar", "result_type" => "data" }) expect do stream.call(chunk) @@ -158,16 +158,16 @@ context "when called with a string containing an error" do let(:chunk) do - <<-CHUNK + <<~CHUNK data: { "foo": "bar" } error: { "message": "A bad thing has happened!" } CHUNK end it "does not raise an error" do - expect(user_proc).to receive(:call).with(JSON.parse('{ "foo": "bar" }')) + expect(user_proc).to receive(:call).with({ "foo" => "bar", "result_type" => "data" }) expect(user_proc).to receive(:call).with( - JSON.parse('{ "message": "A bad thing has happened!" }') + { "message" => "A bad thing has happened!", "result_type" => "error" } ) expect do @@ -175,6 +175,64 @@ end.not_to raise_error end end + + context "when called with a string that is a JSON object (with no 'data:' or 'error:' prefix)" do + context "when the JSON has a top level 'error' key" do + let(:chunk) do + <<~CHUNK + { + "error": { + "type": "invalid_request_error", + "code": "invalid_api_key" + } + } + CHUNK + end + + it "does not raise an error" do + expect(user_proc).to receive(:call).with( + { + "error" => { + "type" => "invalid_request_error", + "code" => "invalid_api_key" + }, + "result_type" => "error" + } + ) + + expect do + stream.call(chunk) + end.not_to raise_error + end + end + + context "when the JSON does not have a top level 'error' key" do + let(:chunk) do + <<~CHUNK + { + "warning": { + "message": "foobar" + } + } + CHUNK + end + + it "does not raise an error" do + expect(user_proc).to receive(:call).with( + { + "result_type" => "unknown", + "warning" => { + "message" => "foobar" + } + } + ) + + expect do + stream.call(chunk) + end.not_to raise_error + end + end + end end end