1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
|
# frozen_string_literal: true
module Dispatch
module Adapter
class Claude < Base
module RequestBuilder
# Applies Anthropic prompt-caching breakpoints to an already-assembled
# request-params hash, then enforces the 4-breakpoint cap and the
# cache-control TTL ordering rule.
#
# Placement order (matches oh-my-pi `applyPromptCaching`):
# 1. last tool definition
# 2. last system block
# 3. penultimate user message — last text block (or last block)
# 4. last user message — last text block (or last block)
#
# After placement:
# - `enforceCacheControlLimit`: strip excess markers above 4.
# - `normalizeCacheControlTtlOrdering`: once a non-"1h" block is
# seen in (tools→system→messages) order, downgrade any subsequent
# "1h" block to plain ephemeral.
#
# Call `CacheControl.apply(params, cache_retention:, base_url:)` after
# fully assembling params[:tools], params[:system], params[:messages].
module CacheControl
MAX_BREAKPOINTS = 4
module_function
# Main entry point. Mutates params in-place.
#
# @param params [Hash] assembled request params
# @param cache_retention [Symbol, nil] :long | :short | :none | nil
# @param base_url [String]
def apply(params, cache_retention: nil, base_url: "https://api.anthropic.com")
cc = resolve_cache_control(cache_retention, base_url)
# Auto-place markers only when we have a cache_control descriptor
# AND the caller hasn't already placed markers on message blocks.
place_markers(params, cc) if cc && !caller_placed_markers?(params[:messages])
# Always enforce the breakpoint cap and TTL ordering rule, even
# when auto-placement was skipped — the caller may have pre-set
# markers via ToolDefinition#cache_control or TextBlock#cache_control.
enforce_limit(params)
normalize_ttl_ordering(params)
end
# ── Cache-control resolution ─────────────────────────────────────────
# Translate a cache_retention symbol to an Anthropic cache_control hash.
# Returns nil when caching is disabled (:none).
def resolve_cache_control(cache_retention, base_url)
retention = cache_retention || :short
return nil if retention == :none
if retention == :long && base_url.to_s.include?("api.anthropic.com")
{ "type" => "ephemeral", "ttl" => "1h" }
else
{ "type" => "ephemeral", "ttl" => "5m" }
end
end
def place_markers(params, cc)
breakpoints = 0
# 1. Last tool
tools = params[:tools]
if tools.is_a?(Array) && !tools.empty?
set_cache_control!(tools.last, cc)
breakpoints += 1
end
return if breakpoints >= MAX_BREAKPOINTS
# 2. Last system block
system = params[:system]
if system.is_a?(Array) && !system.empty?
set_cache_control!(system.last, cc)
breakpoints += 1
end
return if breakpoints >= MAX_BREAKPOINTS
messages = Array(params[:messages])
user_indexes = messages.each_index.select { |i| messages[i][:role] == "user" }
# 3. Penultimate user message — last text block
if user_indexes.length >= 2
penultimate = messages[user_indexes[-2]]
placed = apply_to_last_text_block?(penultimate, cc)
breakpoints += 1 if placed
end
return if breakpoints >= MAX_BREAKPOINTS
# 4. Last user message — last text block
return unless user_indexes.length >= 1
last_user = messages[user_indexes[-1]]
apply_to_last_text_block?(last_user, cc)
end
# Apply cache_control to the last text block (or last block as fallback)
# of a single message. Converts String content to [{type:"text",text:…}].
# Returns true if a marker was placed.
def apply_to_last_text_block?(msg, cc)
content = msg[:content]
if content.is_a?(String)
# Convert string content to a single text block array
block = { type: "text", text: content }
set_cache_control!(block, cc)
msg[:content] = [block]
return true
end
if content.is_a?(Array) && !content.empty?
# Find last text block; fall back to absolute last block
idx = content.rindex { |b| b.is_a?(Hash) && block_type(b) == "text" }
target = idx ? content[idx] : content.last
set_cache_control!(target, cc) if target.is_a?(Hash)
return true
end
false
end
# Returns true if any message's content array already has a
# cache_control marker (meaning the caller is in charge of caching).
def caller_placed_markers?(messages)
return false unless messages.is_a?(Array)
messages.any? do |msg|
next false unless msg[:content].is_a?(Array)
msg[:content].any? { |b| b.is_a?(Hash) && cache_control_present?(b) }
end
end
# ── Enforce 4-breakpoint cap ─────────────────────────────────────────
def enforce_limit(params)
total = count_breakpoints(params)
return if total <= MAX_BREAKPOINTS
excess = { value: total - MAX_BREAKPOINTS }
system_blocks = Array(params[:system])
tool_blocks = Array(params[:tools])
messages = Array(params[:messages])
last_system_idx = last_marked_index(system_blocks)
last_tool_idx = last_marked_index(tool_blocks)
# 1. Strip system blocks, but preserve the last one
strip_except_index(system_blocks, last_system_idx, excess) unless system_blocks.empty?
return if excess[:value] <= 0
# 2. Strip tool blocks, but preserve the last one
strip_except_index(tool_blocks, last_tool_idx, excess) unless tool_blocks.empty?
return if excess[:value] <= 0
# 3. Strip message content blocks (in order)
strip_message_markers(messages, excess)
return if excess[:value] <= 0
# 4. Strip all remaining system markers
strip_all_marked(system_blocks, excess)
return if excess[:value] <= 0
# 5. Strip all remaining tool markers
strip_all_marked(tool_blocks, excess)
end
def count_breakpoints(params)
total = 0
Array(params[:tools]).each { |b| total += 1 if b.is_a?(Hash) && cache_control_present?(b) }
Array(params[:system]).each { |b| total += 1 if b.is_a?(Hash) && cache_control_present?(b) }
Array(params[:messages]).each do |msg|
next unless msg[:content].is_a?(Array)
msg[:content].each { |b| total += 1 if b.is_a?(Hash) && cache_control_present?(b) }
end
total
end
def last_marked_index(blocks)
blocks.rindex { |b| b.is_a?(Hash) && cache_control_present?(b) } || -1
end
def strip_except_index(blocks, preserve_idx, excess)
blocks.each_with_index do |b, idx|
break if excess[:value] <= 0
next if idx == preserve_idx
next unless b.is_a?(Hash) && cache_control_present?(b)
delete_cache_control!(b)
excess[:value] -= 1
end
end
def strip_all_marked(blocks, excess)
blocks.each do |b|
break if excess[:value] <= 0
next unless b.is_a?(Hash) && cache_control_present?(b)
delete_cache_control!(b)
excess[:value] -= 1
end
end
def strip_message_markers(messages, excess)
messages.each do |msg|
break if excess[:value] <= 0
next unless msg[:content].is_a?(Array)
msg[:content].each do |b|
break if excess[:value] <= 0
next unless b.is_a?(Hash) && cache_control_present?(b)
delete_cache_control!(b)
excess[:value] -= 1
end
end
end
# ── Normalize TTL ordering ───────────────────────────────────────────
#
# Walk tools → system → messages in order.
# Once a block with a non-"1h" ttl (including plain ephemeral with no
# ttl) is seen, all subsequent "1h" blocks are downgraded by deleting
# their ttl field (resulting in plain {type: "ephemeral"}).
def normalize_ttl_ordering(params)
seen_non_one_hour = { value: false }
Array(params[:tools]).each { |b| normalize_block_ttl(b, seen_non_one_hour) }
Array(params[:system]).each { |b| normalize_block_ttl(b, seen_non_one_hour) }
Array(params[:messages]).each do |msg|
next unless msg[:content].is_a?(Array)
msg[:content].each { |b| normalize_block_ttl(b, seen_non_one_hour) if b.is_a?(Hash) }
end
end
def normalize_block_ttl(block, seen_non_one_hour)
return unless block.is_a?(Hash)
cc = get_cache_control(block)
return unless cc.is_a?(Hash)
ttl = cc[:ttl] || cc["ttl"]
if ttl != "1h"
seen_non_one_hour[:value] = true
return
end
# This block has ttl: "1h" — downgrade if a non-"1h" was seen before
return unless seen_non_one_hour[:value]
cc.delete(:ttl)
cc.delete("ttl")
end
# ── Low-level helpers ────────────────────────────────────────────────
# Set cache_control on a block hash, using whichever key type the
# hash already uses (symbol or string).
def set_cache_control!(hash, cc)
if hash.any? { |k, _| k.is_a?(Symbol) }
hash[:cache_control] = cc
else
hash["cache_control"] = cc
end
end
# Delete cache_control from a block hash (both key types).
def delete_cache_control!(hash)
hash.delete(:cache_control)
hash.delete("cache_control")
end
# True if the block hash has a cache_control entry (either key type).
def cache_control_present?(hash)
hash.key?(:cache_control) || hash.key?("cache_control")
end
# Retrieve the cache_control value from a block hash.
def get_cache_control(hash)
hash[:cache_control] || hash["cache_control"]
end
# Return the "type" of a block hash regardless of key style.
def block_type(hash)
(hash[:type] || hash["type"]).to_s
end
end
end
end
end
end
|