summaryrefslogtreecommitdiffhomepage
path: root/lib/dispatch/adapter
diff options
context:
space:
mode:
Diffstat (limited to 'lib/dispatch/adapter')
-rw-r--r--lib/dispatch/adapter/copilot.rb53
-rw-r--r--lib/dispatch/adapter/rate_limiter.rb171
-rw-r--r--lib/dispatch/adapter/response.rb8
-rw-r--r--lib/dispatch/adapter/version.rb2
4 files changed, 209 insertions, 25 deletions
diff --git a/lib/dispatch/adapter/copilot.rb b/lib/dispatch/adapter/copilot.rb
index 200fcb9..728ce99 100644
--- a/lib/dispatch/adapter/copilot.rb
+++ b/lib/dispatch/adapter/copilot.rb
@@ -12,6 +12,7 @@ require_relative "response"
require_relative "tool_definition"
require_relative "model_info"
require_relative "base"
+require_relative "rate_limiter"
require_relative "version"
@@ -55,7 +56,8 @@ module Dispatch
VALID_THINKING_LEVELS = %w[low medium high].freeze
- def initialize(model: "gpt-4.1", github_token: nil, token_path: nil, max_tokens: 8192, thinking: nil)
+ def initialize(model: "gpt-4.1", github_token: nil, token_path: nil, max_tokens: 8192, thinking: nil,
+ min_request_interval: 3.0, rate_limit: nil)
super()
@model = model
@github_token = github_token
@@ -66,9 +68,16 @@ module Dispatch
@copilot_token_expires_at = 0
@mutex = Mutex.new
validate_thinking_level!(@default_thinking)
+
+ rate_limit_path = File.join(File.dirname(@token_path), "copilot_rate_limit")
+ @rate_limiter = RateLimiter.new(
+ rate_limit_path: rate_limit_path,
+ min_request_interval: min_request_interval,
+ rate_limit: rate_limit
+ )
end
- def chat(messages, system: nil, tools: [], stream: false, max_tokens: nil, thinking: :default, &block)
+ def chat(messages, system: nil, tools: [], stream: false, max_tokens: nil, thinking: :default, &)
ensure_authenticated!
wire_messages = build_wire_messages(messages, system)
wire_tools = build_wire_tools(tools)
@@ -86,7 +95,7 @@ module Dispatch
body[:reasoning_effort] = effective_thinking if effective_thinking
if stream
- chat_streaming(body, &block)
+ chat_streaming(body, &)
else
chat_non_streaming(body)
end
@@ -106,6 +115,7 @@ module Dispatch
def list_models
ensure_authenticated!
+ @rate_limiter.wait!
uri = URI("#{API_BASE}/v1/models")
request = Net::HTTP::Get.new(uri)
apply_headers!(request)
@@ -133,7 +143,8 @@ module Dispatch
return if VALID_THINKING_LEVELS.include?(level)
- raise ArgumentError, "Invalid thinking level: #{level.inspect}. Must be one of: #{VALID_THINKING_LEVELS.join(", ")}, or nil"
+ raise ArgumentError,
+ "Invalid thinking level: #{level.inspect}. Must be one of: #{VALID_THINKING_LEVELS.join(", ")}, or nil"
end
def default_token_path
@@ -184,10 +195,10 @@ module Dispatch
verification_uri = data["verification_uri"]
interval = (data["interval"] || 5).to_i
- $stderr.puts "\n=== GitHub Device Authorization ==="
- $stderr.puts "Open: #{verification_uri}"
- $stderr.puts "Enter code: #{user_code}"
- $stderr.puts "Waiting for authorization...\n\n"
+ warn "\n=== GitHub Device Authorization ==="
+ warn "Open: #{verification_uri}"
+ warn "Enter code: #{user_code}"
+ warn "Waiting for authorization...\n\n"
poll_for_access_token(device_code, interval)
end
@@ -273,7 +284,7 @@ module Dispatch
)
end
- def execute_streaming_request(uri, request, &block)
+ def execute_streaming_request(uri, request)
http = Net::HTTP.new(uri.host, uri.port)
http.use_ssl = (uri.scheme == "https")
http.open_timeout = 30
@@ -282,7 +293,7 @@ module Dispatch
http.start do |h|
h.request(request) do |response|
handle_error_response!(response) unless response.is_a?(Net::HTTPSuccess)
- block.call(response)
+ yield(response)
end
end
rescue Errno::ECONNREFUSED, Errno::EHOSTUNREACH, Errno::ETIMEDOUT,
@@ -456,6 +467,7 @@ module Dispatch
# --- Chat (non-streaming) ---
def chat_non_streaming(body)
+ @rate_limiter.wait!
uri = URI("#{API_BASE}/chat/completions")
request = Net::HTTP::Post.new(uri)
apply_headers!(request)
@@ -522,6 +534,7 @@ module Dispatch
# --- Chat (streaming) ---
def chat_streaming(body, &block)
+ @rate_limiter.wait!
uri = URI("#{API_BASE}/chat/completions")
request = Net::HTTP::Post.new(uri)
apply_headers!(request)
@@ -552,7 +565,7 @@ module Dispatch
}
end
- def process_sse_buffer(buffer, collected, &block)
+ def process_sse_buffer(buffer, collected, &)
while (line_end = buffer.index("\n"))
line = buffer.slice!(0..line_end).strip
next if line.empty?
@@ -562,14 +575,14 @@ module Dispatch
next if data_str == "[DONE]"
data = JSON.parse(data_str)
- process_stream_chunk(data, collected, &block)
+ process_stream_chunk(data, collected, &)
end
rescue JSON::ParserError
# Incomplete JSON chunk, will be completed on next read
nil
end
- def process_stream_chunk(data, collected, &block)
+ def process_stream_chunk(data, collected, &)
collected[:model] = data["model"] if data["model"]
choice = data.dig("choices", 0)
@@ -578,20 +591,20 @@ module Dispatch
collected[:finish_reason] = choice["finish_reason"] if choice["finish_reason"]
delta = choice["delta"] || {}
- process_text_delta(delta, collected, &block)
- process_tool_call_deltas(delta, collected, &block)
+ process_text_delta(delta, collected, &)
+ process_tool_call_deltas(delta, collected, &)
process_usage(data, collected)
end
- def process_text_delta(delta, collected, &block)
+ def process_text_delta(delta, collected)
return unless delta["content"]
collected[:content] << delta["content"]
- block.call(StreamDelta.new(type: :text_delta, text: delta["content"]))
+ yield(StreamDelta.new(type: :text_delta, text: delta["content"]))
end
- def process_tool_call_deltas(delta, collected, &block)
+ def process_tool_call_deltas(delta, collected)
return unless delta["tool_calls"]
delta["tool_calls"].each do |tc_delta|
@@ -601,7 +614,7 @@ module Dispatch
if tc_delta["id"]
tc[:id] = tc_delta["id"]
tc[:name] = tc_delta.dig("function", "name") || ""
- block.call(StreamDelta.new(
+ yield(StreamDelta.new(
type: :tool_use_start,
tool_call_id: tc[:id],
tool_name: tc[:name]
@@ -612,7 +625,7 @@ module Dispatch
next if arg_frag.empty?
tc[:arguments] << arg_frag
- block.call(StreamDelta.new(
+ yield(StreamDelta.new(
type: :tool_use_delta,
tool_call_id: tc[:id],
argument_delta: arg_frag
diff --git a/lib/dispatch/adapter/rate_limiter.rb b/lib/dispatch/adapter/rate_limiter.rb
new file mode 100644
index 0000000..6f10905
--- /dev/null
+++ b/lib/dispatch/adapter/rate_limiter.rb
@@ -0,0 +1,171 @@
+# frozen_string_literal: true
+
+require "json"
+require "fileutils"
+
+module Dispatch
+ module Adapter
+ class RateLimiter
+ def initialize(rate_limit_path:, min_request_interval:, rate_limit:)
+ validate_min_request_interval!(min_request_interval)
+ validate_rate_limit!(rate_limit)
+
+ @rate_limit_path = rate_limit_path
+ @min_request_interval = min_request_interval
+ @rate_limit = rate_limit
+ end
+
+ def wait!
+ return if disabled?
+
+ loop do
+ wait_time = 0.0
+
+ File.open(rate_limit_file, File::RDWR | File::CREAT) do |file|
+ file.flock(File::LOCK_EX)
+ state = read_state(file)
+ now = Time.now.to_f
+ wait_time = compute_wait(state, now)
+
+ if wait_time <= 0
+ record_request(state, now)
+ write_state(file, state)
+ return
+ end
+ end
+
+ sleep(wait_time)
+ end
+ end
+
+ private
+
+ def disabled?
+ effective_min_interval.nil? && @rate_limit.nil?
+ end
+
+ def effective_min_interval
+ return nil if @min_request_interval.nil?
+ return nil if @min_request_interval.zero?
+
+ @min_request_interval
+ end
+
+ def rate_limit_file
+ FileUtils.mkdir_p(File.dirname(@rate_limit_path))
+ File.chmod(0o600, @rate_limit_path) if File.exist?(@rate_limit_path)
+ @rate_limit_path
+ end
+
+ def read_state(file)
+ file.rewind
+ content = file.read
+ return default_state if content.nil? || content.strip.empty?
+
+ parsed = JSON.parse(content)
+ {
+ "last_request_at" => parsed["last_request_at"]&.to_f,
+ "request_log" => Array(parsed["request_log"]).map(&:to_f)
+ }
+ rescue JSON::ParserError
+ default_state
+ end
+
+ def default_state
+ { "last_request_at" => nil, "request_log" => [] }
+ end
+
+ def write_state(file, state)
+ file.rewind
+ file.truncate(0)
+ file.write(JSON.generate(state))
+ file.flush
+
+ File.chmod(0o600, @rate_limit_path)
+ end
+
+ def compute_wait(state, now)
+ cooldown_wait = compute_cooldown_wait(state, now)
+ window_wait = compute_window_wait(state, now)
+ [cooldown_wait, window_wait].max
+ end
+
+ def compute_cooldown_wait(state, now)
+ interval = effective_min_interval
+ return 0.0 if interval.nil?
+
+ last = state["last_request_at"]
+ return 0.0 if last.nil?
+
+ elapsed = now - last
+ remaining = interval - elapsed
+ remaining > 0 ? remaining : 0.0
+ end
+
+ def compute_window_wait(state, now)
+ return 0.0 if @rate_limit.nil?
+
+ max_requests = @rate_limit[:requests]
+ period = @rate_limit[:period]
+ window_start = now - period
+
+ log = state["request_log"].select { |t| t > window_start }
+
+ return 0.0 if log.size < max_requests
+
+ oldest_in_window = log.min
+ wait = oldest_in_window + period - now
+ wait > 0 ? wait : 0.0
+ end
+
+ def record_request(state, now)
+ state["last_request_at"] = now
+ state["request_log"] << now
+ prune_log(state, now)
+ end
+
+ def prune_log(state, now)
+ if @rate_limit
+ period = @rate_limit[:period]
+ cutoff = now - period
+ state["request_log"] = state["request_log"].select { |t| t > cutoff }
+ else
+ state["request_log"] = []
+ end
+ end
+
+ def validate_min_request_interval!(value)
+ return if value.nil?
+
+ unless value.is_a?(Numeric)
+ raise ArgumentError,
+ "min_request_interval must be nil or a Numeric >= 0, got #{value.inspect}"
+ end
+
+ return unless value.negative?
+
+ raise ArgumentError,
+ "min_request_interval must be nil or a Numeric >= 0, got #{value.inspect}"
+ end
+
+ def validate_rate_limit!(value)
+ return if value.nil?
+
+ unless value.is_a?(Hash)
+ raise ArgumentError,
+ "rate_limit must be nil or a Hash with :requests and :period keys, got #{value.inspect}"
+ end
+
+ unless value.key?(:requests) && value[:requests].is_a?(Integer) && value[:requests].positive?
+ raise ArgumentError,
+ "rate_limit[:requests] must be a positive Integer, got #{value[:requests].inspect}"
+ end
+
+ return if value.key?(:period) && value[:period].is_a?(Numeric) && value[:period].positive?
+
+ raise ArgumentError,
+ "rate_limit[:period] must be a positive Numeric, got #{value[:period].inspect}"
+ end
+ end
+ end
+end
diff --git a/lib/dispatch/adapter/response.rb b/lib/dispatch/adapter/response.rb
index d3e4789..b4ba3eb 100644
--- a/lib/dispatch/adapter/response.rb
+++ b/lib/dispatch/adapter/response.rb
@@ -3,20 +3,20 @@
module Dispatch
module Adapter
Response = Struct.new(:content, :tool_calls, :model, :stop_reason, :usage, keyword_init: true) do
- def initialize(content: nil, tool_calls: [], model:, stop_reason:, usage:)
- super(content:, tool_calls:, model:, stop_reason:, usage:)
+ def initialize(model:, stop_reason:, usage:, content: nil, tool_calls: [])
+ super
end
end
Usage = Struct.new(:input_tokens, :output_tokens, :cache_read_tokens, :cache_creation_tokens, keyword_init: true) do
def initialize(input_tokens:, output_tokens:, cache_read_tokens: 0, cache_creation_tokens: 0)
- super(input_tokens:, output_tokens:, cache_read_tokens:, cache_creation_tokens:)
+ super
end
end
StreamDelta = Struct.new(:type, :text, :tool_call_id, :tool_name, :argument_delta, keyword_init: true) do
def initialize(type:, text: nil, tool_call_id: nil, tool_name: nil, argument_delta: nil)
- super(type:, text:, tool_call_id:, tool_name:, argument_delta:)
+ super
end
end
end
diff --git a/lib/dispatch/adapter/version.rb b/lib/dispatch/adapter/version.rb
index 3df9f5f..6df2da6 100644
--- a/lib/dispatch/adapter/version.rb
+++ b/lib/dispatch/adapter/version.rb
@@ -3,7 +3,7 @@
module Dispatch
module Adapter
module CopilotVersion
- VERSION = "0.1.0"
+ VERSION = "0.2.0"
end
end
end