diff options
Diffstat (limited to 'lib/dispatch/adapter')
| -rw-r--r-- | lib/dispatch/adapter/copilot.rb | 53 | ||||
| -rw-r--r-- | lib/dispatch/adapter/rate_limiter.rb | 171 | ||||
| -rw-r--r-- | lib/dispatch/adapter/response.rb | 8 | ||||
| -rw-r--r-- | lib/dispatch/adapter/version.rb | 2 |
4 files changed, 209 insertions, 25 deletions
diff --git a/lib/dispatch/adapter/copilot.rb b/lib/dispatch/adapter/copilot.rb index 200fcb9..728ce99 100644 --- a/lib/dispatch/adapter/copilot.rb +++ b/lib/dispatch/adapter/copilot.rb @@ -12,6 +12,7 @@ require_relative "response" require_relative "tool_definition" require_relative "model_info" require_relative "base" +require_relative "rate_limiter" require_relative "version" @@ -55,7 +56,8 @@ module Dispatch VALID_THINKING_LEVELS = %w[low medium high].freeze - def initialize(model: "gpt-4.1", github_token: nil, token_path: nil, max_tokens: 8192, thinking: nil) + def initialize(model: "gpt-4.1", github_token: nil, token_path: nil, max_tokens: 8192, thinking: nil, + min_request_interval: 3.0, rate_limit: nil) super() @model = model @github_token = github_token @@ -66,9 +68,16 @@ module Dispatch @copilot_token_expires_at = 0 @mutex = Mutex.new validate_thinking_level!(@default_thinking) + + rate_limit_path = File.join(File.dirname(@token_path), "copilot_rate_limit") + @rate_limiter = RateLimiter.new( + rate_limit_path: rate_limit_path, + min_request_interval: min_request_interval, + rate_limit: rate_limit + ) end - def chat(messages, system: nil, tools: [], stream: false, max_tokens: nil, thinking: :default, &block) + def chat(messages, system: nil, tools: [], stream: false, max_tokens: nil, thinking: :default, &) ensure_authenticated! wire_messages = build_wire_messages(messages, system) wire_tools = build_wire_tools(tools) @@ -86,7 +95,7 @@ module Dispatch body[:reasoning_effort] = effective_thinking if effective_thinking if stream - chat_streaming(body, &block) + chat_streaming(body, &) else chat_non_streaming(body) end @@ -106,6 +115,7 @@ module Dispatch def list_models ensure_authenticated! + @rate_limiter.wait! uri = URI("#{API_BASE}/v1/models") request = Net::HTTP::Get.new(uri) apply_headers!(request) @@ -133,7 +143,8 @@ module Dispatch return if VALID_THINKING_LEVELS.include?(level) - raise ArgumentError, "Invalid thinking level: #{level.inspect}. Must be one of: #{VALID_THINKING_LEVELS.join(", ")}, or nil" + raise ArgumentError, + "Invalid thinking level: #{level.inspect}. Must be one of: #{VALID_THINKING_LEVELS.join(", ")}, or nil" end def default_token_path @@ -184,10 +195,10 @@ module Dispatch verification_uri = data["verification_uri"] interval = (data["interval"] || 5).to_i - $stderr.puts "\n=== GitHub Device Authorization ===" - $stderr.puts "Open: #{verification_uri}" - $stderr.puts "Enter code: #{user_code}" - $stderr.puts "Waiting for authorization...\n\n" + warn "\n=== GitHub Device Authorization ===" + warn "Open: #{verification_uri}" + warn "Enter code: #{user_code}" + warn "Waiting for authorization...\n\n" poll_for_access_token(device_code, interval) end @@ -273,7 +284,7 @@ module Dispatch ) end - def execute_streaming_request(uri, request, &block) + def execute_streaming_request(uri, request) http = Net::HTTP.new(uri.host, uri.port) http.use_ssl = (uri.scheme == "https") http.open_timeout = 30 @@ -282,7 +293,7 @@ module Dispatch http.start do |h| h.request(request) do |response| handle_error_response!(response) unless response.is_a?(Net::HTTPSuccess) - block.call(response) + yield(response) end end rescue Errno::ECONNREFUSED, Errno::EHOSTUNREACH, Errno::ETIMEDOUT, @@ -456,6 +467,7 @@ module Dispatch # --- Chat (non-streaming) --- def chat_non_streaming(body) + @rate_limiter.wait! uri = URI("#{API_BASE}/chat/completions") request = Net::HTTP::Post.new(uri) apply_headers!(request) @@ -522,6 +534,7 @@ module Dispatch # --- Chat (streaming) --- def chat_streaming(body, &block) + @rate_limiter.wait! uri = URI("#{API_BASE}/chat/completions") request = Net::HTTP::Post.new(uri) apply_headers!(request) @@ -552,7 +565,7 @@ module Dispatch } end - def process_sse_buffer(buffer, collected, &block) + def process_sse_buffer(buffer, collected, &) while (line_end = buffer.index("\n")) line = buffer.slice!(0..line_end).strip next if line.empty? @@ -562,14 +575,14 @@ module Dispatch next if data_str == "[DONE]" data = JSON.parse(data_str) - process_stream_chunk(data, collected, &block) + process_stream_chunk(data, collected, &) end rescue JSON::ParserError # Incomplete JSON chunk, will be completed on next read nil end - def process_stream_chunk(data, collected, &block) + def process_stream_chunk(data, collected, &) collected[:model] = data["model"] if data["model"] choice = data.dig("choices", 0) @@ -578,20 +591,20 @@ module Dispatch collected[:finish_reason] = choice["finish_reason"] if choice["finish_reason"] delta = choice["delta"] || {} - process_text_delta(delta, collected, &block) - process_tool_call_deltas(delta, collected, &block) + process_text_delta(delta, collected, &) + process_tool_call_deltas(delta, collected, &) process_usage(data, collected) end - def process_text_delta(delta, collected, &block) + def process_text_delta(delta, collected) return unless delta["content"] collected[:content] << delta["content"] - block.call(StreamDelta.new(type: :text_delta, text: delta["content"])) + yield(StreamDelta.new(type: :text_delta, text: delta["content"])) end - def process_tool_call_deltas(delta, collected, &block) + def process_tool_call_deltas(delta, collected) return unless delta["tool_calls"] delta["tool_calls"].each do |tc_delta| @@ -601,7 +614,7 @@ module Dispatch if tc_delta["id"] tc[:id] = tc_delta["id"] tc[:name] = tc_delta.dig("function", "name") || "" - block.call(StreamDelta.new( + yield(StreamDelta.new( type: :tool_use_start, tool_call_id: tc[:id], tool_name: tc[:name] @@ -612,7 +625,7 @@ module Dispatch next if arg_frag.empty? tc[:arguments] << arg_frag - block.call(StreamDelta.new( + yield(StreamDelta.new( type: :tool_use_delta, tool_call_id: tc[:id], argument_delta: arg_frag diff --git a/lib/dispatch/adapter/rate_limiter.rb b/lib/dispatch/adapter/rate_limiter.rb new file mode 100644 index 0000000..6f10905 --- /dev/null +++ b/lib/dispatch/adapter/rate_limiter.rb @@ -0,0 +1,171 @@ +# frozen_string_literal: true + +require "json" +require "fileutils" + +module Dispatch + module Adapter + class RateLimiter + def initialize(rate_limit_path:, min_request_interval:, rate_limit:) + validate_min_request_interval!(min_request_interval) + validate_rate_limit!(rate_limit) + + @rate_limit_path = rate_limit_path + @min_request_interval = min_request_interval + @rate_limit = rate_limit + end + + def wait! + return if disabled? + + loop do + wait_time = 0.0 + + File.open(rate_limit_file, File::RDWR | File::CREAT) do |file| + file.flock(File::LOCK_EX) + state = read_state(file) + now = Time.now.to_f + wait_time = compute_wait(state, now) + + if wait_time <= 0 + record_request(state, now) + write_state(file, state) + return + end + end + + sleep(wait_time) + end + end + + private + + def disabled? + effective_min_interval.nil? && @rate_limit.nil? + end + + def effective_min_interval + return nil if @min_request_interval.nil? + return nil if @min_request_interval.zero? + + @min_request_interval + end + + def rate_limit_file + FileUtils.mkdir_p(File.dirname(@rate_limit_path)) + File.chmod(0o600, @rate_limit_path) if File.exist?(@rate_limit_path) + @rate_limit_path + end + + def read_state(file) + file.rewind + content = file.read + return default_state if content.nil? || content.strip.empty? + + parsed = JSON.parse(content) + { + "last_request_at" => parsed["last_request_at"]&.to_f, + "request_log" => Array(parsed["request_log"]).map(&:to_f) + } + rescue JSON::ParserError + default_state + end + + def default_state + { "last_request_at" => nil, "request_log" => [] } + end + + def write_state(file, state) + file.rewind + file.truncate(0) + file.write(JSON.generate(state)) + file.flush + + File.chmod(0o600, @rate_limit_path) + end + + def compute_wait(state, now) + cooldown_wait = compute_cooldown_wait(state, now) + window_wait = compute_window_wait(state, now) + [cooldown_wait, window_wait].max + end + + def compute_cooldown_wait(state, now) + interval = effective_min_interval + return 0.0 if interval.nil? + + last = state["last_request_at"] + return 0.0 if last.nil? + + elapsed = now - last + remaining = interval - elapsed + remaining > 0 ? remaining : 0.0 + end + + def compute_window_wait(state, now) + return 0.0 if @rate_limit.nil? + + max_requests = @rate_limit[:requests] + period = @rate_limit[:period] + window_start = now - period + + log = state["request_log"].select { |t| t > window_start } + + return 0.0 if log.size < max_requests + + oldest_in_window = log.min + wait = oldest_in_window + period - now + wait > 0 ? wait : 0.0 + end + + def record_request(state, now) + state["last_request_at"] = now + state["request_log"] << now + prune_log(state, now) + end + + def prune_log(state, now) + if @rate_limit + period = @rate_limit[:period] + cutoff = now - period + state["request_log"] = state["request_log"].select { |t| t > cutoff } + else + state["request_log"] = [] + end + end + + def validate_min_request_interval!(value) + return if value.nil? + + unless value.is_a?(Numeric) + raise ArgumentError, + "min_request_interval must be nil or a Numeric >= 0, got #{value.inspect}" + end + + return unless value.negative? + + raise ArgumentError, + "min_request_interval must be nil or a Numeric >= 0, got #{value.inspect}" + end + + def validate_rate_limit!(value) + return if value.nil? + + unless value.is_a?(Hash) + raise ArgumentError, + "rate_limit must be nil or a Hash with :requests and :period keys, got #{value.inspect}" + end + + unless value.key?(:requests) && value[:requests].is_a?(Integer) && value[:requests].positive? + raise ArgumentError, + "rate_limit[:requests] must be a positive Integer, got #{value[:requests].inspect}" + end + + return if value.key?(:period) && value[:period].is_a?(Numeric) && value[:period].positive? + + raise ArgumentError, + "rate_limit[:period] must be a positive Numeric, got #{value[:period].inspect}" + end + end + end +end diff --git a/lib/dispatch/adapter/response.rb b/lib/dispatch/adapter/response.rb index d3e4789..b4ba3eb 100644 --- a/lib/dispatch/adapter/response.rb +++ b/lib/dispatch/adapter/response.rb @@ -3,20 +3,20 @@ module Dispatch module Adapter Response = Struct.new(:content, :tool_calls, :model, :stop_reason, :usage, keyword_init: true) do - def initialize(content: nil, tool_calls: [], model:, stop_reason:, usage:) - super(content:, tool_calls:, model:, stop_reason:, usage:) + def initialize(model:, stop_reason:, usage:, content: nil, tool_calls: []) + super end end Usage = Struct.new(:input_tokens, :output_tokens, :cache_read_tokens, :cache_creation_tokens, keyword_init: true) do def initialize(input_tokens:, output_tokens:, cache_read_tokens: 0, cache_creation_tokens: 0) - super(input_tokens:, output_tokens:, cache_read_tokens:, cache_creation_tokens:) + super end end StreamDelta = Struct.new(:type, :text, :tool_call_id, :tool_name, :argument_delta, keyword_init: true) do def initialize(type:, text: nil, tool_call_id: nil, tool_name: nil, argument_delta: nil) - super(type:, text:, tool_call_id:, tool_name:, argument_delta:) + super end end end diff --git a/lib/dispatch/adapter/version.rb b/lib/dispatch/adapter/version.rb index 3df9f5f..6df2da6 100644 --- a/lib/dispatch/adapter/version.rb +++ b/lib/dispatch/adapter/version.rb @@ -3,7 +3,7 @@ module Dispatch module Adapter module CopilotVersion - VERSION = "0.1.0" + VERSION = "0.2.0" end end end |
