summaryrefslogtreecommitdiffhomepage
path: root/packages/credential-store/src/registry.ts
blob: 89d8084e3c91f68de8f9a30032a142e90022c29a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import type { ModelInfo, ProviderContract } from "@dispatch/kernel";

export interface Credential {
	readonly name: string;
	readonly providerId: string;
}

export interface ResolvedModel {
	readonly providerId: string;
	readonly model: string;
}

export interface CredentialStore {
	/**
	 * Split a model name on the FIRST "/": name=before, model=after (model may contain "/").
	 * Look up the credential by name; return its providerId + the model id, or undefined if
	 * the name is unknown or there is no model segment.
	 */
	resolve(modelName: string): ResolvedModel | undefined;

	/**
	 * The model catalog: for each credential, look up its provider and call listModels(),
	 * emitting `${credential.name}/${modelInfo.id}`. Skip credentials whose provider is
	 * missing or has no listModels.
	 */
	listCatalog(): Promise<readonly string[]>;

	/**
	 * Returns the full `ModelInfo` for a `<credentialName>/<model>` string, or
	 * undefined if unknown. Caches the result of `listModels` per credential.
	 * Used to look up `contextWindow` for auto-compaction.
	 */
	getModelInfo(modelName: string): Promise<ModelInfo | undefined>;
}

export interface CredentialStoreDeps {
	readonly credentials: readonly Credential[];
	readonly getProvider: (id: string) => ProviderContract | undefined;
}

export function createCredentialStore(deps: CredentialStoreDeps): CredentialStore {
	const credentialMap = new Map<string, string>();
	for (const credential of deps.credentials) {
		credentialMap.set(credential.name, credential.providerId);
	}

	return {
		resolve(modelName: string): ResolvedModel | undefined {
			const slashIndex = modelName.indexOf("/");
			if (slashIndex === -1) {
				return undefined;
			}

			const credentialName = modelName.slice(0, slashIndex);
			const model = modelName.slice(slashIndex + 1);

			if (!model) {
				return undefined;
			}

			const providerId = credentialMap.get(credentialName);
			if (!providerId) {
				return undefined;
			}

			return { providerId, model };
		},

		async listCatalog(): Promise<readonly string[]> {
			const results: string[] = [];

			for (const credential of deps.credentials) {
				const provider = deps.getProvider(credential.providerId);
				if (!provider?.listModels) {
					continue;
				}

				const models = await provider.listModels();
				for (const model of models) {
					results.push(`${credential.name}/${model.id}`);
				}
			}

			return results;
		},

		async getModelInfo(modelName: string): Promise<ModelInfo | undefined> {
			const slashIndex = modelName.indexOf("/");
			if (slashIndex === -1) return undefined;
			const credentialName = modelName.slice(0, slashIndex);
			const modelId = modelName.slice(slashIndex + 1);
			if (!modelId) return undefined;

			const providerId = credentialMap.get(credentialName);
			if (providerId === undefined) return undefined;

			const provider = deps.getProvider(providerId);
			if (provider?.listModels === undefined) return undefined;

			const models = await provider.listModels();
			return models.find((m) => m.id === modelId);
		},
	};
}