diff --git a/.gitignore b/.gitignore index 83cf97c..9e21ed6 100644 --- a/.gitignore +++ b/.gitignore @@ -64,11 +64,3 @@ spring_ai/create_user.sql spring_ai/drop.sql src/client/spring_ai/target/classes/* api_server_key -src/client/mcp/rag/optimizer_settings.json -src/client/mcp/rag/pyproject.toml -src/client/mcp/rag/main.py -src/client/mcp/rag/.python-version -src/client/mcp/rag/uv.lock -src/client/mcp/rag/node_modules/ -src/client/mcp/rag/package-lock.json -src/client/mcp/rag/package.json diff --git a/src/client/spring_ai/README.md b/src/client/spring_ai/README.md index 506adab..191392f 100644 --- a/src/client/spring_ai/README.md +++ b/src/client/spring_ai/README.md @@ -20,12 +20,12 @@ Get Started with Java Development](https://docs.oracle.com/en/database/oracle/or Download one of them through the `Download SpringAI` button. Unzip the content and set the executable permission on the `start.sh` with `chmod 755 ./start.sh`. -Edit `start.sh` to add only the DB_PASSWORD not exported, as in this example: +Edit `start.sh` to change the DB_PASSWORD or any other referece/credential changed by the dev env, as in this example: ``` export SPRING_AI_OPENAI_API_KEY=$OPENAI_API_KEY export DB_DSN="jdbc:oracle:thin:@localhost:1521/FREEPDB1" export DB_USERNAME= -export DB_PASSWORD="" +export DB_PASSWORD= export DISTANCE_TYPE=COSINE export OPENAI_CHAT_MODEL=gpt-4o-mini export OPENAI_EMBEDDING_MODEL=text-embedding-3-small @@ -54,13 +54,17 @@ Start with: This project contains a web service that will accept HTTP GET requests at -* `http://localhost:9090/v1/chat/completions`: to use Vector Search via OpenAI REST API +* `http://localhost:9090/v1/chat/completions`: to use RAG via OpenAI REST API +* `http://localhost:9090/v1/models`: return models behind the RAG via OpenAI REST API * `http://localhost:9090/v1/service/llm` : to chat straight with the LLM used * `http://localhost:9090/v1/service/search/`: to search for document similar to the message provided * `http://localhost:9090/v1/service/store-chunks/`: to embedd and store a list of text chunks in the vectorstore -Vector Search call example with `openai` build profile with no-stream: + + +### Completions +RAG call example with `openai` build profile with no-stream: ``` curl -N http://localhost:9090/v1/chat/completions \ @@ -119,23 +123,41 @@ curl -X POST http://localhost:9090/v1/service/store-chunks \ -d '["First chunk of text.", "Second chunk.", "Another example."]' ``` -response: +### Get model name +Return the name of model used. It's useful to integrate ChatGUIs that require the model list before proceed. ``` -[ - [ - -0.014500250108540058, - -0.03604526072740555, - 0.035963304340839386, - 0.010181647725403309, - -0.01610776223242283, - -0.021091962233185768, - 0.03924199938774109, - .. - ] -] +curl http://localhost:9090/v1/models +``` + +## MCP RagTool +The completion service is also available as an MCP server based on the SSE transport protocol. +To test it: + +* Start as usual the microservice: +```shell +./start.sh ``` +* Start the MCP inspector: +```shell +export DANGEROUSLY_OMIT_AUTH=true +npx @modelcontextprotocol/inspector +``` + +* With a web browser open: http://127.0.0.1:6274 + +* Configure: + * Transport Type: SSE + * URL: http://localhost:9090/sse + * set Request Timeout to: 200000 + +* Test a call to `getRag` Tool. + + + + + ## Oracle Backend for Microservices and AI * Add in `application-obaas.yml` the **OPENAI_API_KEY**, if the deployement is based on the OpenAI LLM services: ``` @@ -248,11 +270,6 @@ it should return: ``` - - - - - ## Prerequisites Before using the AI commands, make sure you have a developer token from OpenAI. @@ -269,3 +286,7 @@ export SPRING_AI_OPENAI_API_KEY= Setting the API key is all you need to run the application. However, you can find more information on setting started in the [Spring AI reference documentation section on OpenAI Chat](https://docs.spring.io/spring-ai/reference/api/clients/openai-chat.html). + + + + diff --git a/src/client/spring_ai/pom.xml b/src/client/spring_ai/pom.xml index d70799d..cc57b65 100644 --- a/src/client/spring_ai/pom.xml +++ b/src/client/spring_ai/pom.xml @@ -32,6 +32,11 @@ + + org.springframework.ai + spring-ai-starter-mcp-server-webmvc + + org.springframework.boot spring-boot-starter-web @@ -75,7 +80,7 @@ 23.5.0.24.07 - + org.springframework.boot spring-boot-starter-jdbc @@ -83,7 +88,7 @@ com.oracle.database.spring oracle-spring-boot-starter-wallet 23.4.0 - + diff --git a/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/AIController.java b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/AIController.java index eeda366..1468c41 100644 --- a/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/AIController.java +++ b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/AIController.java @@ -6,22 +6,14 @@ package org.springframework.ai.openai.samples.helloworld; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.ChatClient.ChatClientRequestSpec; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.Embedding; //import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; -import org.springframework.ai.reader.ExtractedTextFormatter; -import org.springframework.ai.reader.pdf.PagePdfDocumentReader; -import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig; -import org.springframework.ai.transformer.splitter.TokenTextSplitter; + import org.springframework.ai.vectorstore.SearchRequest; -import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Qualifier; -import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Lazy; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; @@ -35,21 +27,18 @@ import jakarta.annotation.PostConstruct; -import org.springframework.core.io.Resource; import org.springframework.http.MediaType; import org.springframework.jdbc.core.JdbcTemplate; -import java.io.IOException; -import java.nio.charset.StandardCharsets; + import java.util.List; import java.util.ArrayList; import java.util.Map; import java.util.HashMap; -import java.security.SecureRandom; + import java.time.Instant; import java.util.stream.Collectors; -import java.util.Iterator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,49 +47,60 @@ @RestController class AIController { - @Value("${spring.ai.openai.chat.options.model}") - private String modelOpenAI; - - @Value("${spring.ai.ollama.chat.options.model}") - private String modelOllamaAI; - - @Autowired + private final String modelOpenAI; + private final String modelOllamaAI; + private final ChatClient chatClient; private final OracleVectorStore vectorStore; - - @Autowired private final EmbeddingModel embeddingModel; - - @Autowired - private final ChatClient chatClient; - - @Value("${aims.vectortable.name}") - private String legacyTable; - - @Value("${aims.context_instr}") - private String contextInstr; - - @Value("${aims.rag_params.search_type}") - private String searchType; - - @Value("${aims.rag_params.top_k}") - private int TOPK; - - @Autowired + private final String legacyTable; + private final String contextInstr; + private final String searchType; + private final int TOPK; private JdbcTemplate jdbcTemplate; private static final Logger logger = LoggerFactory.getLogger(AIController.class); private static final int SLEEP = 50; // Wait in streaming between chunks private static final int STREAM_SIZE = 5; // chars in each chunk - AIController(ChatClient chatClient, EmbeddingModel embeddingModel, OracleVectorStore vectorStore) { + @Autowired + private PromptBuilderService promptBuilderService; + @Autowired + private Helper helper; + + AIController( + String modelOpenAI, + String modelOllamaAI, + @Lazy ChatClient chatClient, + EmbeddingModel embeddingModel, + OracleVectorStore vectorStore, + JdbcTemplate jdbcTemplate, + String legacyTable, + String contextInstr, + String searchType, + int TOPK) { + + this.modelOpenAI = modelOpenAI; + this.modelOllamaAI = modelOllamaAI; + this.vectorStore = vectorStore; this.chatClient = chatClient; this.embeddingModel = embeddingModel; - this.vectorStore = vectorStore; + this.legacyTable = legacyTable; + this.contextInstr = contextInstr; + this.searchType = searchType; + this.TOPK = TOPK; + this.jdbcTemplate = jdbcTemplate; } - @GetMapping("/service/llm") + + /** + * Chat completion endpoint to interact with the LLM, without RAG,memory or system prompting. + * No compliant with Open AI API + * + * @param message: the message to be routed to the LLM + */ + @GetMapping("/v1/service/llm") Map completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { return Map.of( @@ -111,6 +111,11 @@ Map completion(@RequestParam(value = "message", defaultValue = " .content()); } + /** + * Create a new table with the Spring AI format from an existing vectorstore made by langchain + * If table already exists, it will not be overrided. + * + */ @PostConstruct public void insertData() { String sqlUser = "SELECT USER FROM DUAL"; @@ -119,7 +124,7 @@ public void insertData() { String newTable = legacyTable + "_SPRINGAI"; user = jdbcTemplate.queryForObject(sqlUser, String.class); - if (doesTableExist(legacyTable, user) != -1) { + if (helper.doesTableExist(legacyTable, user,this.jdbcTemplate) != -1) { // RUNNING LOCAL logger.info("Running local with user: " + user); sql = "INSERT INTO " + user + "." + newTable + " (ID, CONTENT, METADATA, EMBEDDING) " + @@ -131,8 +136,8 @@ public void insertData() { "SELECT ID, TEXT, METADATA, EMBEDDING FROM ADMIN." + legacyTable; } // Execute the insert - logger.info("doesExist" + user + ": " + doesTableExist(newTable, user)); - if (countRecordsInTable(newTable, user) == 0) { + logger.info("doesExist" + user + ": " + helper.doesTableExist(newTable, user,this.jdbcTemplate)); + if (helper.countRecordsInTable(newTable, user,this.jdbcTemplate) == 0) { // First microservice execution logger.info("Table " + user + "." + newTable + " doesn't exist: create from ADMIN/USER." + legacyTable); jdbcTemplate.update(sql); @@ -142,99 +147,16 @@ public void insertData() { } } - public int countRecordsInTable(String tableName, String schemaName) { - // Dynamically construct the SQL query with the table and schema names - String sql = String.format("SELECT COUNT(*) FROM %s.%s", schemaName.toUpperCase(), tableName.toUpperCase()); - logger.info("Checking if table is empty: " + tableName + " in schema: " + schemaName); - - try { - // Execute the query and get the count of records in the table - Integer count = jdbcTemplate.queryForObject(sql, Integer.class); - - // Return the count if it's not null, otherwise return -1 - return count != null ? count : -1; - } catch (Exception e) { - logger.error("Error checking table record count: " + e.getMessage()); - return -1; // Return -1 in case of an error - } - } - - public int doesTableExist(String tableName, String schemaName) { - String sql = "SELECT COUNT(*) FROM all_tables WHERE table_name = ? AND owner = ?"; - logger.info("Checking if table exists: " + tableName + " in schema: " + schemaName); - - try { - // Query the system catalog to check for the existence of the table in the given - // schema - Integer count = jdbcTemplate.queryForObject(sql, Integer.class, tableName.toUpperCase(), - schemaName.toUpperCase()); - - if (count != null && count > 0) { - return count; - } else { - return -1; - } - } catch (Exception e) { - logger.error("Error checking table existence: " + e.getMessage()); - return -1; - } - } - - public Prompt promptEngineering(String message, String contextInstr) { - - String template = """ - DOCUMENTS: - {documents} - - QUESTION: - {question} - INSTRUCTIONS:"""; - - String default_Instr = """ - Answer the users question using the DOCUMENTS text above. - Keep your answer ground in the facts of the DOCUMENTS. - If the DOCUMENTS doesn’t contain the facts to answer the QUESTION, return: - I'm sorry but I haven't enough information to answer. - """; - - // This template doesn't work with re-phrasing/grading pattern, but only via RAG - // The contextInstr coming from Oracle ai optimizer and toolkit can't be used - // here: default only - // Modifiy it to include re-phrasing/grading if you wish. - - template = template + "\n" + default_Instr; - - List similarDocuments = this.vectorStore.similaritySearch( - SearchRequest.builder().query(message).topK(TOPK).build()); - - StringBuilder context = createContext(similarDocuments); - - PromptTemplate promptTemplate = new PromptTemplate(template); - - Prompt prompt = promptTemplate.create(Map.of("documents", context, "question", message)); - - logger.info(prompt.toString()); - - return prompt; - - } - - StringBuilder createContext(List similarDocuments) { - String START = "\n
\n"; - String STOP = "\n
\n"; - - Iterator iterator = similarDocuments.iterator(); - StringBuilder context = new StringBuilder(); - while (iterator.hasNext()) { - Document document = iterator.next(); - context.append(document.getId() + "."); - context.append(START + document.getFormattedContent() + STOP); - } - return context; - } - - @PostMapping(value = "/chat/completions", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + /** + * Chat completion endpoint to interact with the LLM, with RAG support. + * Compliant with Open AI API + * It works also in stream + * + * @param message: the message to be routed to the LLM along the prompt/context + * @return the llm response in one shot or in streaming + */ + @PostMapping(value = "/v1/chat/completions", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public ResponseBodyEmitter streamCompletions(@RequestBody ChatRequest request) { ResponseBodyEmitter bodyEmitter = new ResponseBodyEmitter(); String userMessageContent; @@ -246,7 +168,7 @@ public ResponseBodyEmitter streamCompletions(@RequestBody ChatRequest request) { if (content != null && !content.trim().isEmpty()) { userMessageContent = content; logger.info("user message: " + userMessageContent); - Prompt prompt = promptEngineering(userMessageContent, contextInstr); + Prompt prompt = promptBuilderService.buildPrompt(userMessageContent, contextInstr, TOPK); logger.info("prompt message: " + prompt.getContents()); String contentResponse = chatClient.prompt(prompt).call().content(); logger.info("-------------------------------------------------------"); @@ -259,7 +181,7 @@ public ResponseBodyEmitter streamCompletions(@RequestBody ChatRequest request) { if (request.isStream()) { logger.info("Request is a Stream"); - List chunks = chunkString(contentResponse); + List chunks = helper.chunkString(contentResponse); for (String token : chunks) { ChatMessage messageAnswer = new ChatMessage("assistant", token); @@ -274,10 +196,10 @@ public ResponseBodyEmitter streamCompletions(@RequestBody ChatRequest request) { bodyEmitter.send("data: [DONE]\n\n"); } else { logger.info("Request isn't a Stream"); - String id = "chatcmpl-" + generateRandomToken(28); + String id = "chatcmpl-" + helper.generateRandomToken(28); String object = "chat.completion"; String created = String.valueOf(Instant.now().getEpochSecond()); - String model = getModel(); + String model = helper.getModel(this.modelOpenAI,this.modelOllamaAI); ChatMessage messageAnswer = new ChatMessage("assistant", contentResponse); List choices = List.of(new ChatChoice(messageAnswer)); bodyEmitter.send(new ChatResponse(id, object, created, model, choices)); @@ -298,7 +220,14 @@ public ResponseBodyEmitter streamCompletions(@RequestBody ChatRequest request) { return bodyEmitter; } - @GetMapping("/service/search") + /** + * Similarity search + * + * @param message: the message to be routed to the LLM along the prompt/context + * @param topK: the number of chunks to be included in the context + * @return the list of the nearest topK chunks + */ + @GetMapping("/v1/service/search") List> search(@RequestParam(value = "message", defaultValue = "Tell me a joke") String query, @RequestParam(value = "topk", defaultValue = "5") Integer topK) { @@ -318,12 +247,18 @@ List> search(@RequestParam(value = "message", defaultValue = return resultList; } - @PostMapping("/service/store-chunks") + /** + * Store new chunks, sent as a list of strings in the request body + * + * @param chunks: the list of chunks + * @return the list of vector embeddings created and stored along the chunks + */ + @PostMapping("/v1/service/store-chunks") List> store(@RequestBody List chunks) { List> allVectors = new ArrayList<>(); List documents = chunks.stream() .map(chunk -> { - double[] vector = floatToDouble(embeddingModel.embed(chunk)); + double[] vector = helper.floatToDouble(embeddingModel.embed(chunk)); Double[] sVector = java.util.Arrays.stream(vector) .mapToObj(Double::valueOf) .toArray(Double[]::new); @@ -339,8 +274,14 @@ List> store(@RequestBody List chunks) { return allVectors; } - - @GetMapping("/models") + + /** + * List of model + * + * @param requestBody: the message to be routed to the LLM along the prompt/context + * @return in this case it will be returned a list with only one model on which is based this microservice + */ + @GetMapping("/v1/models") Map models(@RequestBody(required = false) Map requestBody) { String modelId = "custom"; logger.info("models request"); @@ -374,48 +315,5 @@ Map models(@RequestBody(required = false) Map re } } - public List chunkString(String input) { - List chunks = new ArrayList<>(); - int chunkSize = STREAM_SIZE; - - for (int i = 0; i < input.length(); i += chunkSize) { - int end = Math.min(input.length(), i + chunkSize); - chunks.add(input.substring(i, end)); - } - - return chunks; - } - - public String generateRandomToken(int length) { - String CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; - SecureRandom random = new SecureRandom(); - StringBuilder sb = new StringBuilder(length); - for (int i = 0; i < length; i++) { - int index = random.nextInt(CHARACTERS.length()); - sb.append(CHARACTERS.charAt(index)); - } - return sb.toString(); - } - public String getModel() { - String modelId = "custom"; - if (!"".equals(modelOpenAI)) { - modelId = modelOpenAI; - } else if (!"".equals(modelOllamaAI)) { - modelId = modelOllamaAI; - } - return modelId; - } - - public double[] floatToDouble(float[] floatArray) { - double[] doubleArray = new double[floatArray.length]; - - for (int i = 0; i < floatArray.length; i++) { - doubleArray[i] = floatArray[i]; // implicit widening cast per element - } - return doubleArray; - } } - - - diff --git a/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/Application.java b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/Application.java index dc3c1f7..79db8fe 100644 --- a/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/Application.java +++ b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/Application.java @@ -5,7 +5,10 @@ package org.springframework.ai.openai.samples.helloworld; -import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.method.MethodToolCallbackProvider; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.vectorstore.oracle.OracleVectorStore; import org.springframework.beans.factory.annotation.Value; @@ -13,6 +16,9 @@ import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.context.annotation.Bean; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Service; + + @SpringBootApplication public class Application { @@ -33,4 +39,9 @@ OracleVectorStore vectorStore(EmbeddingModel ec, JdbcTemplate t) { return ovs; } + @Bean + public ToolCallbackProvider ragTools(RagService ragService) { + return MethodToolCallbackProvider.builder().toolObjects(ragService).build(); + } + } diff --git a/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/Config.java b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/Config.java index 487fab2..3a4b7c3 100644 --- a/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/Config.java +++ b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/Config.java @@ -8,15 +8,47 @@ import org.springframework.ai.chat.client.ChatClient; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; - +import org.springframework.beans.factory.annotation.Value; @Configuration -class Config { +public class Config { @Bean - ChatClient chatClient(ChatClient.Builder builder) { + public ChatClient chatClient(ChatClient.Builder builder) { return builder.build(); } + + // Optional: Centralize property values if used in multiple places + @Bean + public String modelOpenAI(@Value("${spring.ai.openai.chat.options.model}") String modelOpenAI) { + return modelOpenAI; + } + + @Bean + public String modelOllamaAI(@Value("${spring.ai.ollama.chat.options.model}") String modelOllamaAI) { + return modelOllamaAI; + } + + @Bean + public String legacyTable(@Value("${aims.vectortable.name}") String table) { + return table; + } + + @Bean + public String contextInstr(@Value("${aims.context_instr}") String instr) { + return instr; + } + + @Bean + public String searchType(@Value("${aims.rag_params.search_type}") String searchType) { + return searchType; + } + + @Bean + public Integer topK(@Value("${aims.rag_params.top_k}") int topK) { + return topK; + } + } diff --git a/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/Helper.java b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/Helper.java new file mode 100644 index 0000000..cbf1f87 --- /dev/null +++ b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/Helper.java @@ -0,0 +1,101 @@ +package org.springframework.ai.openai.samples.helloworld; + +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Component; + +@Component +public class Helper { + private static final Logger logger = LoggerFactory.getLogger(PromptBuilderService.class); + + private static final int SLEEP = 50; // Wait in streaming between chunks + private static final int STREAM_SIZE = 5; // chars in each chunk + + public Helper() { + } + + public int countRecordsInTable(String tableName, String schemaName, JdbcTemplate jdbcTemplate) { + // Dynamically construct the SQL query with the table and schema names + String sql = String.format("SELECT COUNT(*) FROM %s.%s", schemaName.toUpperCase(), tableName.toUpperCase()); + logger.info("Checking if table is empty: " + tableName + " in schema: " + schemaName); + + try { + // Execute the query and get the count of records in the table + Integer count = jdbcTemplate.queryForObject(sql, Integer.class); + + // Return the count if it's not null, otherwise return -1 + return count != null ? count : -1; + } catch (Exception e) { + logger.error("Error checking table record count: " + e.getMessage()); + return -1; // Return -1 in case of an error + } + } + + public int doesTableExist(String tableName, String schemaName, JdbcTemplate jdbcTemplate ) { + String sql = "SELECT COUNT(*) FROM all_tables WHERE table_name = ? AND owner = ?"; + logger.info("Checking if table exists: " + tableName + " in schema: " + schemaName); + + try { + // Query the system catalog to check for the existence of the table in the given + // schema + Integer count = jdbcTemplate.queryForObject(sql, Integer.class, tableName.toUpperCase(), + schemaName.toUpperCase()); + + if (count != null && count > 0) { + return count; + } else { + return -1; + } + } catch (Exception e) { + logger.error("Error checking table existence: " + e.getMessage()); + return -1; + } + } + + public List chunkString(String input) { + List chunks = new ArrayList<>(); + int chunkSize = STREAM_SIZE; + + for (int i = 0; i < input.length(); i += chunkSize) { + int end = Math.min(input.length(), i + chunkSize); + chunks.add(input.substring(i, end)); + } + + return chunks; + } + + public String generateRandomToken(int length) { + String CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + SecureRandom random = new SecureRandom(); + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + int index = random.nextInt(CHARACTERS.length()); + sb.append(CHARACTERS.charAt(index)); + } + return sb.toString(); + } + + public String getModel(String modelOpenAI, String modelOllamaAI) { + String modelId = "custom"; + if (!"".equals(modelOpenAI)) { + modelId = modelOpenAI; + } else if (!"".equals(modelOllamaAI)) { + modelId = modelOllamaAI; + } + return modelId; + } + + public double[] floatToDouble(float[] floatArray) { + double[] doubleArray = new double[floatArray.length]; + + for (int i = 0; i < floatArray.length; i++) { + doubleArray[i] = floatArray[i]; // implicit widening cast per element + } + return doubleArray; + } +} diff --git a/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/PromptBuilderService.java b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/PromptBuilderService.java new file mode 100644 index 0000000..89bd4f0 --- /dev/null +++ b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/PromptBuilderService.java @@ -0,0 +1,69 @@ +package org.springframework.ai.openai.samples.helloworld; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.oracle.OracleVectorStore; +import org.springframework.stereotype.Component; + +import java.util.List; +import java.util.Map; + +@Component +public class PromptBuilderService { + + private static final Logger logger = LoggerFactory.getLogger(PromptBuilderService.class); + + private final OracleVectorStore vectorStore; + + public PromptBuilderService(OracleVectorStore vectorStore) { + this.vectorStore = vectorStore; + } + + public Prompt buildPrompt(String message, String contextInstr, int topK) { + String template = """ + DOCUMENTS: + {documents} + + QUESTION: + {question} + + INSTRUCTIONS:"""; + + String defaultInstr = """ + Answer the users question using the DOCUMENTS text above. + Keep your answer ground in the facts of the DOCUMENTS. + If the DOCUMENTS doesn’t contain the facts to answer the QUESTION, return: + I'm sorry but I haven't enough information to answer. + """; + + template += "\n" + defaultInstr; + + List similarDocuments = vectorStore.similaritySearch( + SearchRequest.builder().query(message).topK(topK).build()); + + StringBuilder context = createContext(similarDocuments); + + PromptTemplate promptTemplate = new PromptTemplate(template); + Prompt prompt = promptTemplate.create(Map.of("documents", context, "question", message)); + + logger.info("Generated Prompt:\n{}", prompt.toString()); + + return prompt; + } + + private StringBuilder createContext(List documents) { + String START = "\n
\n"; + String STOP = "\n
\n"; + + StringBuilder context = new StringBuilder(); + for (Document doc : documents) { + context.append(doc.getId()).append("."); + context.append(START).append(doc.getFormattedContent()).append(STOP); + } + return context; + } +} diff --git a/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/RagService.java b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/RagService.java new file mode 100644 index 0000000..0a71830 --- /dev/null +++ b/src/client/spring_ai/src/main/java/org/springframework/ai/openai/samples/helloworld/RagService.java @@ -0,0 +1,78 @@ +package org.springframework.ai.openai.samples.helloworld; + +import org.springframework.stereotype.Service; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.oracle.OracleVectorStore; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Lazy; +import org.springframework.jdbc.core.JdbcTemplate; + +@Service +public class RagService { + + private final String modelOpenAI; + private final String modelOllamaAI; + private final ChatClient chatClient; + private final OracleVectorStore vectorStore; + private final EmbeddingModel embeddingModel; + private final String legacyTable; + private final String contextInstr; + private final String searchType; + private final int TOPK; + private JdbcTemplate jdbcTemplate; + + private static final Logger logger = LoggerFactory.getLogger(RagService.class); + private static final int SLEEP = 50; // Wait in streaming between chunks + private static final int STREAM_SIZE = 5; // chars in each chunk + + @Autowired + private PromptBuilderService promptBuilderService; + + RagService( + String modelOpenAI, + String modelOllamaAI, + @Lazy ChatClient chatClient, + EmbeddingModel embeddingModel, + OracleVectorStore vectorStore, + JdbcTemplate jdbcTemplate, + String legacyTable, + String contextInstr, + String searchType, + int TOPK) { + this.modelOpenAI = modelOpenAI; + this.modelOllamaAI = modelOllamaAI; + this.vectorStore = vectorStore; + this.chatClient = chatClient; + this.embeddingModel = embeddingModel; + this.legacyTable = legacyTable; + this.contextInstr = contextInstr; + this.searchType = searchType; + this.TOPK = TOPK; + this.jdbcTemplate = jdbcTemplate; + } + + @Tool(description = "Use this tool to answer any question that may benefit from up-to-date or domain-specific information.") + public String getRag(String question) { + + // Implementation + Prompt prompt = promptBuilderService.buildPrompt(question, contextInstr, TOPK); + logger.info("prompt message: " + prompt.getContents()); + String contentResponse = chatClient.prompt(prompt).call().content(); + + return (contentResponse); + } +} diff --git a/src/client/spring_ai/src/main/resources/application-dev.yml b/src/client/spring_ai/src/main/resources/application-dev.yml index f61f233..a2d3523 100644 --- a/src/client/spring_ai/src/main/resources/application-dev.yml +++ b/src/client/spring_ai/src/main/resources/application-dev.yml @@ -1,6 +1,4 @@ server: - servlet: - context-path: /v1 port: 9090 spring: datasource: @@ -8,6 +6,18 @@ spring: username: ${DB_USERNAME} password: ${DB_PASSWORD} ai: + mcp: + server: + name: webmvc-mcp-server + version: 1.0.0 + type: SYNC + request-timeout: 120 + instructions: "Use this tool to answer any question that may benefit from up-to-date or domain-specific information." + capabilities: + tool: true + resource: true + prompt: true + completion: true vectorstore: oracle: distance-type: ${DISTANCE_TYPE} @@ -48,4 +58,3 @@ aims: rag_params: search_type: Similarity top_k: ${TOP_K} - diff --git a/src/server/bootstrap/model_def.py b/src/server/bootstrap/model_def.py index d6982b4..b1afb60 100644 --- a/src/server/bootstrap/model_def.py +++ b/src/server/bootstrap/model_def.py @@ -161,6 +161,17 @@ def main() -> list[Model]: "openai_compat": True, "max_chunk_size": 512, }, + { + # This is intentionally last to line up with docos + "name": "all-minilm", + "enabled": os.getenv("ON_PREM_OLLAMA_URL") is not None, + "type": "embed", + "api": "OllamaEmbeddings", + "url": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"), + "api_key": "", + "openai_compat": True, + "max_chunk_size": 256, + }, ] # Check for Duplicates