Skip to content

Dynamic Knowledge bases #1126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions packages/cdk/lambda/predictStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,25 @@ declare global {

export const handler = awslambda.streamifyResponse(
async (event, responseStream, context) => {
context.callbackWaitsForEmptyEventLoop = false;
const model = event.model || defaultModel;
for await (const token of api[model.type].invokeStream?.(
model,
event.messages,
event.id,
event.idToken
) ?? []) {
responseStream.write(token);
try {
context.callbackWaitsForEmptyEventLoop = false;
const model = event.model || defaultModel;
for await (const token of api[model.type].invokeStream?.(
model,
event.messages,
event.id,
event.idToken,
event.kbId
) ?? []) {
responseStream.write(token);
}
responseStream.end();
} catch (error) {
console.error('Error in stream processing:', error);
responseStream.write(
JSON.stringify({ error: 'Stream processing failed' })
);
responseStream.end();
}
responseStream.end();
}
);
5 changes: 3 additions & 2 deletions packages/cdk/lambda/retrieveKnowledgeBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ exports.handler = async (
): Promise<lambda.APIGatewayProxyResult> => {
const req = JSON.parse(event.body!) as RetrieveKnowledgeBaseRequest;
const query = req.query;
const knowledgeBaseId = req.knowledgeBaseId || KNOWLEDGE_BASE_ID;

if (!query) {
return {
Expand All @@ -25,12 +26,12 @@ exports.handler = async (

const client = await initBedrockAgentRuntimeClient({ region: MODEL_REGION });
const retrieveCommand = new RetrieveCommand({
knowledgeBaseId: KNOWLEDGE_BASE_ID,
knowledgeBaseId: knowledgeBaseId,
retrievalQuery: { text: query },
retrievalConfiguration: {
vectorSearchConfiguration: {
numberOfResults: 10,
overrideSearchType: 'HYBRID',
//overrideSearchType: 'HYBRID',
},
},
});
Expand Down
5 changes: 3 additions & 2 deletions packages/cdk/lambda/utils/bedrockKbApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ const bedrockKbApi: ApiInterface = {
model: Model,
messages: UnrecordedMessage[],
id: string,
idToken?: string
idToken?: string,
kbid?: string
) {
try {
// Get explicit filters (async since it may require idToken verification)
Expand All @@ -148,7 +149,7 @@ const bedrockKbApi: ApiInterface = {
retrieveAndGenerateConfiguration: {
type: 'KNOWLEDGE_BASE',
knowledgeBaseConfiguration: {
knowledgeBaseId: process.env.KNOWLEDGE_BASE_ID,
knowledgeBaseId: kbid ?? process.env.KNOWLEDGE_BASE_ID,
modelArn: model.modelId,
retrievalConfiguration: {
vectorSearchConfiguration: {
Expand Down
2 changes: 2 additions & 0 deletions packages/types/src/protocol.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ export type PredictRequest = {
idToken?: string;
messages: UnrecordedMessage[];
id: string;
kbId?: string;
};

export type PredictResponse = string;
Expand Down Expand Up @@ -127,6 +128,7 @@ export type RetrieveKendraResponse = RetrieveCommandOutput;

export type RetrieveKnowledgeBaseRequest = {
query: string;
knowledgeBaseId?: string;
};

export type RetrieveKnowledgeBaseResponse = RetrieveCommandOutputKnowledgeBase;
Expand Down
3 changes: 2 additions & 1 deletion packages/types/src/utils.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ export type InvokeStreamInterface = (
model: Model,
messages: UnrecordedMessage[],
id: string,
idToken?: string
idToken?: string,
kbId?: string
) => AsyncIterable<string>;

// Return Base64 encoded image
Expand Down
19 changes: 13 additions & 6 deletions packages/web/src/hooks/useChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ const useChatState = create<{
overrideModelType: Model['type'] | undefined,
setSessionId: (sessionId: string) => void,
base64Cache: Record<string, string> | undefined,
overrideModelParameters: AdditionalModelRequestFields | undefined
overrideModelParameters: AdditionalModelRequestFields | undefined,
knowledgeBaseId: string | undefined
) => void;
edit: (
id: string,
Expand Down Expand Up @@ -439,7 +440,8 @@ const useChatState = create<{
base64Cache: Record<string, string> | undefined = undefined,
overrideModelParameters:
| AdditionalModelRequestFields
| undefined = undefined
| undefined = undefined,
selectedKnowledgeBaseId: string | undefined = undefined
) => {
const modelId = get().modelIds[id];

Expand Down Expand Up @@ -538,6 +540,7 @@ const useChatState = create<{
model: model,
messages: formattedMessages,
id: id,
kbId: selectedKnowledgeBaseId,
});

// Update the assistant's message
Expand Down Expand Up @@ -763,7 +766,8 @@ const useChatState = create<{
base64Cache: Record<string, string> | undefined = undefined,
overrideModelParameters:
| AdditionalModelRequestFields
| undefined = undefined
| undefined = undefined,
selectedKnowledgeBaseId: string | undefined = undefined
) => {
const unrecordedUserMessage: UnrecordedMessage = {
role: 'user',
Expand Down Expand Up @@ -816,7 +820,8 @@ const useChatState = create<{
overrideModelType,
setSessionId,
base64Cache,
overrideModelParameters
overrideModelParameters,
selectedKnowledgeBaseId
);
},

Expand Down Expand Up @@ -1010,7 +1015,8 @@ const useChat = (id: string, chatId?: string) => {
base64Cache: Record<string, string> | undefined = undefined,
overrideModelParameters:
| AdditionalModelRequestFields
| undefined = undefined
| undefined = undefined,
selectedKnowledgeBaseId: string | undefined = undefined
) => {
post(
id,
Expand All @@ -1025,7 +1031,8 @@ const useChat = (id: string, chatId?: string) => {
overrideModelType,
setSessionId,
base64Cache,
overrideModelParameters
overrideModelParameters,
selectedKnowledgeBaseId
);
},
editChat: (
Expand Down
68 changes: 68 additions & 0 deletions packages/web/src/hooks/useKnowledgeBases.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import { useState, useEffect } from 'react';
import {
BedrockAgent,
KnowledgeBaseSummary,
} from '@aws-sdk/client-bedrock-agent';
import { fromCognitoIdentityPool } from '@aws-sdk/credential-provider-cognito-identity';
import { CognitoIdentityClient } from '@aws-sdk/client-cognito-identity';
import { fetchAuthSession } from 'aws-amplify/auth';

export const useKnowledgeBases = () => {
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBaseSummary[]>(
[]
);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<Error | null>(null);

useEffect(() => {
const region = import.meta.env.VITE_APP_REGION;
const userPoolId = import.meta.env.VITE_APP_USER_POOL_ID;
const idPoolId = import.meta.env.VITE_APP_IDENTITY_POOL_ID;
const cognito = new CognitoIdentityClient({ region });
const providerName = `cognito-idp.${region}.amazonaws.com/${userPoolId}`;
const fetchKnowledgeBases = async () => {
try {
const token = (await fetchAuthSession()).tokens?.idToken?.toString();
if (!token) {
throw new Error('Not authenticated');
}
const client = new BedrockAgent({
region: region,
credentials: fromCognitoIdentityPool({
client: cognito,
identityPoolId: idPoolId,
logins: {
[providerName]: token,
},
}),
});

//let nextToken: string | undefined;
//const allKnowledgeBases: KnowledgeBaseSummary[] = [];

const response = await client.listKnowledgeBases({
maxResults: 10,
});
const allKnowledgeBases = response.knowledgeBaseSummaries || [];
setKnowledgeBases(allKnowledgeBases);
} catch (err) {
console.error('Error fetching knowledge bases:', err);
setError(
err instanceof Error
? err
: new Error('Failed to fetch knowledge bases')
);
} finally {
setLoading(false);
}
};

fetchKnowledgeBases();
}, [knowledgeBases]);

return {
knowledgeBaseIds: knowledgeBases.map((kb) => kb.knowledgeBaseId || ''),
loading,
error,
};
};
4 changes: 3 additions & 1 deletion packages/web/src/hooks/useRagKnowledgeBaseApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ import useHttp from './useHttp';

const useRagKnowledgeBaseApi = () => {
const http = useHttp();
console.log('calling /rag-knowledge-base/retrieve');
return {
retrieve: (query: string) => {
retrieve: (query: string, knowledgeBaseId?: string) => {
return http.post<
RetrieveKnowledgeBaseResponse,
RetrieveKnowledgeBaseRequest
>('/rag-knowledge-base/retrieve', {
query,
knowledgeBaseId,
});
},
};
Expand Down
Loading