From ea9223c7a09e79890a0728149e8e5a056403efba Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Fri, 23 May 2025 18:32:52 +0200 Subject: [PATCH 01/20] [WIP] Proof of Concept integration with LangChain4j --- core/pom.xml | 12 + .../com/google/adk/models/LangChain4j.java | 288 ++++++++++++++++++ .../google/adk/models/LangChain4jTest.java | 154 ++++++++++ 3 files changed, 454 insertions(+) create mode 100644 core/src/main/java/com/google/adk/models/LangChain4j.java create mode 100644 core/src/test/java/com/google/adk/models/LangChain4jTest.java diff --git a/core/pom.xml b/core/pom.xml index 231263e..93f10b3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -63,6 +63,7 @@ 1.6.0 2.19.0 4.12.0 + 1.0.1 @@ -169,6 +170,17 @@ protobuf-javalite ${protobuf.version} + + dev.langchain4j + langchain4j-core + ${langchain4j.version} + + + dev.langchain4j + langchain4j-anthropic + 1.0.1-beta6 + test + org.java-websocket Java-WebSocket diff --git a/core/src/main/java/com/google/adk/models/LangChain4j.java b/core/src/main/java/com/google/adk/models/LangChain4j.java new file mode 100644 index 0000000..c654805 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/LangChain4j.java @@ -0,0 +1,288 @@ +package com.google.adk.models; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.tools.FunctionTool; +import com.google.genai.types.*; +import com.google.genai.types.Content; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.*; +import dev.langchain4j.exception.UnsupportedFeatureException; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.request.json.*; +import dev.langchain4j.model.chat.response.ChatResponse; +import io.reactivex.rxjava3.core.Flowable; + +import java.util.*; + +import static com.google.genai.types.Type.Known.OBJECT; + +public class LangChain4j extends BaseLlm { + + private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference<>() { + }; + + private final ChatModel chatModel; + private final StreamingChatModel streamingChatModel; + private final ObjectMapper objectMapper; + + public LangChain4j(ChatModel chatModel) { // TODO + super(chatModel.defaultRequestParameters().modelName()); + this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); + this.streamingChatModel = null; + this.objectMapper = new ObjectMapper(); + } + + public LangChain4j(ChatModel chatModel, String modelName) { // TODO + super(modelName); + this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); + this.streamingChatModel = null; + this.objectMapper = new ObjectMapper(); + } + + public LangChain4j(StreamingChatModel streamingChatModel) { // TODO + super(streamingChatModel.defaultRequestParameters().modelName()); + this.chatModel = null; + this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); + this.objectMapper = new ObjectMapper(); + } + + public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { // TODO + super(modelName); + this.chatModel = null; + this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); + this.objectMapper = new ObjectMapper(); + } + + @Override + public Flowable generateContent(LlmRequest llmRequest, boolean stream) { + if (stream) { + if (this.streamingChatModel == null) { + return Flowable.error(new IllegalStateException("StreamingChatModel is not configured")); + } + + // TODO + throw new UnsupportedOperationException("Streaming is not supported for LangChain4j models yet."); + } else { + if (this.chatModel == null) { + return Flowable.error(new IllegalStateException("ChatModel is not configured")); + } + + ChatRequest chatRequest = toChatRequest(llmRequest); + ChatResponse chatResponse = chatModel.chat(chatRequest); + LlmResponse llmResponse = toLlmResponse(chatResponse); + return Flowable.just(llmResponse); + } + } + + private ChatRequest toChatRequest(LlmRequest llmRequest) { + // TODO llmRequest.model() ? + return ChatRequest.builder() + .messages(toMessages(llmRequest)) + .toolSpecifications(toToolSpecifications(llmRequest)) + // TODO + .build(); + } + + private List toMessages(LlmRequest llmRequest) { + List messages = new ArrayList<>(); + messages.addAll(llmRequest.getSystemInstructions().stream().map(SystemMessage::from).toList()); + messages.addAll(llmRequest.contents().stream().map(this::toChatMessage).toList()); + return messages; + } + + private ChatMessage toChatMessage(Content content) { + String role = content.role().orElseThrow().toLowerCase(); // TODO + return switch (role) { + case "user" -> toUserOrToolResultMessage(content); + case "model", "assistant" -> toAiMessage(content); + default -> throw new IllegalStateException("Unexpected role: " + role); + }; + } + + private ChatMessage toUserOrToolResultMessage(Content content) { + + List texts = new ArrayList<>(); + ToolExecutionResultMessage toolExecutionResultMessage = null; + + for (Part part : content.parts().orElse(List.of())) { + if (part.text().isPresent()) { + texts.add(part.text().get()); + } else if (part.functionResponse().isPresent()) { + // TODO multiple tool calls? + FunctionResponse functionResponse = part.functionResponse().get(); + toolExecutionResultMessage = ToolExecutionResultMessage.from( + functionResponse.id().orElseThrow(), + functionResponse.name().orElseThrow(), + toJson(functionResponse.response().orElseThrow()) + ); + } else { + throw new IllegalStateException("Either text or functionCall is expected, but was: " + part); + } + } + + if (toolExecutionResultMessage != null) { + return toolExecutionResultMessage; + } else { + return UserMessage.from(String.join("\n", texts)); + } + } + + private AiMessage toAiMessage(Content content) { + + List texts = new ArrayList<>(); + List toolExecutionRequests = new ArrayList<>(); + + content.parts().orElse(List.of()).forEach(part -> { + if (part.text().isPresent()) { + texts.add(part.text().get()); + } else if (part.functionCall().isPresent()) { + FunctionCall functionCall = part.functionCall().get(); + ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .id(functionCall.id().orElseThrow()) + .name(functionCall.name().orElseThrow()) + .arguments(toJson(functionCall.args().orElseThrow())) + .build(); + toolExecutionRequests.add(toolExecutionRequest); + } else { + throw new IllegalStateException("Either text or functionCall is expected, but was: " + part); + } + }); + + return AiMessage.builder() + .text(String.join("\n", texts)) + .toolExecutionRequests(toolExecutionRequests) + .build(); + } + + private String toJson(Object object) { + try { + return objectMapper.writeValueAsString(object); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private List toToolSpecifications(LlmRequest llmRequest) { + + List toolSpecifications = new ArrayList<>(); + + llmRequest.tools().values() + .forEach(baseTool -> { + if (baseTool instanceof FunctionTool functionTool) { // TODO MCP tool, LongRunningFunctionTool, etc + if (functionTool.declaration().isPresent()) { + FunctionDeclaration functionDeclaration = functionTool.declaration().get(); + if (functionDeclaration.parameters().isPresent()) { + Schema schema = functionDeclaration.parameters().get(); + ToolSpecification toolSpecification = ToolSpecification.builder() + .name(functionTool.name()) + .description(functionTool.description()) + .parameters(toParameters(schema)) // TODO + .build(); + toolSpecifications.add(toolSpecification); + } + } + } else { + throw new UnsupportedOperationException("LangChain4jLlm does not support tool of type: " + baseTool.getClass()); + } + }); + + return toolSpecifications; // TODO + } + + private JsonObjectSchema toParameters(Schema schema) { + if (schema.type().isPresent() && schema.type().get().knownEnum().equals(OBJECT)) { + + return JsonObjectSchema.builder() + .addProperties(toProperties(schema)) + .required(schema.required().orElse(List.of())) + .build(); // TODO + } else { + throw new UnsupportedOperationException("LangChain4jLlm does not support schema of type: " + schema.type()); + } + } + + private Map toProperties(Schema schema) { + Map properties = schema.properties().orElse(Map.of()); + Map result = new HashMap<>(); + properties.forEach((k, v) -> result.put(k, toJsonSchemaElement(v))); + return result; + } + + private JsonSchemaElement toJsonSchemaElement(Schema schema) { + Type type = schema.type().get(); // TODO + return switch (type.knownEnum()) { + case STRING -> JsonStringSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case NUMBER -> JsonNumberSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case INTEGER -> JsonIntegerSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case BOOLEAN -> JsonBooleanSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case ARRAY -> JsonArraySchema.builder() + .description(schema.description().orElse(null)) + .items(toJsonSchemaElement(schema.items().orElseThrow())) + .build(); + case OBJECT -> toParameters(schema); + case TYPE_UNSPECIFIED -> + throw new UnsupportedFeatureException("LangChain4jLlm does not support schema of type: " + type); + }; + } + + private LlmResponse toLlmResponse(ChatResponse chatResponse) { + + Content content = Content.builder() + .role("model") + .parts(toParts(chatResponse.aiMessage())) + .build(); + + return LlmResponse.builder() + .content(content) + .build(); + } + + private List toParts(AiMessage aiMessage) { + if (aiMessage.hasToolExecutionRequests()) { + List parts = new ArrayList<>(); + aiMessage.toolExecutionRequests().forEach(toolExecutionRequest -> { + FunctionCall functionCall = FunctionCall.builder() + .id(toolExecutionRequest.id() != null ? toolExecutionRequest.id() : UUID.randomUUID().toString()) + .name(toolExecutionRequest.name()) + .args(toArgs(toolExecutionRequest)) + .build(); + Part part = Part.builder() + .functionCall(functionCall) + .build(); + parts.add(part); + }); + return parts; + } else { + Part part = Part.builder() + .text(aiMessage.text()) + .build(); + return List.of(part); + } + } + + private Map toArgs(ToolExecutionRequest toolExecutionRequest) { + try { + return objectMapper.readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public BaseLlmConnection connect(LlmRequest llmRequest) { + throw new UnsupportedOperationException("Live connection is not supported for LangChain4j models."); + } +} \ No newline at end of file diff --git a/core/src/test/java/com/google/adk/models/LangChain4jTest.java b/core/src/test/java/com/google/adk/models/LangChain4jTest.java new file mode 100644 index 0000000..def3af1 --- /dev/null +++ b/core/src/test/java/com/google/adk/models/LangChain4jTest.java @@ -0,0 +1,154 @@ +package com.google.adk.models; + +import static org.junit.jupiter.api.Assertions.*; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.RunConfig; +import com.google.adk.events.Event; +import com.google.adk.runner.InMemoryRunner; +import com.google.adk.runner.Runner; +import com.google.adk.sessions.Session; +import com.google.adk.tools.Annotations.Schema; +import com.google.adk.tools.FunctionTool; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import dev.langchain4j.model.anthropic.AnthropicChatModel; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_7_SONNET_20250219; + +public class LangChain4jTest { + @Test + void testSimpleAgent() { + // given + AnthropicChatModel claudeModel = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(CLAUDE_3_7_SONNET_20250219) + .build(); + + LlmAgent agent = LlmAgent.builder() + .name("science-app") + .description("Science teacher agent") + .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219.toString())) + .instruction(""" + You are a helpful science teacher that explains science concepts + to kids and teenagers. + """) + .build(); + + // when + List events = askAgent(agent, "What is a qubit?"); + + // then + assertEquals(1, events.size()); + + Event firstEvent = events.get(0); + assertTrue(firstEvent.content().isPresent()); + + Content content = firstEvent.content().get(); + System.out.println("Answer: " + content.text()); + assertTrue(content.text().contains("quantum")); + } + + @Test + void testSingleAgentWithTools() { + // given + AnthropicChatModel claudeModel = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(CLAUDE_3_7_SONNET_20250219) + .build(); + + BaseAgent agent = LlmAgent.builder() + .name("friendly-weather-app") + .description("Friend agent that knows about the weather") + .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219.toString())) + .instruction(""" + You are a friendly assistant. + + If asked about the weather forecast for a city, + you MUST call the `getWeather` function. + """) + .tools(FunctionTool.create(LangChain4jTest.class, "getWeather")) + .build(); + + // when + List events = askAgent(agent, "What's the weather like in Paris?"); + + // then + assertEquals(3, events.size()); + + events.forEach(event -> { + assertTrue(event.content().isPresent()); + System.out.printf("%nevent: %s%n", event.stringifyContent()); + }); + + Event eventOne = events.get(0); + Event eventTwo = events.get(1); + Event eventThree = events.get(2); + + // assert the first event is a function call + Content contentOne = eventOne.content().get(); + assertTrue(contentOne.parts().isPresent()); + List partsOne = contentOne.parts().get(); + assertEquals(1, partsOne.size()); + Optional functionCall = partsOne.get(0).functionCall(); + assertTrue(functionCall.isPresent()); + assertEquals("getWeather", functionCall.get().name().get()); + assertTrue(functionCall.get().args().get().containsKey("city")); + + // assert the second event is a function response + Content contentTwo = eventTwo.content().get(); + assertTrue(contentTwo.parts().isPresent()); + List partsTwo = contentTwo.parts().get(); + assertEquals(1, partsTwo.size()); + Optional functionResponseTwo = partsTwo.get(0).functionResponse(); + assertTrue(functionResponseTwo.isPresent()); + assertTrue(functionResponseTwo.get().response().get().get("city").toString().contains("Paris")); + assertTrue(functionResponseTwo.get().response().get().get("forecast").toString().contains("sunny")); + assertTrue(functionResponseTwo.get().response().get().get("temperature").toString().contains("24")); + + // assert the third event is the final text response + assertTrue(eventThree.finalResponse()); + Content contentThree = eventThree.content().get(); + assertTrue(contentThree.parts().isPresent()); + List partsThree = contentThree.parts().get(); + assertEquals(1, partsThree.size()); + assertTrue(partsThree.get(0).text().get().contains("beautiful")); + } + + private static List askAgent(BaseAgent agent, String... messages) { + ArrayList allEvents = new ArrayList<>(); + + Runner runner = new InMemoryRunner(agent, agent.name()); + Session session = runner.sessionService().createSession(agent.name(), "user132").blockingGet(); + + for (String message : messages) { + Content messageContent = Content.fromParts(Part.fromText(message)); + allEvents.addAll( + runner.runAsync(session, messageContent, RunConfig.builder().build()) + .blockingStream().toList() + ); + } + + return allEvents; + } + + @Schema(description = "Function to get the weather forecast for a given city") + public static Map getWeather( + @Schema(name = "city", description = "The city to get the weather forecast for") + String city) { + return Map.of( + "city", city, + "forecast", "a beautiful and sunny weather", + "temperature", "from 10°C in the morning up to 24°C in the afternoon" + ); + } +} From d705754878b279bd2ec6fa09450743ac5df4d1a5 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Mon, 26 May 2025 16:39:01 +0200 Subject: [PATCH 02/20] [WIP] Proof of Concept integration with LangChain4j Moved this code into the new contrib directory --- contrib/pom.xml | 360 ++++++++++++++++++ .../adk/models/langchain4j}/LangChain4j.java | 38 +- .../models/langchain4j}/LangChain4jTest.java | 11 +- core/pom.xml | 11 - 4 files changed, 396 insertions(+), 24 deletions(-) create mode 100644 contrib/pom.xml rename {core/src/main/java/com/google/adk/models => contrib/src/test/java/com/google/adk/models/langchain4j}/LangChain4j.java (89%) rename {core/src/test/java/com/google/adk/models => contrib/src/test/java/com/google/adk/models/langchain4j}/LangChain4jTest.java (97%) diff --git a/contrib/pom.xml b/contrib/pom.xml new file mode 100644 index 0000000..0f774e6 --- /dev/null +++ b/contrib/pom.xml @@ -0,0 +1,360 @@ + + + + + 4.0.0 + com.google.adk + google-adk-contrib + 0.1.0 + jar + Agent Development Kit - Contributions + https://github.com/google/adk-java + + + The Apache License, Version 2.0 + https://www.apache.org/licenses/LICENSE-2.0 + + + + scm:git:git@github.com/google:adk-java.git + + scm:git:git@github.com/google:adk-java.git + + git@github.com/google:adk-java.git + + + + Google Inc. + http://www.google.com + + + + Third-party contributions, integrations, and plugins for Agent Development Kit. + + + UTF-8 + 17 + ${java.version} + 0.10.0 + 2.38.0 + 1.33.1 + 2.28.0 + 1.0.0 + 1.11.0 + 4.31.0-RC1 + 5.11.4 + 5.17.0 + 1.6.0 + 2.19.0 + 4.12.0 + 1.0.1 + 1.0.1 + 1.0.1-beta6 + 1.0.1-beta6 + + + + + dev.langchain4j + langchain4j-core + ${langchain4j.version} + + + com.google.adk + google-adk + 0.1.0 + + + com.google.genai + google-genai + ${google.genai.version} + + + io.modelcontextprotocol.sdk + mcp + ${mcp-schema.version} + + + + + dev.langchain4j + langchain4j-anthropic + ${langchain4j.anthropic.version} + test + + + dev.langchain4j + langchain4j-open-ai + ${langchain4j.openai.version} + test + + + dev.langchain4j + langchain4j-google-ai-gemini + ${langchain4j.gemini.version} + test + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test + + + org.junit.vintage + junit-vintage-engine + ${junit.version} + test + + + com.google.truth + truth + 1.4.4 + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + + + ossrh + Central Repository OSSRH + https://google.oss.sonatype.org/service/local/staging/deploy/maven2/ + + + ossrh + Central Repository OSSRH for snapshots + https://google.oss.sonatype.org/content/repositories/snapshots + + + + + + com.google.cloud.artifactregistry + artifactregistry-maven-wagon + 2.2.0 + + + + + + maven-clean-plugin + 3.1.0 + + + maven-resources-plugin + 3.0.2 + + + maven-compiler-plugin + 3.13.0 + + ${java.version} + ${java.version} + ${maven.compiler.release} + true + + + com.google.auto.value + auto-value + ${auto-value.version} + + + + + + maven-surefire-plugin + 3.5.2 + + + me.fabriciorby + maven-surefire-junit5-tree-reporter + 0.1.0 + + + + plain + + + **/*Test.java + + + + + maven-jar-plugin + 3.0.2 + + + maven-install-plugin + 2.5.2 + + + maven-deploy-plugin + 3.1.1 + + false + + + + maven-site-plugin + 3.7.1 + + + maven-project-info-reports-plugin + 3.0.0 + + + org.apache.maven.plugins + maven-gpg-plugin + 3.2.7 + + + sign-artifacts + verify + + sign + + + + + + org.apache.maven.plugins + maven-source-plugin + 3.3.1 + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.6.3 + + all,-missing + true + ${project.build.directory}/javadoc + Agent Development Kit + ${maven.compiler.release} + UTF-8 + + + + attach-javadocs + + jar + + + + + + org.sonatype.plugins + nexus-staging-maven-plugin + 1.7.0 + true + + ossrh + https://google.oss.sonatype.org/ + false + + + + + + + org.jacoco + jacoco-maven-plugin + 0.8.12 + + + + prepare-agent + + + + *MockitoMock* + *$$EnhancerByMockitoWithCGLIB$$* + *$$FastClassByMockitoWithCGLIB$$* + com/sun/tools/attach/* + sun/util/resources/cldr/provider/* + + + + + report + test + + report + + + + HTML + + + + + + + + + + release + + + + org.apache.maven.plugins + maven-source-plugin + + + attach-sources + + jar-no-fork + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + + attach-javadocs + + jar + + + + + + + + + \ No newline at end of file diff --git a/core/src/main/java/com/google/adk/models/LangChain4j.java b/contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4j.java similarity index 89% rename from core/src/main/java/com/google/adk/models/LangChain4j.java rename to contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4j.java index c654805..8f3d503 100644 --- a/core/src/main/java/com/google/adk/models/LangChain4j.java +++ b/contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -1,25 +1,47 @@ -package com.google.adk.models; +package com.google.adk.models.langchain4j; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.BaseLlm; +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; import com.google.adk.tools.FunctionTool; -import com.google.genai.types.*; import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import com.google.genai.types.Schema; +import com.google.genai.types.Type; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.data.message.*; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.exception.UnsupportedFeatureException; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; -import dev.langchain4j.model.chat.request.json.*; +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonNumberSchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import io.reactivex.rxjava3.core.Flowable; -import java.util.*; - -import static com.google.genai.types.Type.Known.OBJECT; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; public class LangChain4j extends BaseLlm { @@ -195,7 +217,7 @@ private List toToolSpecifications(LlmRequest llmRequest) { } private JsonObjectSchema toParameters(Schema schema) { - if (schema.type().isPresent() && schema.type().get().knownEnum().equals(OBJECT)) { + if (schema.type().isPresent() && schema.type().get().knownEnum().equals(Type.Known.OBJECT)) { return JsonObjectSchema.builder() .addProperties(toProperties(schema)) diff --git a/core/src/test/java/com/google/adk/models/LangChain4jTest.java b/contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java similarity index 97% rename from core/src/test/java/com/google/adk/models/LangChain4jTest.java rename to contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index def3af1..29a5104 100644 --- a/core/src/test/java/com/google/adk/models/LangChain4jTest.java +++ b/contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -1,4 +1,4 @@ -package com.google.adk.models; +package com.google.adk.models.langchain4j; import static org.junit.jupiter.api.Assertions.*; @@ -23,9 +23,10 @@ import java.util.Map; import java.util.Optional; -import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_7_SONNET_20250219; - public class LangChain4jTest { + + static final String CLAUDE_3_7_SONNET_20250219 = "claude-3-7-sonnet-20250219"; + @Test void testSimpleAgent() { // given @@ -37,7 +38,7 @@ void testSimpleAgent() { LlmAgent agent = LlmAgent.builder() .name("science-app") .description("Science teacher agent") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219.toString())) + .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) .instruction(""" You are a helpful science teacher that explains science concepts to kids and teenagers. @@ -69,7 +70,7 @@ void testSingleAgentWithTools() { BaseAgent agent = LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219.toString())) + .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) .instruction(""" You are a friendly assistant. diff --git a/core/pom.xml b/core/pom.xml index 93f10b3..d4afdd4 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -170,17 +170,6 @@ protobuf-javalite ${protobuf.version} - - dev.langchain4j - langchain4j-core - ${langchain4j.version} - - - dev.langchain4j - langchain4j-anthropic - 1.0.1-beta6 - test - org.java-websocket Java-WebSocket From 0f265a50bf7561a26c49f33e0b506225dbe7dc4b Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Mon, 26 May 2025 16:52:51 +0200 Subject: [PATCH 03/20] [WIP] Proof of Concept integration with LangChain4j Accidently moved the model class in the test directory --- .../java/com/google/adk/models/langchain4j/LangChain4j.java | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename contrib/src/{test => main}/java/com/google/adk/models/langchain4j/LangChain4j.java (100%) diff --git a/contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java similarity index 100% rename from contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4j.java rename to contrib/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java From 2b5707e99d029331c1d1fb3c8d71c72cc93b27a0 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Sat, 31 May 2025 11:42:04 +0200 Subject: [PATCH 04/20] [WIP] Cover other types of tools --- .../adk/models/langchain4j/LangChain4j.java | 42 +++++++---- .../models/langchain4j/LangChain4jTest.java | 70 +++++++++++++++++-- 2 files changed, 94 insertions(+), 18 deletions(-) diff --git a/contrib/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 8f3d503..2be29ed 100644 --- a/contrib/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -1,3 +1,18 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.models.langchain4j; import com.fasterxml.jackson.core.JsonProcessingException; @@ -7,7 +22,6 @@ import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; -import com.google.adk.tools.FunctionTool; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionDeclaration; @@ -195,21 +209,21 @@ private List toToolSpecifications(LlmRequest llmRequest) { llmRequest.tools().values() .forEach(baseTool -> { - if (baseTool instanceof FunctionTool functionTool) { // TODO MCP tool, LongRunningFunctionTool, etc - if (functionTool.declaration().isPresent()) { - FunctionDeclaration functionDeclaration = functionTool.declaration().get(); - if (functionDeclaration.parameters().isPresent()) { - Schema schema = functionDeclaration.parameters().get(); - ToolSpecification toolSpecification = ToolSpecification.builder() - .name(functionTool.name()) - .description(functionTool.description()) - .parameters(toParameters(schema)) // TODO - .build(); - toolSpecifications.add(toolSpecification); - } + if (baseTool.declaration().isPresent()) { + FunctionDeclaration functionDeclaration = baseTool.declaration().get(); + if (functionDeclaration.parameters().isPresent()) { + Schema schema = functionDeclaration.parameters().get(); + ToolSpecification toolSpecification = ToolSpecification.builder() + .name(baseTool.name()) + .description(baseTool.description()) + .parameters(toParameters(schema)) // TODO + .build(); + toolSpecifications.add(toolSpecification); + } else { + throw new IllegalStateException("Tool lacking parameters: " + baseTool); } } else { - throw new UnsupportedOperationException("LangChain4jLlm does not support tool of type: " + baseTool.getClass()); + throw new IllegalStateException("Tool lacking declaration: " + baseTool); } }); diff --git a/contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 29a5104..d90b297 100644 --- a/contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -9,6 +9,7 @@ import com.google.adk.runner.InMemoryRunner; import com.google.adk.runner.Runner; import com.google.adk.sessions.Session; +import com.google.adk.tools.AgentTool; import com.google.adk.tools.Annotations.Schema; import com.google.adk.tools.FunctionTool; import com.google.genai.types.Content; @@ -16,16 +17,27 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import dev.langchain4j.model.anthropic.AnthropicChatModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; public class LangChain4jTest { - static final String CLAUDE_3_7_SONNET_20250219 = "claude-3-7-sonnet-20250219"; + public static final String CLAUDE_3_7_SONNET_20250219 = "claude-3-7-sonnet-20250219"; + public static final String GEMINI_2_0_FLASH = "gemini-2.0-flash"; + public static final String GPT_4_O_MINI = "gpt-4o-mini"; + + @BeforeAll + public static void setUp() { + assertNotNull(System.getenv("ANTHROPIC_API_KEY")); + assertNotNull(System.getenv("GOOGLE_API_KEY")); + } @Test void testSimpleAgent() { @@ -112,9 +124,6 @@ void testSingleAgentWithTools() { assertEquals(1, partsTwo.size()); Optional functionResponseTwo = partsTwo.get(0).functionResponse(); assertTrue(functionResponseTwo.isPresent()); - assertTrue(functionResponseTwo.get().response().get().get("city").toString().contains("Paris")); - assertTrue(functionResponseTwo.get().response().get().get("forecast").toString().contains("sunny")); - assertTrue(functionResponseTwo.get().response().get().get("temperature").toString().contains("24")); // assert the third event is the final text response assertTrue(eventThree.finalResponse()); @@ -125,6 +134,59 @@ void testSingleAgentWithTools() { assertTrue(partsThree.get(0).text().get().contains("beautiful")); } + @Test + void testAgentTool() { + // given + OpenAiChatModel gptModel = OpenAiChatModel.builder() + .baseUrl("http://langchain4j.dev/demo/openai/v1") + .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) + .modelName(GPT_4_O_MINI) + .build(); + + LlmAgent weatherAgent = LlmAgent.builder() + .name("weather-agent") + .description("Weather agent") + .model(GEMINI_2_0_FLASH) + .instruction(""" + Your role is to always answer that the weather is sunny and 20°C. + """) + .build(); + + BaseAgent agent = LlmAgent.builder() + .name("friendly-weather-app") + .description("Friend agent that knows about the weather") + .model(new LangChain4j(gptModel, "gpt-3.5-turbo")) + .instruction(""" + You are a friendly assistant. + + If asked about the weather forecast for a city, + you MUST call the `weather-agent` function. + """) + .tools(AgentTool.create(weatherAgent)) + .build(); + + // when + List events = askAgent(agent, "What's the weather like in Paris?"); + + // then + assertEquals(3, events.size()); + events.forEach(event -> { + assertTrue(event.content().isPresent()); + System.out.printf("%nevent: %s%n", event.stringifyContent()); + }); + + assertEquals(1, events.get(0).functionCalls().size()); + assertEquals("weather-agent", events.get(0).functionCalls().get(0).name().get()); + + assertEquals(1, events.get(1).functionResponses().size()); + assertTrue(events.get(1).functionResponses().get(0).response().get().toString().toLowerCase().contains("sunny")); + assertTrue(events.get(1).functionResponses().get(0).response().get().toString().contains("20")); + + assertTrue(events.get(2).finalResponse()); + assertTrue(events.get(2).content().get().text().contains("sunny")); + assertTrue(events.get(2).content().get().text().contains("20")); + } + private static List askAgent(BaseAgent agent, String... messages) { ArrayList allEvents = new ArrayList<>(); From ca168695c0e5eb89db85f5df5f7cddaf6179ba36 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Sat, 31 May 2025 21:03:11 +0200 Subject: [PATCH 05/20] [WIP] Move the LangChain4j integration in its own subdirectory. --- contrib/{ => langchain4j}/pom.xml | 0 .../adk/models/langchain4j/LangChain4j.java | 60 ++++++++++++++----- .../models/langchain4j/LangChain4jTest.java | 0 3 files changed, 45 insertions(+), 15 deletions(-) rename contrib/{ => langchain4j}/pom.xml (100%) rename contrib/{ => langchain4j}/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java (85%) rename contrib/{ => langchain4j}/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java (100%) diff --git a/contrib/pom.xml b/contrib/langchain4j/pom.xml similarity index 100% rename from contrib/pom.xml rename to contrib/langchain4j/pom.xml diff --git a/contrib/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java similarity index 85% rename from contrib/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java rename to contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 2be29ed..ce38fcd 100644 --- a/contrib/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -22,13 +22,7 @@ import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; -import com.google.genai.types.Content; -import com.google.genai.types.FunctionCall; -import com.google.genai.types.FunctionDeclaration; -import com.google.genai.types.FunctionResponse; -import com.google.genai.types.Part; -import com.google.genai.types.Schema; -import com.google.genai.types.Type; +import com.google.genai.types.*; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; @@ -40,6 +34,7 @@ import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.request.ToolChoice; import dev.langchain4j.model.chat.request.json.JsonArraySchema; import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; @@ -50,12 +45,7 @@ import dev.langchain4j.model.chat.response.ChatResponse; import io.reactivex.rxjava3.core.Flowable; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.UUID; +import java.util.*; public class LangChain4j extends BaseLlm { @@ -117,10 +107,50 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre private ChatRequest toChatRequest(LlmRequest llmRequest) { // TODO llmRequest.model() ? - return ChatRequest.builder() + ChatRequest.Builder requestBuilder = ChatRequest.builder(); + if (llmRequest.config().isPresent()) { + GenerateContentConfig generateContentConfig = llmRequest.config().get(); + + generateContentConfig.temperature().ifPresent(temp -> + requestBuilder.temperature(temp.doubleValue())); + generateContentConfig.topP().ifPresent(topP -> + requestBuilder.topP(topP.doubleValue())); + generateContentConfig.topK().ifPresent(topK -> + requestBuilder.topK(topK.intValue())); + generateContentConfig.maxOutputTokens().ifPresent(requestBuilder::maxOutputTokens); + generateContentConfig.stopSequences().ifPresent(requestBuilder::stopSequences); + generateContentConfig.frequencyPenalty().ifPresent(freqPenalty -> + requestBuilder.frequencyPenalty(freqPenalty.doubleValue())); + generateContentConfig.presencePenalty().ifPresent(presPenalty -> + requestBuilder.presencePenalty(presPenalty.doubleValue())); + + if (generateContentConfig.toolConfig().isPresent()) { + ToolConfig toolConfig = generateContentConfig.toolConfig().get(); + toolConfig.functionCallingConfig().ifPresent(functionCallingConfig -> { + functionCallingConfig.mode().ifPresent(functionMode -> { + // TODO + if (functionMode.knownEnum().equals(FunctionCallingConfigMode.Known.AUTO)) { + requestBuilder.toolChoice(ToolChoice.AUTO); + } else if (functionMode.knownEnum().equals(FunctionCallingConfigMode.Known.ANY)) { + + } + }); + functionCallingConfig.allowedFunctionNames().ifPresent(allowedFunctionName -> { + // TODO + + }); + }); + toolConfig.retrievalConfig().ifPresent(retrievalConfig -> { + // TODO? It exposes Latitude / Longitude, what to do with this? + + }); + } + } + + return requestBuilder .messages(toMessages(llmRequest)) .toolSpecifications(toToolSpecifications(llmRequest)) - // TODO + // TODO? .build(); } diff --git a/contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java similarity index 100% rename from contrib/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java rename to contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java From de547a0b13b27ba90a86218a2a4d9916b9ea93ea Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Sat, 31 May 2025 23:49:57 +0200 Subject: [PATCH 06/20] [WIP] Further hyperparameter handling and function calling mode --- .../adk/models/langchain4j/LangChain4j.java | 49 +++++++++++++------ 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index ce38fcd..d62c592 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -22,7 +22,16 @@ import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; -import com.google.genai.types.*; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionCallingConfigMode; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import com.google.genai.types.Schema; +import com.google.genai.types.ToolConfig; +import com.google.genai.types.Type; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; @@ -45,7 +54,12 @@ import dev.langchain4j.model.chat.response.ChatResponse; import io.reactivex.rxjava3.core.Flowable; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; public class LangChain4j extends BaseLlm { @@ -91,7 +105,7 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre return Flowable.error(new IllegalStateException("StreamingChatModel is not configured")); } - // TODO + // TODO implement streaming throw new UnsupportedOperationException("Streaming is not supported for LangChain4j models yet."); } else { if (this.chatModel == null) { @@ -108,6 +122,10 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre private ChatRequest toChatRequest(LlmRequest llmRequest) { // TODO llmRequest.model() ? ChatRequest.Builder requestBuilder = ChatRequest.builder(); + + List toolSpecifications = toToolSpecifications(llmRequest); + requestBuilder.toolSpecifications(toolSpecifications); + if (llmRequest.config().isPresent()) { GenerateContentConfig generateContentConfig = llmRequest.config().get(); @@ -132,24 +150,28 @@ private ChatRequest toChatRequest(LlmRequest llmRequest) { if (functionMode.knownEnum().equals(FunctionCallingConfigMode.Known.AUTO)) { requestBuilder.toolChoice(ToolChoice.AUTO); } else if (functionMode.knownEnum().equals(FunctionCallingConfigMode.Known.ANY)) { - + // TODO check if it's the correct mapping + requestBuilder.toolChoice(ToolChoice.REQUIRED); + functionCallingConfig.allowedFunctionNames().ifPresent(allowedFunctionNames -> { + requestBuilder.toolSpecifications( + toolSpecifications.stream() + .filter(toolSpecification -> + allowedFunctionNames.contains(toolSpecification.name())) + .toList()); + }); + } else if (functionMode.knownEnum().equals(FunctionCallingConfigMode.Known.NONE)) { + requestBuilder.toolSpecifications(List.of()); } }); - functionCallingConfig.allowedFunctionNames().ifPresent(allowedFunctionName -> { - // TODO - - }); }); toolConfig.retrievalConfig().ifPresent(retrievalConfig -> { // TODO? It exposes Latitude / Longitude, what to do with this? - }); } } return requestBuilder .messages(toMessages(llmRequest)) - .toolSpecifications(toToolSpecifications(llmRequest)) // TODO? .build(); } @@ -171,7 +193,6 @@ private ChatMessage toChatMessage(Content content) { } private ChatMessage toUserOrToolResultMessage(Content content) { - List texts = new ArrayList<>(); ToolExecutionResultMessage toolExecutionResultMessage = null; @@ -199,7 +220,6 @@ private ChatMessage toUserOrToolResultMessage(Content content) { } private AiMessage toAiMessage(Content content) { - List texts = new ArrayList<>(); List toolExecutionRequests = new ArrayList<>(); @@ -234,7 +254,6 @@ private String toJson(Object object) { } private List toToolSpecifications(LlmRequest llmRequest) { - List toolSpecifications = new ArrayList<>(); llmRequest.tools().values() @@ -250,9 +269,11 @@ private List toToolSpecifications(LlmRequest llmRequest) { .build(); toolSpecifications.add(toolSpecification); } else { + // TODO exception or something else? throw new IllegalStateException("Tool lacking parameters: " + baseTool); } } else { + // TODO exception or something else? throw new IllegalStateException("Tool lacking declaration: " + baseTool); } }); @@ -262,7 +283,6 @@ private List toToolSpecifications(LlmRequest llmRequest) { private JsonObjectSchema toParameters(Schema schema) { if (schema.type().isPresent() && schema.type().get().knownEnum().equals(Type.Known.OBJECT)) { - return JsonObjectSchema.builder() .addProperties(toProperties(schema)) .required(schema.required().orElse(List.of())) @@ -305,7 +325,6 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) { } private LlmResponse toLlmResponse(ChatResponse chatResponse) { - Content content = Content.builder() .role("model") .parts(toParts(chatResponse.aiMessage())) From 390102f3704ff1835423142ba2405720ef5dd0a2 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Mon, 2 Jun 2025 14:34:52 +0200 Subject: [PATCH 07/20] [WIP] Not all LangChain4j models set their model name as default values (Anthropic doesn't set it, but OpenAI seems to set it) --- contrib/langchain4j/pom.xml | 2 ++ .../adk/models/langchain4j/LangChain4j.java | 12 +++++++---- .../models/langchain4j/LangChain4jTest.java | 20 +++++++++++++++++-- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index 0f774e6..a0adddc 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -49,6 +49,8 @@ UTF-8 17 + 17 + 17 ${java.version} 0.10.0 2.38.0 diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index d62c592..9767bda 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -71,28 +71,32 @@ public class LangChain4j extends BaseLlm { private final ObjectMapper objectMapper; public LangChain4j(ChatModel chatModel) { // TODO - super(chatModel.defaultRequestParameters().modelName()); + super(Objects.requireNonNull(chatModel.defaultRequestParameters().modelName(), + "chat model name cannot be null")); this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); this.streamingChatModel = null; this.objectMapper = new ObjectMapper(); } public LangChain4j(ChatModel chatModel, String modelName) { // TODO - super(modelName); + super(Objects.requireNonNull(modelName, + "chat model name cannot be null")); this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); this.streamingChatModel = null; this.objectMapper = new ObjectMapper(); } public LangChain4j(StreamingChatModel streamingChatModel) { // TODO - super(streamingChatModel.defaultRequestParameters().modelName()); + super(Objects.requireNonNull(streamingChatModel.defaultRequestParameters().modelName(), + "streaming chat model name cannot be null")); this.chatModel = null; this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); this.objectMapper = new ObjectMapper(); } public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { // TODO - super(modelName); + super(Objects.requireNonNull(modelName, + "streaming chat model name cannot be null")); this.chatModel = null; this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); this.objectMapper = new ObjectMapper(); diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index d90b297..8e17db7 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -3,6 +3,7 @@ import static org.junit.jupiter.api.Assertions.*; import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; @@ -12,6 +13,7 @@ import com.google.adk.tools.AgentTool; import com.google.adk.tools.Annotations.Schema; import com.google.adk.tools.FunctionTool; +import com.google.adk.tools.ToolContext; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; @@ -155,7 +157,7 @@ void testAgentTool() { BaseAgent agent = LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(gptModel, "gpt-3.5-turbo")) + .model(new LangChain4j(gptModel)) .instruction(""" You are a friendly assistant. @@ -207,7 +209,21 @@ private static List askAgent(BaseAgent agent, String... messages) { @Schema(description = "Function to get the weather forecast for a given city") public static Map getWeather( @Schema(name = "city", description = "The city to get the weather forecast for") - String city) { + String city, + ToolContext toolContext) { + + System.out.format(""" + Tool context + - function call ID: %s + - invocation ID: %s + - agent name: %s + - state: %s + """, + toolContext.functionCallId(), + toolContext.invocationId(), + toolContext.agentName(), + toolContext.state().entrySet()); + return Map.of( "city", city, "forecast", "a beautiful and sunny weather", From 6dbe1757bc41c7a548e54ed042f7a1d3de55f9b1 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Mon, 2 Jun 2025 16:05:35 +0200 Subject: [PATCH 08/20] [WIP] Adding a test case for sub-agents --- .../models/langchain4j/LangChain4jTest.java | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 8e17db7..245b155 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -189,6 +189,106 @@ void testAgentTool() { assertTrue(events.get(2).content().get().text().contains("20")); } + @Test + void testSubAgent() { + // given + OpenAiChatModel gptModel = OpenAiChatModel.builder() + .baseUrl("http://langchain4j.dev/demo/openai/v1") + .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) + .modelName(GPT_4_O_MINI) + .build(); + + LlmAgent greeterAgent = LlmAgent.builder() + .name("greeterAgent") + .description("Friendly agent that greets users") + .model(new LangChain4j(gptModel)) + .instruction(""" + You are a friendly that greets users. + """) + .build(); + + LlmAgent farewellAgent = LlmAgent.builder() + .name("farewellAgent") + .description("Friendly agent that says goodbye to users") + .model(new LangChain4j(gptModel)) + .instruction(""" + You are a friendly that says goodbye to users. + """) + .build(); + + LlmAgent coordinatorAgent = LlmAgent.builder() + .name("coordinator-agent") + .description("Coordinator agent") + .model(GEMINI_2_0_FLASH) + .instruction(""" + Your role is to coordinate 2 agents: + - `greeterAgent`: should reply to messages saying hello, hi, etc. + - `farewellAgent`: should reply to messages saying bye, goodbye, etc. + """) + .subAgents(greeterAgent, farewellAgent) + .build(); + + // when + List hiEvents = askAgent(coordinatorAgent, "Hi"); + List byeEvents = askAgent(coordinatorAgent, "Goodbye"); + + // then + hiEvents.forEach(event -> { System.out.println(event.stringifyContent()); }); + byeEvents.forEach(event -> { System.out.println(event.stringifyContent()); }); + + // Assertions for hiEvents + assertEquals(3, hiEvents.size()); + + Event hiEvent1 = hiEvents.get(0); + assertTrue(hiEvent1.content().isPresent()); + assertFalse(hiEvent1.functionCalls().isEmpty()); + assertEquals(1, hiEvent1.functionCalls().size()); + FunctionCall hiFunctionCall = hiEvent1.functionCalls().get(0); + assertTrue(hiFunctionCall.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), hiFunctionCall.name()); + assertEquals(Optional.of(Map.of("agentName", "greeterAgent")), hiFunctionCall.args()); + + Event hiEvent2 = hiEvents.get(1); + assertTrue(hiEvent2.content().isPresent()); + assertFalse(hiEvent2.functionResponses().isEmpty()); + assertEquals(1, hiEvent2.functionResponses().size()); + FunctionResponse hiFunctionResponse = hiEvent2.functionResponses().get(0); + assertTrue(hiFunctionResponse.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), hiFunctionResponse.name()); + assertEquals(Optional.of(Map.of()), hiFunctionResponse.response()); // Empty map for response + + Event hiEvent3 = hiEvents.get(2); + assertTrue(hiEvent3.content().isPresent()); + assertTrue(hiEvent3.content().get().text().toLowerCase().contains("hello")); + assertTrue(hiEvent3.finalResponse()); + + // Assertions for byeEvents + assertEquals(3, byeEvents.size()); + + Event byeEvent1 = byeEvents.get(0); + assertTrue(byeEvent1.content().isPresent()); + assertFalse(byeEvent1.functionCalls().isEmpty()); + assertEquals(1, byeEvent1.functionCalls().size()); + FunctionCall byeFunctionCall = byeEvent1.functionCalls().get(0); + assertTrue(byeFunctionCall.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), byeFunctionCall.name()); + assertEquals(Optional.of(Map.of("agentName", "farewellAgent")), byeFunctionCall.args()); + + Event byeEvent2 = byeEvents.get(1); + assertTrue(byeEvent2.content().isPresent()); + assertFalse(byeEvent2.functionResponses().isEmpty()); + assertEquals(1, byeEvent2.functionResponses().size()); + FunctionResponse byeFunctionResponse = byeEvent2.functionResponses().get(0); + assertTrue(byeFunctionResponse.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), byeFunctionResponse.name()); + assertEquals(Optional.of(Map.of()), byeFunctionResponse.response()); // Empty map for response + + Event byeEvent3 = byeEvents.get(2); + assertTrue(byeEvent3.content().isPresent()); + assertTrue(byeEvent3.content().get().text().toLowerCase().contains("goodbye")); + assertTrue(byeEvent3.finalResponse()); + } + private static List askAgent(BaseAgent agent, String... messages) { ArrayList allEvents = new ArrayList<>(); From b1b32506dc5003ec8efa355113e93276811f1877 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Mon, 2 Jun 2025 23:08:06 +0200 Subject: [PATCH 09/20] [WIP] Streaming model support --- .../adk/models/langchain4j/LangChain4j.java | 28 +++++++++++++++-- .../models/langchain4j/LangChain4jTest.java | 30 ++++++++++++++++++- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 9767bda..fb2a1fa 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.events.Event; import com.google.adk.models.BaseLlm; import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; @@ -52,6 +53,8 @@ import dev.langchain4j.model.chat.request.json.JsonSchemaElement; import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; @@ -109,8 +112,29 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre return Flowable.error(new IllegalStateException("StreamingChatModel is not configured")); } - // TODO implement streaming - throw new UnsupportedOperationException("Streaming is not supported for LangChain4j models yet."); + ChatRequest chatRequest = toChatRequest(llmRequest); + + // TODO is streaming properly implemented? What happens for function calls? + return Flowable.create(emitter -> { + streamingChatModel.chat(chatRequest, new StreamingChatResponseHandler() { + @Override + public void onPartialResponse(String s) { + emitter.onNext(LlmResponse.builder() + .content(Content.fromParts(Part.fromText(s))) + .build()); + } + + @Override + public void onCompleteResponse(ChatResponse chatResponse) { + emitter.onComplete(); + } + + @Override + public void onError(Throwable throwable) { + emitter.onError(throwable); + } + }); + }, BackpressureStrategy.BUFFER); } else { if (this.chatModel == null) { return Flowable.error(new IllegalStateException("ChatModel is not configured")); diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 245b155..68d4b01 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -3,10 +3,11 @@ import static org.junit.jupiter.api.Assertions.*; import com.google.adk.agents.BaseAgent; -import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; import com.google.adk.runner.InMemoryRunner; import com.google.adk.runner.Runner; import com.google.adk.sessions.Session; @@ -19,7 +20,9 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import dev.langchain4j.model.anthropic.AnthropicChatModel; +import dev.langchain4j.model.anthropic.AnthropicStreamingChatModel; import dev.langchain4j.model.openai.OpenAiChatModel; +import io.reactivex.rxjava3.core.Flowable; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -289,6 +292,31 @@ void testSubAgent() { assertTrue(byeEvent3.finalResponse()); } + @Test + void testSimpleStreamingResponse() { + // given + AnthropicStreamingChatModel claudeStreamingModel = AnthropicStreamingChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(CLAUDE_3_7_SONNET_20250219) + .build(); + + LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_3_7_SONNET_20250219); + + // when + Flowable responses = lc4jClaude.generateContent(LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Why is the sky blue?")))) + .build(), true); + + String fullResponse = String.join("", responses.blockingStream() + .map(llmResponse -> llmResponse.content().get().text()) + .toList()); + + // then + assertTrue(fullResponse.contains("blue")); + assertTrue(fullResponse.contains("Rayleigh")); + assertTrue(fullResponse.contains("scatter")); + } + private static List askAgent(BaseAgent agent, String... messages) { ArrayList allEvents = new ArrayList<>(); From 6b100c733fd5e302db666b91a62375ec1660a72c Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Wed, 4 Jun 2025 17:04:04 +0200 Subject: [PATCH 10/20] [WIP] Making function calling work in streaming mode --- contrib/langchain4j/pom.xml | 5 + .../adk/models/langchain4j/LangChain4j.java | 45 +++++- .../models/langchain4j/LangChain4jTest.java | 143 ++++++++++++++++-- 3 files changed, 171 insertions(+), 22 deletions(-) diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index a0adddc..55d61d3 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -81,6 +81,11 @@ google-adk 0.1.0 + + com.google.adk + google-adk-dev + 0.1.0 + com.google.genai google-genai diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index fb2a1fa..9609f16 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -18,7 +18,6 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.adk.events.Event; import com.google.adk.models.BaseLlm; import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; @@ -98,13 +97,19 @@ public LangChain4j(StreamingChatModel streamingChatModel) { // TODO } public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { // TODO - super(Objects.requireNonNull(modelName, - "streaming chat model name cannot be null")); + super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); this.chatModel = null; this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); this.objectMapper = new ObjectMapper(); } + public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) { + super(Objects.requireNonNull(modelName, "model name cannot be null")); + this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); + this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); + this.objectMapper = new ObjectMapper(); + } + @Override public Flowable generateContent(LlmRequest llmRequest, boolean stream) { if (stream) { @@ -126,6 +131,20 @@ public void onPartialResponse(String s) { @Override public void onCompleteResponse(ChatResponse chatResponse) { + if (chatResponse.aiMessage().hasToolExecutionRequests()) { + AiMessage aiMessage = chatResponse.aiMessage(); + toParts(aiMessage).stream() + .map(Part::functionCall) + .forEach(functionCall -> { + functionCall.ifPresent(function -> { + emitter.onNext(LlmResponse.builder() + .content(Content.fromParts(Part.fromFunctionCall( + function.name().orElse(""), + function.args().orElse(Map.of())))) + .build()); + }); + }); + } emitter.onComplete(); } @@ -182,10 +201,10 @@ private ChatRequest toChatRequest(LlmRequest llmRequest) { requestBuilder.toolChoice(ToolChoice.REQUIRED); functionCallingConfig.allowedFunctionNames().ifPresent(allowedFunctionNames -> { requestBuilder.toolSpecifications( - toolSpecifications.stream() - .filter(toolSpecification -> - allowedFunctionNames.contains(toolSpecification.name())) - .toList()); + toolSpecifications.stream() + .filter(toolSpecification -> + allowedFunctionNames.contains(toolSpecification.name())) + .toList()); }); } else if (functionMode.knownEnum().equals(FunctionCallingConfigMode.Known.NONE)) { requestBuilder.toolSpecifications(List.of()); @@ -223,18 +242,26 @@ private ChatMessage toChatMessage(Content content) { private ChatMessage toUserOrToolResultMessage(Content content) { List texts = new ArrayList<>(); ToolExecutionResultMessage toolExecutionResultMessage = null; + ToolExecutionRequest toolExecutionRequest = null; for (Part part : content.parts().orElse(List.of())) { if (part.text().isPresent()) { texts.add(part.text().get()); } else if (part.functionResponse().isPresent()) { - // TODO multiple tool calls? + // TODO multiple tool calls? should be 1 per part? FunctionResponse functionResponse = part.functionResponse().get(); toolExecutionResultMessage = ToolExecutionResultMessage.from( functionResponse.id().orElseThrow(), functionResponse.name().orElseThrow(), toJson(functionResponse.response().orElseThrow()) ); + } else if (part.functionCall().isPresent()) { + FunctionCall functionCall = part.functionCall().get(); + toolExecutionRequest = ToolExecutionRequest.builder() + .id(functionCall.id().orElseThrow()) + .name(functionCall.name().orElseThrow()) + .arguments(toJson(functionCall.args().orElse(Map.of()))) + .build(); } else { throw new IllegalStateException("Either text or functionCall is expected, but was: " + part); } @@ -242,6 +269,8 @@ private ChatMessage toUserOrToolResultMessage(Content content) { if (toolExecutionResultMessage != null) { return toolExecutionResultMessage; + } else if (toolExecutionRequest != null){ + return AiMessage.aiMessage(toolExecutionRequest); } else { return UserMessage.from(String.join("\n", texts)); } diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 68d4b01..efd0304 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -21,7 +21,9 @@ import com.google.genai.types.Part; import dev.langchain4j.model.anthropic.AnthropicChatModel; import dev.langchain4j.model.anthropic.AnthropicStreamingChatModel; +import dev.langchain4j.model.googleai.GoogleAiGeminiStreamingChatModel; import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.openai.OpenAiStreamingChatModel; import io.reactivex.rxjava3.core.Flowable; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -317,7 +319,128 @@ void testSimpleStreamingResponse() { assertTrue(fullResponse.contains("scatter")); } + @Test + void testStreamingRunConfig() { + // given + OpenAiStreamingChatModel streamingModel = OpenAiStreamingChatModel.builder() + .baseUrl("http://langchain4j.dev/demo/openai/v1") + .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) + .modelName(GPT_4_O_MINI) + .build(); + +// AnthropicStreamingChatModel streamingModel = AnthropicStreamingChatModel.builder() +// .apiKey(System.getenv("ANTHROPIC_API_KEY")) +// .modelName(CLAUDE_3_7_SONNET_20250219) +// .build(); + +// GoogleAiGeminiStreamingChatModel streamingModel = GoogleAiGeminiStreamingChatModel.builder() +// .apiKey(System.getenv("GOOGLE_API_KEY")) +// .modelName("gemini-2.0-flash") +// .build(); + + LlmAgent agent = LlmAgent.builder() + .name("streaming-agent") + .description("Friendly science teacher agent") + .instruction(""" + You're a friendly science teacher. + You give concise answers about science topics. + + When someone greets you, respond with "Hello". + If someone asks about the weather, call the `getWeather` function. + """) + .model(new LangChain4j(streamingModel, "GPT_4_O_MINI")) +// .model(new LangChain4j(streamingModel, CLAUDE_3_7_SONNET_20250219)) + .tools(FunctionTool.create(LangChain4jTest.class, "getWeather")) + .build(); + + // when + List eventsHi = askAgentStreaming(agent, "Hi"); + String responseToHi = String.join("", eventsHi.stream() + .map(event -> event.content().get().text()) + .toList()); + + List eventsQubit = askAgentStreaming(agent, "Tell me about qubits"); + String responseToQubit = String.join("", eventsQubit.stream() + .map(event -> event.content().get().text()) + .toList()); + + List eventsWeather = askAgentStreaming(agent, "What's the weather in Paris?"); + String responseToWeather = String.join("", eventsWeather.stream() + .map(Event::stringifyContent) + .toList()); + + // then + + // Assertions for "Hi" + assertFalse(eventsHi.isEmpty(), "eventsHi should not be empty"); + // Depending on the model and streaming behavior, the number of events can vary. + // If a single "Hello" is expected in one event: + // assertEquals(1, eventsHi.size(), "Expected 1 event for 'Hi'"); + // assertEquals("Hello", responseToHi, "Response to 'Hi' should be 'Hello'"); + // If "Hello" can be streamed in multiple parts: + assertTrue(eventsHi.size() >= 1, "Expected at least 1 event for 'Hi'"); + assertTrue(responseToHi.trim().contains("Hello"), "Response to 'Hi' should be 'Hello'"); + + + // Assertions for "Tell me about qubits" + assertTrue(eventsQubit.size() > 1, "Expected multiple streaming events for 'qubit' question"); + assertTrue(responseToQubit.toLowerCase().contains("qubit"), "Response to 'qubit' should contain 'qubit'"); + assertTrue(responseToQubit.toLowerCase().contains("quantum"), "Response to 'qubit' should contain 'quantum'"); + assertTrue(responseToQubit.toLowerCase().contains("superposition"), "Response to 'qubit' should contain 'superposition'"); + + // Assertions for "What's the weather in Paris?" + assertTrue(eventsWeather.size() > 2, "Expected multiple events for weather question (function call, response, text)"); + + // Check for function call + Optional functionCallEvent = eventsWeather.stream() + .filter(e -> !e.functionCalls().isEmpty()) + .findFirst(); + assertTrue(functionCallEvent.isPresent(), "Should contain a function call event for weather"); + FunctionCall fc = functionCallEvent.get().functionCalls().get(0); + assertEquals(Optional.of("getWeather"), fc.name(), "Function call name should be 'getWeather'"); + assertTrue(fc.args().isPresent() && "Paris".equals(fc.args().get().get("city")), "Function call should be for 'Paris'"); + + // Check for function response + Optional functionResponseEvent = eventsWeather.stream() + .filter(e -> !e.functionResponses().isEmpty()) + .findFirst(); + assertTrue(functionResponseEvent.isPresent(), "Should contain a function response event for weather"); + FunctionResponse fr = functionResponseEvent.get().functionResponses().get(0); + assertEquals(Optional.of("getWeather"), fr.name(), "Function response name should be 'getWeather'"); + assertTrue(fr.response().isPresent()); + Map weatherResponseMap = (Map) fr.response().get(); + assertEquals("Paris", weatherResponseMap.get("city")); + assertTrue(weatherResponseMap.get("forecast").toString().contains("beautiful and sunny")); + + // Check the final aggregated text response + // Consolidate text parts from events that are not function calls or responses + String finalWeatherTextResponse = eventsWeather.stream() + .filter(event -> event.functionCalls().isEmpty() && event.functionResponses().isEmpty() && event.content().isPresent() && event.content().get().text() != null) + .map(event -> event.content().get().text()) + .collect(java.util.stream.Collectors.joining()) + .trim(); + + assertTrue(finalWeatherTextResponse.contains("Paris"), "Final weather response should mention Paris"); + assertTrue(finalWeatherTextResponse.toLowerCase().contains("beautiful and sunny"), "Final weather response should mention 'beautiful and sunny'"); + assertTrue(finalWeatherTextResponse.contains("10"), "Final weather response should mention '10'"); + assertTrue(finalWeatherTextResponse.contains("24"), "Final weather response should mention '24'"); + + // You can also assert on the concatenated `responseToWeather` if it's meant to capture the full interaction text + assertTrue(responseToWeather.contains("Function Call") && responseToWeather.contains("getWeather") && responseToWeather.contains("Paris")); + assertTrue(responseToWeather.contains("Function Response") && responseToWeather.contains("beautiful and sunny weather")); + assertTrue(responseToWeather.contains("sunny")); + assertTrue(responseToWeather.contains("24")); + } + private static List askAgent(BaseAgent agent, String... messages) { + return runLoop(agent, false, messages); + } + + private static List askAgentStreaming(BaseAgent agent, String... messages) { + return runLoop(agent, true, messages); + } + + private static List runLoop(BaseAgent agent, boolean streaming, String... messages) { ArrayList allEvents = new ArrayList<>(); Runner runner = new InMemoryRunner(agent, agent.name()); @@ -326,8 +449,12 @@ private static List askAgent(BaseAgent agent, String... messages) { for (String message : messages) { Content messageContent = Content.fromParts(Part.fromText(message)); allEvents.addAll( - runner.runAsync(session, messageContent, RunConfig.builder().build()) - .blockingStream().toList() + runner.runAsync(session, messageContent, + RunConfig.builder() + .setStreamingMode(streaming ? RunConfig.StreamingMode.SSE : RunConfig.StreamingMode.NONE) + .build()) + .blockingStream() + .toList() ); } @@ -340,18 +467,6 @@ public static Map getWeather( String city, ToolContext toolContext) { - System.out.format(""" - Tool context - - function call ID: %s - - invocation ID: %s - - agent name: %s - - state: %s - """, - toolContext.functionCallId(), - toolContext.invocationId(), - toolContext.agentName(), - toolContext.state().entrySet()); - return Map.of( "city", city, "forecast", "a beautiful and sunny weather", From 47e8795d61a0a1ced078d5a8a82f6a6a560f2da1 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Fri, 6 Jun 2025 19:46:42 +0200 Subject: [PATCH 11/20] [WIP] Add ASL header in test class Add Ollama in test scope to test local models running with Ollama --- contrib/langchain4j/pom.xml | 7 +++++++ .../adk/models/langchain4j/LangChain4jTest.java | 15 +++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index 55d61d3..5d55ace 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -68,6 +68,7 @@ 1.0.1 1.0.1-beta6 1.0.1-beta6 + 1.0.1-beta6 @@ -116,6 +117,12 @@ ${langchain4j.gemini.version} test + + dev.langchain4j + langchain4j-ollama + ${langchain4j.ollama.version} + test + org.junit.jupiter junit-jupiter-api diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index efd0304..79085ed 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.models.langchain4j; import static org.junit.jupiter.api.Assertions.*; From defb4de2d158ad673fdd1feef94c3fe54f15feca Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Fri, 6 Jun 2025 19:48:32 +0200 Subject: [PATCH 12/20] [WIP] Add LangChain4j's @Experimental annotation --- .../java/com/google/adk/models/langchain4j/LangChain4j.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 9609f16..2942dd5 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -32,6 +32,7 @@ import com.google.genai.types.Schema; import com.google.genai.types.ToolConfig; import com.google.genai.types.Type; +import dev.langchain4j.Experimental; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; @@ -63,6 +64,7 @@ import java.util.Objects; import java.util.UUID; +@Experimental public class LangChain4j extends BaseLlm { private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference<>() { From ab7ecf4972099c3ffdc5637dea2df7c6963f76cb Mon Sep 17 00:00:00 2001 From: kpavlov <1517853+kpavlov@users.noreply.github.com> Date: Mon, 9 Jun 2025 21:32:32 +0300 Subject: [PATCH 13/20] Cleanup pom --- contrib/langchain4j/pom.xml | 42 +++++++++---------- .../models/langchain4j/LangChain4jTest.java | 1 - 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index 5d55ace..1fc2f67 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -49,8 +49,6 @@ UTF-8 17 - 17 - 17 ${java.version} 0.10.0 2.38.0 @@ -65,17 +63,30 @@ 2.19.0 4.12.0 1.0.1 - 1.0.1 - 1.0.1-beta6 - 1.0.1-beta6 - 1.0.1-beta6 + + + + dev.langchain4j + langchain4j-bom + ${langchain4j.version} + pom + import + + + org.junit + junit-bom + ${junit.version} + pom + import + + + dev.langchain4j langchain4j-core - ${langchain4j.version} com.google.adk @@ -102,49 +113,36 @@ dev.langchain4j langchain4j-anthropic - ${langchain4j.anthropic.version} test dev.langchain4j langchain4j-open-ai - ${langchain4j.openai.version} test dev.langchain4j langchain4j-google-ai-gemini - ${langchain4j.gemini.version} test dev.langchain4j langchain4j-ollama - ${langchain4j.ollama.version} test org.junit.jupiter junit-jupiter-api - ${junit.version} test org.junit.jupiter junit-jupiter-params - ${junit.version} test org.junit.jupiter junit-jupiter-engine - ${junit.version} - test - - - org.junit.vintage - junit-vintage-engine - ${junit.version} test @@ -209,7 +207,7 @@ maven-surefire-plugin - 3.5.2 + 3.5.3 me.fabriciorby @@ -371,4 +369,4 @@ - \ No newline at end of file + diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 79085ed..c588a57 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -36,7 +36,6 @@ import com.google.genai.types.Part; import dev.langchain4j.model.anthropic.AnthropicChatModel; import dev.langchain4j.model.anthropic.AnthropicStreamingChatModel; -import dev.langchain4j.model.googleai.GoogleAiGeminiStreamingChatModel; import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.openai.OpenAiStreamingChatModel; import io.reactivex.rxjava3.core.Flowable; From 798610784ec6b5dd45f9e2f459585dcead44f5a0 Mon Sep 17 00:00:00 2001 From: kpavlov <1517853+kpavlov@users.noreply.github.com> Date: Mon, 9 Jun 2025 22:14:21 +0300 Subject: [PATCH 14/20] Add unit test --- contrib/langchain4j/pom.xml | 6 + .../LangChain4jIntegrationTest.java | 494 +++++++++ .../models/langchain4j/LangChain4jTest.java | 979 ++++++++++-------- 3 files changed, 1032 insertions(+), 447 deletions(-) create mode 100644 contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index 1fc2f67..c9ca06c 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -151,6 +151,12 @@ 1.4.4 test + + org.assertj + assertj-core + 3.27.3 + test + org.mockito mockito-core diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java new file mode 100644 index 0000000..1406959 --- /dev/null +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -0,0 +1,494 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.models.langchain4j; + +import static org.junit.jupiter.api.Assertions.*; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.RunConfig; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.runner.InMemoryRunner; +import com.google.adk.runner.Runner; +import com.google.adk.sessions.Session; +import com.google.adk.tools.AgentTool; +import com.google.adk.tools.Annotations.Schema; +import com.google.adk.tools.FunctionTool; +import com.google.adk.tools.ToolContext; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import dev.langchain4j.model.anthropic.AnthropicChatModel; +import dev.langchain4j.model.anthropic.AnthropicStreamingChatModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.openai.OpenAiStreamingChatModel; +import io.reactivex.rxjava3.core.Flowable; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +class LangChain4jIntegrationTest { + + public static final String CLAUDE_3_7_SONNET_20250219 = "claude-3-7-sonnet-20250219"; + public static final String GEMINI_2_0_FLASH = "gemini-2.0-flash"; + public static final String GPT_4_O_MINI = "gpt-4o-mini"; + + @BeforeAll + public static void setUp() { + assertNotNull(System.getenv("ANTHROPIC_API_KEY")); + assertNotNull(System.getenv("GOOGLE_API_KEY")); + } + + @Test + void testSimpleAgent() { + // given + AnthropicChatModel claudeModel = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(CLAUDE_3_7_SONNET_20250219) + .build(); + + LlmAgent agent = LlmAgent.builder() + .name("science-app") + .description("Science teacher agent") + .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .instruction(""" + You are a helpful science teacher that explains science concepts + to kids and teenagers. + """) + .build(); + + // when + List events = askAgent(agent, "What is a qubit?"); + + // then + assertEquals(1, events.size()); + + Event firstEvent = events.get(0); + assertTrue(firstEvent.content().isPresent()); + + Content content = firstEvent.content().get(); + System.out.println("Answer: " + content.text()); + assertTrue(content.text().contains("quantum")); + } + + @Test + void testSingleAgentWithTools() { + // given + AnthropicChatModel claudeModel = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(CLAUDE_3_7_SONNET_20250219) + .build(); + + BaseAgent agent = LlmAgent.builder() + .name("friendly-weather-app") + .description("Friend agent that knows about the weather") + .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .instruction(""" + You are a friendly assistant. + + If asked about the weather forecast for a city, + you MUST call the `getWeather` function. + """) + .tools(FunctionTool.create(LangChain4jIntegrationTest.class, "getWeather")) + .build(); + + // when + List events = askAgent(agent, "What's the weather like in Paris?"); + + // then + assertEquals(3, events.size()); + + events.forEach(event -> { + assertTrue(event.content().isPresent()); + System.out.printf("%nevent: %s%n", event.stringifyContent()); + }); + + Event eventOne = events.get(0); + Event eventTwo = events.get(1); + Event eventThree = events.get(2); + + // assert the first event is a function call + Content contentOne = eventOne.content().get(); + assertTrue(contentOne.parts().isPresent()); + List partsOne = contentOne.parts().get(); + assertEquals(1, partsOne.size()); + Optional functionCall = partsOne.get(0).functionCall(); + assertTrue(functionCall.isPresent()); + assertEquals("getWeather", functionCall.get().name().get()); + assertTrue(functionCall.get().args().get().containsKey("city")); + + // assert the second event is a function response + Content contentTwo = eventTwo.content().get(); + assertTrue(contentTwo.parts().isPresent()); + List partsTwo = contentTwo.parts().get(); + assertEquals(1, partsTwo.size()); + Optional functionResponseTwo = partsTwo.get(0).functionResponse(); + assertTrue(functionResponseTwo.isPresent()); + + // assert the third event is the final text response + assertTrue(eventThree.finalResponse()); + Content contentThree = eventThree.content().get(); + assertTrue(contentThree.parts().isPresent()); + List partsThree = contentThree.parts().get(); + assertEquals(1, partsThree.size()); + assertTrue(partsThree.get(0).text().get().contains("beautiful")); + } + + @Test + void testAgentTool() { + // given + OpenAiChatModel gptModel = OpenAiChatModel.builder() + .baseUrl("http://langchain4j.dev/demo/openai/v1") + .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) + .modelName(GPT_4_O_MINI) + .build(); + + LlmAgent weatherAgent = LlmAgent.builder() + .name("weather-agent") + .description("Weather agent") + .model(GEMINI_2_0_FLASH) + .instruction(""" + Your role is to always answer that the weather is sunny and 20°C. + """) + .build(); + + BaseAgent agent = LlmAgent.builder() + .name("friendly-weather-app") + .description("Friend agent that knows about the weather") + .model(new LangChain4j(gptModel)) + .instruction(""" + You are a friendly assistant. + + If asked about the weather forecast for a city, + you MUST call the `weather-agent` function. + """) + .tools(AgentTool.create(weatherAgent)) + .build(); + + // when + List events = askAgent(agent, "What's the weather like in Paris?"); + + // then + assertEquals(3, events.size()); + events.forEach(event -> { + assertTrue(event.content().isPresent()); + System.out.printf("%nevent: %s%n", event.stringifyContent()); + }); + + assertEquals(1, events.get(0).functionCalls().size()); + assertEquals("weather-agent", events.get(0).functionCalls().get(0).name().get()); + + assertEquals(1, events.get(1).functionResponses().size()); + assertTrue(events.get(1).functionResponses().get(0).response().get().toString().toLowerCase().contains("sunny")); + assertTrue(events.get(1).functionResponses().get(0).response().get().toString().contains("20")); + + { + final var finalEvent = events.get(2); + assertTrue(finalEvent.finalResponse()); + final var text = finalEvent.content().orElseThrow().text(); + assertTrue(text.contains("sunny")); + assertTrue(text.contains("20")); + } + } + + @Test + void testSubAgent() { + // given + OpenAiChatModel gptModel = OpenAiChatModel.builder() + .baseUrl("http://langchain4j.dev/demo/openai/v1") + .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) + .modelName(GPT_4_O_MINI) + .build(); + + LlmAgent greeterAgent = LlmAgent.builder() + .name("greeterAgent") + .description("Friendly agent that greets users") + .model(new LangChain4j(gptModel)) + .instruction(""" + You are a friendly that greets users. + """) + .build(); + + LlmAgent farewellAgent = LlmAgent.builder() + .name("farewellAgent") + .description("Friendly agent that says goodbye to users") + .model(new LangChain4j(gptModel)) + .instruction(""" + You are a friendly that says goodbye to users. + """) + .build(); + + LlmAgent coordinatorAgent = LlmAgent.builder() + .name("coordinator-agent") + .description("Coordinator agent") + .model(GEMINI_2_0_FLASH) + .instruction(""" + Your role is to coordinate 2 agents: + - `greeterAgent`: should reply to messages saying hello, hi, etc. + - `farewellAgent`: should reply to messages saying bye, goodbye, etc. + """) + .subAgents(greeterAgent, farewellAgent) + .build(); + + // when + List hiEvents = askAgent(coordinatorAgent, "Hi"); + List byeEvents = askAgent(coordinatorAgent, "Goodbye"); + + // then + hiEvents.forEach(event -> { System.out.println(event.stringifyContent()); }); + byeEvents.forEach(event -> { System.out.println(event.stringifyContent()); }); + + // Assertions for hiEvents + assertEquals(3, hiEvents.size()); + + Event hiEvent1 = hiEvents.get(0); + assertTrue(hiEvent1.content().isPresent()); + assertFalse(hiEvent1.functionCalls().isEmpty()); + assertEquals(1, hiEvent1.functionCalls().size()); + FunctionCall hiFunctionCall = hiEvent1.functionCalls().get(0); + assertTrue(hiFunctionCall.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), hiFunctionCall.name()); + assertEquals(Optional.of(Map.of("agentName", "greeterAgent")), hiFunctionCall.args()); + + Event hiEvent2 = hiEvents.get(1); + assertTrue(hiEvent2.content().isPresent()); + assertFalse(hiEvent2.functionResponses().isEmpty()); + assertEquals(1, hiEvent2.functionResponses().size()); + FunctionResponse hiFunctionResponse = hiEvent2.functionResponses().get(0); + assertTrue(hiFunctionResponse.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), hiFunctionResponse.name()); + assertEquals(Optional.of(Map.of()), hiFunctionResponse.response()); // Empty map for response + + Event hiEvent3 = hiEvents.get(2); + assertTrue(hiEvent3.content().isPresent()); + assertTrue(hiEvent3.content().get().text().toLowerCase().contains("hello")); + assertTrue(hiEvent3.finalResponse()); + + // Assertions for byeEvents + assertEquals(3, byeEvents.size()); + + Event byeEvent1 = byeEvents.get(0); + assertTrue(byeEvent1.content().isPresent()); + assertFalse(byeEvent1.functionCalls().isEmpty()); + assertEquals(1, byeEvent1.functionCalls().size()); + FunctionCall byeFunctionCall = byeEvent1.functionCalls().get(0); + assertTrue(byeFunctionCall.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), byeFunctionCall.name()); + assertEquals(Optional.of(Map.of("agentName", "farewellAgent")), byeFunctionCall.args()); + + Event byeEvent2 = byeEvents.get(1); + assertTrue(byeEvent2.content().isPresent()); + assertFalse(byeEvent2.functionResponses().isEmpty()); + assertEquals(1, byeEvent2.functionResponses().size()); + FunctionResponse byeFunctionResponse = byeEvent2.functionResponses().get(0); + assertTrue(byeFunctionResponse.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), byeFunctionResponse.name()); + assertEquals(Optional.of(Map.of()), byeFunctionResponse.response()); // Empty map for response + + Event byeEvent3 = byeEvents.get(2); + assertTrue(byeEvent3.content().isPresent()); + assertTrue(byeEvent3.content().get().text().toLowerCase().contains("goodbye")); + assertTrue(byeEvent3.finalResponse()); + } + + @Test + void testSimpleStreamingResponse() { + // given + AnthropicStreamingChatModel claudeStreamingModel = AnthropicStreamingChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(CLAUDE_3_7_SONNET_20250219) + .build(); + + LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_3_7_SONNET_20250219); + + // when + Flowable responses = lc4jClaude.generateContent(LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Why is the sky blue?")))) + .build(), true); + + String fullResponse = String.join("", responses.blockingStream() + .map(llmResponse -> llmResponse.content().get().text()) + .toList()); + + // then + assertTrue(fullResponse.contains("blue")); + assertTrue(fullResponse.contains("Rayleigh")); + assertTrue(fullResponse.contains("scatter")); + } + + @Test + void testStreamingRunConfig() { + // given + OpenAiStreamingChatModel streamingModel = OpenAiStreamingChatModel.builder() + .baseUrl("http://langchain4j.dev/demo/openai/v1") + .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) + .modelName(GPT_4_O_MINI) + .build(); + +// AnthropicStreamingChatModel streamingModel = AnthropicStreamingChatModel.builder() +// .apiKey(System.getenv("ANTHROPIC_API_KEY")) +// .modelName(CLAUDE_3_7_SONNET_20250219) +// .build(); + +// GoogleAiGeminiStreamingChatModel streamingModel = GoogleAiGeminiStreamingChatModel.builder() +// .apiKey(System.getenv("GOOGLE_API_KEY")) +// .modelName("gemini-2.0-flash") +// .build(); + + LlmAgent agent = LlmAgent.builder() + .name("streaming-agent") + .description("Friendly science teacher agent") + .instruction(""" + You're a friendly science teacher. + You give concise answers about science topics. + + When someone greets you, respond with "Hello". + If someone asks about the weather, call the `getWeather` function. + """) + .model(new LangChain4j(streamingModel, "GPT_4_O_MINI")) +// .model(new LangChain4j(streamingModel, CLAUDE_3_7_SONNET_20250219)) + .tools(FunctionTool.create(LangChain4jIntegrationTest.class, "getWeather")) + .build(); + + // when + List eventsHi = askAgentStreaming(agent, "Hi"); + String responseToHi = String.join("", eventsHi.stream() + .map(event -> event.content().get().text()) + .toList()); + + List eventsQubit = askAgentStreaming(agent, "Tell me about qubits"); + String responseToQubit = String.join("", eventsQubit.stream() + .map(event -> event.content().get().text()) + .toList()); + + List eventsWeather = askAgentStreaming(agent, "What's the weather in Paris?"); + String responseToWeather = String.join("", eventsWeather.stream() + .map(Event::stringifyContent) + .toList()); + + // then + + // Assertions for "Hi" + assertFalse(eventsHi.isEmpty(), "eventsHi should not be empty"); + // Depending on the model and streaming behavior, the number of events can vary. + // If a single "Hello" is expected in one event: + // assertEquals(1, eventsHi.size(), "Expected 1 event for 'Hi'"); + // assertEquals("Hello", responseToHi, "Response to 'Hi' should be 'Hello'"); + // If "Hello" can be streamed in multiple parts: + assertTrue(eventsHi.size() >= 1, "Expected at least 1 event for 'Hi'"); + assertTrue(responseToHi.trim().contains("Hello"), "Response to 'Hi' should be 'Hello'"); + + + // Assertions for "Tell me about qubits" + assertTrue(eventsQubit.size() > 1, "Expected multiple streaming events for 'qubit' question"); + assertTrue(responseToQubit.toLowerCase().contains("qubit"), "Response to 'qubit' should contain 'qubit'"); + assertTrue(responseToQubit.toLowerCase().contains("quantum"), "Response to 'qubit' should contain 'quantum'"); + assertTrue(responseToQubit.toLowerCase().contains("superposition"), "Response to 'qubit' should contain 'superposition'"); + + // Assertions for "What's the weather in Paris?" + assertTrue(eventsWeather.size() > 2, "Expected multiple events for weather question (function call, response, text)"); + + // Check for function call + Optional functionCallEvent = eventsWeather.stream() + .filter(e -> !e.functionCalls().isEmpty()) + .findFirst(); + assertTrue(functionCallEvent.isPresent(), "Should contain a function call event for weather"); + FunctionCall fc = functionCallEvent.get().functionCalls().get(0); + assertEquals(Optional.of("getWeather"), fc.name(), "Function call name should be 'getWeather'"); + assertTrue(fc.args().isPresent() && "Paris".equals(fc.args().get().get("city")), "Function call should be for 'Paris'"); + + // Check for function response + Optional functionResponseEvent = eventsWeather.stream() + .filter(e -> !e.functionResponses().isEmpty()) + .findFirst(); + assertTrue(functionResponseEvent.isPresent(), "Should contain a function response event for weather"); + FunctionResponse fr = functionResponseEvent.get().functionResponses().get(0); + assertEquals(Optional.of("getWeather"), fr.name(), "Function response name should be 'getWeather'"); + assertTrue(fr.response().isPresent()); + Map weatherResponseMap = (Map) fr.response().get(); + assertEquals("Paris", weatherResponseMap.get("city")); + assertTrue(weatherResponseMap.get("forecast").toString().contains("beautiful and sunny")); + + // Check the final aggregated text response + // Consolidate text parts from events that are not function calls or responses + String finalWeatherTextResponse = eventsWeather.stream() + .filter(event -> event.functionCalls().isEmpty() && event.functionResponses().isEmpty() && event.content().isPresent() && event.content().get().text() != null) + .map(event -> event.content().get().text()) + .collect(java.util.stream.Collectors.joining()) + .trim(); + + assertTrue(finalWeatherTextResponse.contains("Paris"), "Final weather response should mention Paris"); + assertTrue(finalWeatherTextResponse.toLowerCase().contains("beautiful and sunny"), "Final weather response should mention 'beautiful and sunny'"); + assertTrue(finalWeatherTextResponse.contains("10"), "Final weather response should mention '10'"); + assertTrue(finalWeatherTextResponse.contains("24"), "Final weather response should mention '24'"); + + // You can also assert on the concatenated `responseToWeather` if it's meant to capture the full interaction text + assertTrue(responseToWeather.contains("Function Call") && responseToWeather.contains("getWeather") && responseToWeather.contains("Paris")); + assertTrue(responseToWeather.contains("Function Response") && responseToWeather.contains("beautiful and sunny weather")); + assertTrue(responseToWeather.contains("sunny")); + assertTrue(responseToWeather.contains("24")); + } + + private static List askAgent(BaseAgent agent, String... messages) { + return runLoop(agent, false, messages); + } + + private static List askAgentStreaming(BaseAgent agent, String... messages) { + return runLoop(agent, true, messages); + } + + private static List runLoop(BaseAgent agent, boolean streaming, String... messages) { + ArrayList allEvents = new ArrayList<>(); + + Runner runner = new InMemoryRunner(agent, agent.name()); + Session session = runner.sessionService().createSession(agent.name(), "user132").blockingGet(); + + for (String message : messages) { + Content messageContent = Content.fromParts(Part.fromText(message)); + allEvents.addAll( + runner.runAsync(session, messageContent, + RunConfig.builder() + .setStreamingMode(streaming ? RunConfig.StreamingMode.SSE : RunConfig.StreamingMode.NONE) + .build()) + .blockingStream() + .toList() + ); + } + + return allEvents; + } + + @Schema(description = "Function to get the weather forecast for a given city") + public static Map getWeather( + @Schema(name = "city", description = "The city to get the weather forecast for") + String city, + ToolContext toolContext) { + + return Map.of( + "city", city, + "forecast", "a beautiful and sunny weather", + "temperature", "from 10°C in the morning up to 24°C in the afternoon" + ); + } +} diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index c588a57..f7e6450 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -1,490 +1,575 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ package com.google.adk.models.langchain4j; -import static org.junit.jupiter.api.Assertions.*; - -import com.google.adk.agents.BaseAgent; -import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.RunConfig; -import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; -import com.google.adk.runner.InMemoryRunner; -import com.google.adk.runner.Runner; -import com.google.adk.sessions.Session; -import com.google.adk.tools.AgentTool; -import com.google.adk.tools.Annotations.Schema; import com.google.adk.tools.FunctionTool; -import com.google.adk.tools.ToolContext; -import com.google.genai.types.Content; -import com.google.genai.types.FunctionCall; -import com.google.genai.types.FunctionResponse; -import com.google.genai.types.Part; -import dev.langchain4j.model.anthropic.AnthropicChatModel; -import dev.langchain4j.model.anthropic.AnthropicStreamingChatModel; -import dev.langchain4j.model.openai.OpenAiChatModel; -import dev.langchain4j.model.openai.OpenAiStreamingChatModel; +import com.google.genai.types.*; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import io.reactivex.rxjava3.core.Flowable; -import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; -public class LangChain4jTest { +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +class LangChain4jTest { + + private static final String MODEL_NAME = "test-model"; - public static final String CLAUDE_3_7_SONNET_20250219 = "claude-3-7-sonnet-20250219"; - public static final String GEMINI_2_0_FLASH = "gemini-2.0-flash"; - public static final String GPT_4_O_MINI = "gpt-4o-mini"; + private ChatModel chatModel; + private StreamingChatModel streamingChatModel; + private LangChain4j langChain4j; + private LangChain4j streamingLangChain4j; - @BeforeAll - public static void setUp() { - assertNotNull(System.getenv("ANTHROPIC_API_KEY")); - assertNotNull(System.getenv("GOOGLE_API_KEY")); + @BeforeEach + void setUp() { + chatModel = mock(ChatModel.class); + streamingChatModel = mock(StreamingChatModel.class); + + langChain4j = new LangChain4j(chatModel, MODEL_NAME); + streamingLangChain4j = new LangChain4j(streamingChatModel, MODEL_NAME); } @Test - void testSimpleAgent() { - // given - AnthropicChatModel claudeModel = AnthropicChatModel.builder() - .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) - .build(); - - LlmAgent agent = LlmAgent.builder() - .name("science-app") - .description("Science teacher agent") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) - .instruction(""" - You are a helpful science teacher that explains science concepts - to kids and teenagers. - """) - .build(); - - // when - List events = askAgent(agent, "What is a qubit?"); - - // then - assertEquals(1, events.size()); - - Event firstEvent = events.get(0); - assertTrue(firstEvent.content().isPresent()); - - Content content = firstEvent.content().get(); - System.out.println("Answer: " + content.text()); - assertTrue(content.text().contains("quantum")); + @DisplayName("Should generate content using non-streaming chat model") + void testGenerateContentWithChatModel() { + // Given + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("Hello")) + )) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("Hello, how can I help you?"); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final Flowable responseFlowable = langChain4j.generateContent(llmRequest, false); + final LlmResponse response = responseFlowable.blockingFirst(); + + // Then + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Hello, how can I help you?"); + + // Verify the request conversion + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + assertThat(capturedRequest.messages()).hasSize(1); + assertThat(capturedRequest.messages().get(0)).isInstanceOf(UserMessage.class); } @Test - void testSingleAgentWithTools() { - // given - AnthropicChatModel claudeModel = AnthropicChatModel.builder() - .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) - .build(); - - BaseAgent agent = LlmAgent.builder() - .name("friendly-weather-app") - .description("Friend agent that knows about the weather") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) - .instruction(""" - You are a friendly assistant. - - If asked about the weather forecast for a city, - you MUST call the `getWeather` function. - """) - .tools(FunctionTool.create(LangChain4jTest.class, "getWeather")) - .build(); - - // when - List events = askAgent(agent, "What's the weather like in Paris?"); - - // then - assertEquals(3, events.size()); - - events.forEach(event -> { - assertTrue(event.content().isPresent()); - System.out.printf("%nevent: %s%n", event.stringifyContent()); - }); - - Event eventOne = events.get(0); - Event eventTwo = events.get(1); - Event eventThree = events.get(2); - - // assert the first event is a function call - Content contentOne = eventOne.content().get(); - assertTrue(contentOne.parts().isPresent()); - List partsOne = contentOne.parts().get(); - assertEquals(1, partsOne.size()); - Optional functionCall = partsOne.get(0).functionCall(); - assertTrue(functionCall.isPresent()); - assertEquals("getWeather", functionCall.get().name().get()); - assertTrue(functionCall.get().args().get().containsKey("city")); - - // assert the second event is a function response - Content contentTwo = eventTwo.content().get(); - assertTrue(contentTwo.parts().isPresent()); - List partsTwo = contentTwo.parts().get(); - assertEquals(1, partsTwo.size()); - Optional functionResponseTwo = partsTwo.get(0).functionResponse(); - assertTrue(functionResponseTwo.isPresent()); - - // assert the third event is the final text response - assertTrue(eventThree.finalResponse()); - Content contentThree = eventThree.content().get(); - assertTrue(contentThree.parts().isPresent()); - List partsThree = contentThree.parts().get(); - assertEquals(1, partsThree.size()); - assertTrue(partsThree.get(0).text().get().contains("beautiful")); + @DisplayName("Should handle function calls in LLM responses") + void testGenerateContentWithFunctionCall() { + // Given + // Create a mock FunctionTool + final FunctionTool weatherTool = mock(FunctionTool.class); + when(weatherTool.name()).thenReturn("getWeather"); + when(weatherTool.description()).thenReturn("Get weather for a city"); + + // Create a mock FunctionDeclaration + final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class); + when(weatherTool.declaration()).thenReturn(Optional.of(functionDeclaration)); + + // Create a mock Schema + final Schema schema = mock(Schema.class); + when(functionDeclaration.parameters()).thenReturn(Optional.of(schema)); + + // Create a mock Type + final Type type = mock(Type.class); + when(schema.type()).thenReturn(Optional.of(type)); + when(type.knownEnum()).thenReturn(Type.Known.OBJECT); + + // Create a mock for schema properties and required fields + when(schema.properties()).thenReturn(Optional.of(Map.of("city", schema))); + when(schema.required()).thenReturn(Optional.of(List.of("city"))); + + // Create a real LlmRequest + // We'll use a real LlmRequest but we won't add any tools to it + // This is because we don't know the exact return type of LlmRequest.tools() + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("What's the weather in Paris?")) + )) + .build(); + + // Mock the AI response with a function call + final ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .id("123") + .name("getWeather") + //language=json + .arguments("{\"city\":\"Paris\"}") + .build(); + + final List toolExecutionRequests = List.of(toolExecutionRequest); + + final AiMessage aiMessage = AiMessage.builder() + .text("") + .toolExecutionRequests(toolExecutionRequests) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final Flowable responseFlowable = langChain4j.generateContent(llmRequest, false); + final LlmResponse response = responseFlowable.blockingFirst(); + + // Then + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts()).isPresent(); + + final List parts = response.content().get().parts().orElseThrow(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + + final FunctionCall functionCall = parts.get(0).functionCall().orElseThrow(); + assertThat(functionCall.name()).isEqualTo(Optional.of("getWeather")); + assertThat(functionCall.args()).isPresent(); + assertThat(functionCall.args().get()).containsEntry("city", "Paris"); } @Test - void testAgentTool() { - // given - OpenAiChatModel gptModel = OpenAiChatModel.builder() - .baseUrl("http://langchain4j.dev/demo/openai/v1") - .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) - .modelName(GPT_4_O_MINI) - .build(); - - LlmAgent weatherAgent = LlmAgent.builder() - .name("weather-agent") - .description("Weather agent") - .model(GEMINI_2_0_FLASH) - .instruction(""" - Your role is to always answer that the weather is sunny and 20°C. - """) - .build(); - - BaseAgent agent = LlmAgent.builder() - .name("friendly-weather-app") - .description("Friend agent that knows about the weather") - .model(new LangChain4j(gptModel)) - .instruction(""" - You are a friendly assistant. - - If asked about the weather forecast for a city, - you MUST call the `weather-agent` function. - """) - .tools(AgentTool.create(weatherAgent)) - .build(); - - // when - List events = askAgent(agent, "What's the weather like in Paris?"); - - // then - assertEquals(3, events.size()); - events.forEach(event -> { - assertTrue(event.content().isPresent()); - System.out.printf("%nevent: %s%n", event.stringifyContent()); - }); - - assertEquals(1, events.get(0).functionCalls().size()); - assertEquals("weather-agent", events.get(0).functionCalls().get(0).name().get()); - - assertEquals(1, events.get(1).functionResponses().size()); - assertTrue(events.get(1).functionResponses().get(0).response().get().toString().toLowerCase().contains("sunny")); - assertTrue(events.get(1).functionResponses().get(0).response().get().toString().contains("20")); - - assertTrue(events.get(2).finalResponse()); - assertTrue(events.get(2).content().get().text().contains("sunny")); - assertTrue(events.get(2).content().get().text().contains("20")); + @DisplayName("Should handle streaming responses correctly") + void testGenerateContentWithStreamingChatModel() { + // Given + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("Hello")) + )) + .build(); + + // Create a list to collect the responses + final List responses = new ArrayList<>(); + + // Set up the mock to capture and store the handler + final StreamingChatResponseHandler[] handlerRef = new StreamingChatResponseHandler[1]; + + doAnswer(invocation -> { + // Store the handler for later use + handlerRef[0] = invocation.getArgument(1); + return null; + }).when(streamingChatModel).chat(any(ChatRequest.class), any(StreamingChatResponseHandler.class)); + + // When + final Flowable responseFlowable = streamingLangChain4j.generateContent(llmRequest, true); + + // Subscribe to the flowable to collect responses + final var disposable = responseFlowable.subscribe(responses::add); + + // Verify the streaming model was called + verify(streamingChatModel).chat(any(ChatRequest.class), any(StreamingChatResponseHandler.class)); + + // Get the captured handler + final StreamingChatResponseHandler handler = handlerRef[0]; + + // Simulate streaming responses + handler.onPartialResponse("Hello"); + handler.onPartialResponse(", how"); + handler.onPartialResponse(" can I help"); + handler.onPartialResponse(" you?"); + + // Simulate a function call in the complete response + final ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .id("123") + .name("getWeather") + .arguments("{\"city\":\"Paris\"}") + .build(); + + final AiMessage aiMessage = AiMessage.builder() + .text("") + .toolExecutionRequests(List.of(toolExecutionRequest)) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + + // Simulate completion with a function call + handler.onCompleteResponse(chatResponse); + + // Then + assertThat(responses).hasSize(5); // 4 partial responses + 1 function call + + // Verify the partial responses + assertThat(responses.get(0).content().orElseThrow().text()).isEqualTo("Hello"); + assertThat(responses.get(1).content().orElseThrow().text()).isEqualTo(", how"); + assertThat(responses.get(2).content().orElseThrow().text()).isEqualTo(" can I help"); + assertThat(responses.get(3).content().orElseThrow().text()).isEqualTo(" you?"); + + // Verify the function call + assertThat(responses.get(4).content().orElseThrow().parts().orElseThrow()).hasSize(1); + assertThat(responses.get(4).content().orElseThrow().parts().orElseThrow().get(0).functionCall()).isPresent(); + final FunctionCall functionCall = responses.get(4).content().orElseThrow().parts().orElseThrow().get(0).functionCall().orElseThrow(); + assertThat(functionCall.name()).isEqualTo(Optional.of("getWeather")); + assertThat(functionCall.args().orElseThrow()).containsEntry("city", "Paris"); + + disposable.dispose(); } @Test - void testSubAgent() { - // given - OpenAiChatModel gptModel = OpenAiChatModel.builder() - .baseUrl("http://langchain4j.dev/demo/openai/v1") - .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) - .modelName(GPT_4_O_MINI) - .build(); - - LlmAgent greeterAgent = LlmAgent.builder() - .name("greeterAgent") - .description("Friendly agent that greets users") - .model(new LangChain4j(gptModel)) - .instruction(""" - You are a friendly that greets users. - """) - .build(); - - LlmAgent farewellAgent = LlmAgent.builder() - .name("farewellAgent") - .description("Friendly agent that says goodbye to users") - .model(new LangChain4j(gptModel)) - .instruction(""" - You are a friendly that says goodbye to users. - """) - .build(); - - LlmAgent coordinatorAgent = LlmAgent.builder() - .name("coordinator-agent") - .description("Coordinator agent") - .model(GEMINI_2_0_FLASH) - .instruction(""" - Your role is to coordinate 2 agents: - - `greeterAgent`: should reply to messages saying hello, hi, etc. - - `farewellAgent`: should reply to messages saying bye, goodbye, etc. - """) - .subAgents(greeterAgent, farewellAgent) - .build(); - - // when - List hiEvents = askAgent(coordinatorAgent, "Hi"); - List byeEvents = askAgent(coordinatorAgent, "Goodbye"); - - // then - hiEvents.forEach(event -> { System.out.println(event.stringifyContent()); }); - byeEvents.forEach(event -> { System.out.println(event.stringifyContent()); }); - - // Assertions for hiEvents - assertEquals(3, hiEvents.size()); - - Event hiEvent1 = hiEvents.get(0); - assertTrue(hiEvent1.content().isPresent()); - assertFalse(hiEvent1.functionCalls().isEmpty()); - assertEquals(1, hiEvent1.functionCalls().size()); - FunctionCall hiFunctionCall = hiEvent1.functionCalls().get(0); - assertTrue(hiFunctionCall.id().isPresent()); - assertEquals(Optional.of("transferToAgent"), hiFunctionCall.name()); - assertEquals(Optional.of(Map.of("agentName", "greeterAgent")), hiFunctionCall.args()); - - Event hiEvent2 = hiEvents.get(1); - assertTrue(hiEvent2.content().isPresent()); - assertFalse(hiEvent2.functionResponses().isEmpty()); - assertEquals(1, hiEvent2.functionResponses().size()); - FunctionResponse hiFunctionResponse = hiEvent2.functionResponses().get(0); - assertTrue(hiFunctionResponse.id().isPresent()); - assertEquals(Optional.of("transferToAgent"), hiFunctionResponse.name()); - assertEquals(Optional.of(Map.of()), hiFunctionResponse.response()); // Empty map for response - - Event hiEvent3 = hiEvents.get(2); - assertTrue(hiEvent3.content().isPresent()); - assertTrue(hiEvent3.content().get().text().toLowerCase().contains("hello")); - assertTrue(hiEvent3.finalResponse()); - - // Assertions for byeEvents - assertEquals(3, byeEvents.size()); - - Event byeEvent1 = byeEvents.get(0); - assertTrue(byeEvent1.content().isPresent()); - assertFalse(byeEvent1.functionCalls().isEmpty()); - assertEquals(1, byeEvent1.functionCalls().size()); - FunctionCall byeFunctionCall = byeEvent1.functionCalls().get(0); - assertTrue(byeFunctionCall.id().isPresent()); - assertEquals(Optional.of("transferToAgent"), byeFunctionCall.name()); - assertEquals(Optional.of(Map.of("agentName", "farewellAgent")), byeFunctionCall.args()); - - Event byeEvent2 = byeEvents.get(1); - assertTrue(byeEvent2.content().isPresent()); - assertFalse(byeEvent2.functionResponses().isEmpty()); - assertEquals(1, byeEvent2.functionResponses().size()); - FunctionResponse byeFunctionResponse = byeEvent2.functionResponses().get(0); - assertTrue(byeFunctionResponse.id().isPresent()); - assertEquals(Optional.of("transferToAgent"), byeFunctionResponse.name()); - assertEquals(Optional.of(Map.of()), byeFunctionResponse.response()); // Empty map for response - - Event byeEvent3 = byeEvents.get(2); - assertTrue(byeEvent3.content().isPresent()); - assertTrue(byeEvent3.content().get().text().toLowerCase().contains("goodbye")); - assertTrue(byeEvent3.finalResponse()); + @DisplayName("Should pass configuration options to LangChain4j") + void testGenerateContentWithConfigOptions() { + // Given + final GenerateContentConfig config = GenerateContentConfig.builder() + .temperature(0.7f) + .topP(0.9f) + .topK(40f) + .maxOutputTokens(100) + .presencePenalty(0.5f) + .build(); + + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("Hello")) + )) + .config(config) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("Hello, how can I help you?"); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final var llmResponse = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Assert the llmResponse + assertThat(llmResponse).isNotNull(); + assertThat(llmResponse.content()).isPresent(); + assertThat(llmResponse.content().get().text()).isEqualTo("Hello, how can I help you?"); + + // Assert the request configuration + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + assertThat(capturedRequest.temperature()).isCloseTo(0.7, offset(0.001)); + assertThat(capturedRequest.topP()).isCloseTo(0.9, offset(0.001)); + assertThat(capturedRequest.topK()).isEqualTo(40); + assertThat(capturedRequest.maxOutputTokens()).isEqualTo(100); + assertThat(capturedRequest.presencePenalty()).isCloseTo(0.5, offset(0.001)); } @Test - void testSimpleStreamingResponse() { - // given - AnthropicStreamingChatModel claudeStreamingModel = AnthropicStreamingChatModel.builder() - .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) - .build(); - - LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_3_7_SONNET_20250219); - - // when - Flowable responses = lc4jClaude.generateContent(LlmRequest.builder() - .contents(List.of(Content.fromParts(Part.fromText("Why is the sky blue?")))) - .build(), true); - - String fullResponse = String.join("", responses.blockingStream() - .map(llmResponse -> llmResponse.content().get().text()) - .toList()); - - // then - assertTrue(fullResponse.contains("blue")); - assertTrue(fullResponse.contains("Rayleigh")); - assertTrue(fullResponse.contains("scatter")); + @DisplayName("Should throw UnsupportedOperationException when connect is called") + void testConnectThrowsUnsupportedOperationException() { + // Given + final LlmRequest llmRequest = LlmRequest.builder().build(); + + // When/Then + assertThatThrownBy(() -> langChain4j.connect(llmRequest)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessage("Live connection is not supported for LangChain4j models."); } @Test - void testStreamingRunConfig() { - // given - OpenAiStreamingChatModel streamingModel = OpenAiStreamingChatModel.builder() - .baseUrl("http://langchain4j.dev/demo/openai/v1") - .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) - .modelName(GPT_4_O_MINI) - .build(); - -// AnthropicStreamingChatModel streamingModel = AnthropicStreamingChatModel.builder() -// .apiKey(System.getenv("ANTHROPIC_API_KEY")) -// .modelName(CLAUDE_3_7_SONNET_20250219) -// .build(); - -// GoogleAiGeminiStreamingChatModel streamingModel = GoogleAiGeminiStreamingChatModel.builder() -// .apiKey(System.getenv("GOOGLE_API_KEY")) -// .modelName("gemini-2.0-flash") -// .build(); - - LlmAgent agent = LlmAgent.builder() - .name("streaming-agent") - .description("Friendly science teacher agent") - .instruction(""" - You're a friendly science teacher. - You give concise answers about science topics. - - When someone greets you, respond with "Hello". - If someone asks about the weather, call the `getWeather` function. - """) - .model(new LangChain4j(streamingModel, "GPT_4_O_MINI")) -// .model(new LangChain4j(streamingModel, CLAUDE_3_7_SONNET_20250219)) - .tools(FunctionTool.create(LangChain4jTest.class, "getWeather")) - .build(); - - // when - List eventsHi = askAgentStreaming(agent, "Hi"); - String responseToHi = String.join("", eventsHi.stream() - .map(event -> event.content().get().text()) - .toList()); - - List eventsQubit = askAgentStreaming(agent, "Tell me about qubits"); - String responseToQubit = String.join("", eventsQubit.stream() - .map(event -> event.content().get().text()) - .toList()); - - List eventsWeather = askAgentStreaming(agent, "What's the weather in Paris?"); - String responseToWeather = String.join("", eventsWeather.stream() - .map(Event::stringifyContent) - .toList()); - - // then - - // Assertions for "Hi" - assertFalse(eventsHi.isEmpty(), "eventsHi should not be empty"); - // Depending on the model and streaming behavior, the number of events can vary. - // If a single "Hello" is expected in one event: - // assertEquals(1, eventsHi.size(), "Expected 1 event for 'Hi'"); - // assertEquals("Hello", responseToHi, "Response to 'Hi' should be 'Hello'"); - // If "Hello" can be streamed in multiple parts: - assertTrue(eventsHi.size() >= 1, "Expected at least 1 event for 'Hi'"); - assertTrue(responseToHi.trim().contains("Hello"), "Response to 'Hi' should be 'Hello'"); - - - // Assertions for "Tell me about qubits" - assertTrue(eventsQubit.size() > 1, "Expected multiple streaming events for 'qubit' question"); - assertTrue(responseToQubit.toLowerCase().contains("qubit"), "Response to 'qubit' should contain 'qubit'"); - assertTrue(responseToQubit.toLowerCase().contains("quantum"), "Response to 'qubit' should contain 'quantum'"); - assertTrue(responseToQubit.toLowerCase().contains("superposition"), "Response to 'qubit' should contain 'superposition'"); - - // Assertions for "What's the weather in Paris?" - assertTrue(eventsWeather.size() > 2, "Expected multiple events for weather question (function call, response, text)"); - - // Check for function call - Optional functionCallEvent = eventsWeather.stream() - .filter(e -> !e.functionCalls().isEmpty()) - .findFirst(); - assertTrue(functionCallEvent.isPresent(), "Should contain a function call event for weather"); - FunctionCall fc = functionCallEvent.get().functionCalls().get(0); - assertEquals(Optional.of("getWeather"), fc.name(), "Function call name should be 'getWeather'"); - assertTrue(fc.args().isPresent() && "Paris".equals(fc.args().get().get("city")), "Function call should be for 'Paris'"); - - // Check for function response - Optional functionResponseEvent = eventsWeather.stream() - .filter(e -> !e.functionResponses().isEmpty()) - .findFirst(); - assertTrue(functionResponseEvent.isPresent(), "Should contain a function response event for weather"); - FunctionResponse fr = functionResponseEvent.get().functionResponses().get(0); - assertEquals(Optional.of("getWeather"), fr.name(), "Function response name should be 'getWeather'"); - assertTrue(fr.response().isPresent()); - Map weatherResponseMap = (Map) fr.response().get(); - assertEquals("Paris", weatherResponseMap.get("city")); - assertTrue(weatherResponseMap.get("forecast").toString().contains("beautiful and sunny")); - - // Check the final aggregated text response - // Consolidate text parts from events that are not function calls or responses - String finalWeatherTextResponse = eventsWeather.stream() - .filter(event -> event.functionCalls().isEmpty() && event.functionResponses().isEmpty() && event.content().isPresent() && event.content().get().text() != null) - .map(event -> event.content().get().text()) - .collect(java.util.stream.Collectors.joining()) - .trim(); - - assertTrue(finalWeatherTextResponse.contains("Paris"), "Final weather response should mention Paris"); - assertTrue(finalWeatherTextResponse.toLowerCase().contains("beautiful and sunny"), "Final weather response should mention 'beautiful and sunny'"); - assertTrue(finalWeatherTextResponse.contains("10"), "Final weather response should mention '10'"); - assertTrue(finalWeatherTextResponse.contains("24"), "Final weather response should mention '24'"); - - // You can also assert on the concatenated `responseToWeather` if it's meant to capture the full interaction text - assertTrue(responseToWeather.contains("Function Call") && responseToWeather.contains("getWeather") && responseToWeather.contains("Paris")); - assertTrue(responseToWeather.contains("Function Response") && responseToWeather.contains("beautiful and sunny weather")); - assertTrue(responseToWeather.contains("sunny")); - assertTrue(responseToWeather.contains("24")); + @DisplayName("Should handle tool calling in LLM responses") + void testGenerateContentWithToolCalling() { + // Given + // Create a mock ChatResponse with a tool execution request + final ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .id("123") + .name("getWeather") + .arguments("{\"city\":\"Paris\"}") + .build(); + + final AiMessage aiMessage = AiMessage.builder() + .text("") + .toolExecutionRequests(List.of(toolExecutionRequest)) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // Create a LlmRequest with a user message + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("What's the weather in Paris?")) + )) + .build(); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response contains the expected function call + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts()).isPresent(); + + final List parts = response.content().get().parts().orElseThrow(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + + final FunctionCall functionCall = parts.get(0).functionCall().orElseThrow(); + assertThat(functionCall.name()).isEqualTo(Optional.of("getWeather")); + assertThat(functionCall.args()).isPresent(); + assertThat(functionCall.args().get()).containsEntry("city", "Paris"); + + // Verify the ChatModel was called + verify(chatModel).chat(any(ChatRequest.class)); } - private static List askAgent(BaseAgent agent, String... messages) { - return runLoop(agent, false, messages); + + @Test + @DisplayName("Should set ToolChoice to AUTO when FunctionCallingConfig mode is AUTO") + void testGenerateContentWithAutoToolChoice() { + // Given + // Create a FunctionCallingConfig with mode AUTO + final FunctionCallingConfig functionCallingConfig = mock(FunctionCallingConfig.class); + final FunctionCallingConfigMode functionMode = mock(FunctionCallingConfigMode.class); + + when(functionCallingConfig.mode()).thenReturn(Optional.of(functionMode)); + when(functionMode.knownEnum()).thenReturn(FunctionCallingConfigMode.Known.AUTO); + + // Create a ToolConfig with the FunctionCallingConfig + final ToolConfig toolConfig = mock(ToolConfig.class); + when(toolConfig.functionCallingConfig()).thenReturn(Optional.of(functionCallingConfig)); + + // Create a GenerateContentConfig with the ToolConfig + final GenerateContentConfig config = GenerateContentConfig.builder() + .toolConfig(toolConfig) + .build(); + + // Create a LlmRequest with the config + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("What's the weather in Paris?")) + )) + .config(config) + .build(); + + // Mock the AI response + final AiMessage aiMessage = AiMessage.from("It's sunny in Paris"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("It's sunny in Paris"); + + // Verify the request was built correctly with the tool config + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool choice is AUTO + assertThat(capturedRequest.toolChoice()).isEqualTo(dev.langchain4j.model.chat.request.ToolChoice.AUTO); } - private static List askAgentStreaming(BaseAgent agent, String... messages) { - return runLoop(agent, true, messages); + @Test + @DisplayName("Should set ToolChoice to REQUIRED when FunctionCallingConfig mode is ANY") + void testGenerateContentWithAnyToolChoice() { + // Given + // Create a FunctionCallingConfig with mode ANY and allowed function names + final FunctionCallingConfig functionCallingConfig = mock(FunctionCallingConfig.class); + final FunctionCallingConfigMode functionMode = mock(FunctionCallingConfigMode.class); + + when(functionCallingConfig.mode()).thenReturn(Optional.of(functionMode)); + when(functionMode.knownEnum()).thenReturn(FunctionCallingConfigMode.Known.ANY); + when(functionCallingConfig.allowedFunctionNames()).thenReturn(Optional.of(List.of("getWeather"))); + + // Create a ToolConfig with the FunctionCallingConfig + final ToolConfig toolConfig = mock(ToolConfig.class); + when(toolConfig.functionCallingConfig()).thenReturn(Optional.of(functionCallingConfig)); + + // Create a GenerateContentConfig with the ToolConfig + final GenerateContentConfig config = GenerateContentConfig.builder() + .toolConfig(toolConfig) + .build(); + + // Create a LlmRequest with the config + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("What's the weather in Paris?")) + )) + .config(config) + .build(); + + // Mock the AI response with a function call + final ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .id("123") + .name("getWeather") + .arguments("{\"city\":\"Paris\"}") + .build(); + + final AiMessage aiMessage = AiMessage.builder() + .text("") + .toolExecutionRequests(List.of(toolExecutionRequest)) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response contains the expected function call + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts()).isPresent(); + + final List parts = response.content().get().parts().orElseThrow(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + + final FunctionCall functionCall = parts.get(0).functionCall().orElseThrow(); + assertThat(functionCall.name()).isEqualTo(Optional.of("getWeather")); + assertThat(functionCall.args()).isPresent(); + assertThat(functionCall.args().get()).containsEntry("city", "Paris"); + + // Verify the request was built correctly with the tool config + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool choice is REQUIRED (mapped from ANY) + assertThat(capturedRequest.toolChoice()).isEqualTo(dev.langchain4j.model.chat.request.ToolChoice.REQUIRED); } - private static List runLoop(BaseAgent agent, boolean streaming, String... messages) { - ArrayList allEvents = new ArrayList<>(); - - Runner runner = new InMemoryRunner(agent, agent.name()); - Session session = runner.sessionService().createSession(agent.name(), "user132").blockingGet(); - - for (String message : messages) { - Content messageContent = Content.fromParts(Part.fromText(message)); - allEvents.addAll( - runner.runAsync(session, messageContent, - RunConfig.builder() - .setStreamingMode(streaming ? RunConfig.StreamingMode.SSE : RunConfig.StreamingMode.NONE) - .build()) - .blockingStream() - .toList() - ); - } - - return allEvents; + @Test + @DisplayName("Should disable tool calling when FunctionCallingConfig mode is NONE") + void testGenerateContentWithNoneToolChoice() { + // Given + // Create a FunctionCallingConfig with mode NONE + final FunctionCallingConfig functionCallingConfig = mock(FunctionCallingConfig.class); + final FunctionCallingConfigMode functionMode = mock(FunctionCallingConfigMode.class); + + when(functionCallingConfig.mode()).thenReturn(Optional.of(functionMode)); + when(functionMode.knownEnum()).thenReturn(FunctionCallingConfigMode.Known.NONE); + + // Create a ToolConfig with the FunctionCallingConfig + final ToolConfig toolConfig = mock(ToolConfig.class); + when(toolConfig.functionCallingConfig()).thenReturn(Optional.of(functionCallingConfig)); + + // Create a GenerateContentConfig with the ToolConfig + final GenerateContentConfig config = GenerateContentConfig.builder() + .toolConfig(toolConfig) + .build(); + + // Create a LlmRequest with the config + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("What's the weather in Paris?")) + )) + .config(config) + .build(); + + // Mock the AI response with text (no function call) + final AiMessage aiMessage = AiMessage.from("It's sunny in Paris"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response contains text (no function call) + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("It's sunny in Paris"); + + // Verify the request was built correctly with the tool config + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool specifications are empty + assertThat(capturedRequest.toolSpecifications()).isEmpty(); } - @Schema(description = "Function to get the weather forecast for a given city") - public static Map getWeather( - @Schema(name = "city", description = "The city to get the weather forecast for") - String city, - ToolContext toolContext) { - - return Map.of( - "city", city, - "forecast", "a beautiful and sunny weather", - "temperature", "from 10°C in the morning up to 24°C in the afternoon" - ); + @Test + @DisplayName("Should handle structured responses with JSON schema") + void testGenerateContentWithStructuredResponseJsonSchema() { + // Given + // Create a JSON schema for the structured response + final JsonObjectSchema responseSchema = JsonObjectSchema.builder() + .addProperty("name", JsonStringSchema.builder().build()) + .addProperty("age", JsonStringSchema.builder().build()) + .addProperty("city", JsonStringSchema.builder().build()) + .build(); + + // Create a GenerateContentConfig without responseSchema + final GenerateContentConfig config = GenerateContentConfig.builder() + .build(); + + // Create a LlmRequest with the config + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("Give me information about John Doe")) + )) + .config(config) + .build(); + + // Mock the AI response with structured JSON data + final String jsonResponse = """ + { + "name": "John Doe", + "age": "30", + "city": "New York" + } + """; + final AiMessage aiMessage = AiMessage.from(jsonResponse); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response contains the expected JSON data + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo(jsonResponse); + + // Verify the request was built correctly + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify the request contains the expected messages + assertThat(capturedRequest.messages()).hasSize(1); + assertThat(capturedRequest.messages().get(0)).isInstanceOf(UserMessage.class); + final UserMessage userMessage = (UserMessage) capturedRequest.messages().get(0); + assertThat(userMessage.singleText()).isEqualTo("Give me information about John Doe"); } } From 29f8bd0780f6214b0e6ad9df1e8574014563dd45 Mon Sep 17 00:00:00 2001 From: kpavlov <1517853+kpavlov@users.noreply.github.com> Date: Mon, 9 Jun 2025 23:12:22 +0300 Subject: [PATCH 15/20] Run integration test conditionally Use `@EnabledIfEnvironmentVariable` to conditionally skip tests --- .../langchain4j/LangChain4jIntegrationTest.java | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index 1406959..7c2176c 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -39,8 +39,8 @@ import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.openai.OpenAiStreamingChatModel; import io.reactivex.rxjava3.core.Flowable; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import java.util.ArrayList; import java.util.List; @@ -54,13 +54,8 @@ class LangChain4jIntegrationTest { public static final String GEMINI_2_0_FLASH = "gemini-2.0-flash"; public static final String GPT_4_O_MINI = "gpt-4o-mini"; - @BeforeAll - public static void setUp() { - assertNotNull(System.getenv("ANTHROPIC_API_KEY")); - assertNotNull(System.getenv("GOOGLE_API_KEY")); - } - @Test + @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = "\\S+") void testSimpleAgent() { // given AnthropicChatModel claudeModel = AnthropicChatModel.builder() @@ -93,6 +88,7 @@ void testSimpleAgent() { } @Test + @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = "\\S+") void testSingleAgentWithTools() { // given AnthropicChatModel claudeModel = AnthropicChatModel.builder() @@ -156,6 +152,7 @@ void testSingleAgentWithTools() { } @Test + @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "\\S+") void testAgentTool() { // given OpenAiChatModel gptModel = OpenAiChatModel.builder() @@ -213,6 +210,8 @@ void testAgentTool() { } @Test + @EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = "\\S+") + @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "\\S+") void testSubAgent() { // given OpenAiChatModel gptModel = OpenAiChatModel.builder() @@ -313,6 +312,7 @@ void testSubAgent() { } @Test + @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = "\\S+") void testSimpleStreamingResponse() { // given AnthropicStreamingChatModel claudeStreamingModel = AnthropicStreamingChatModel.builder() @@ -338,6 +338,7 @@ void testSimpleStreamingResponse() { } @Test + @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "\\S+") void testStreamingRunConfig() { // given OpenAiStreamingChatModel streamingModel = OpenAiStreamingChatModel.builder() From 60107932d57a97eb663581ee22b313feea3c8c9d Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Sun, 15 Jun 2025 12:27:12 +0200 Subject: [PATCH 16/20] [WIP] Improve support for multiple modalities, moved run loop in its own class, externalized tool function to avoid reflection issues --- .../adk/models/langchain4j/LangChain4j.java | 139 +++++++++++++----- .../LangChain4jIntegrationTest.java | 60 ++------ .../models/langchain4j/LangChain4jTest.java | 15 ++ .../adk/models/langchain4j/RunLoop.java | 66 +++++++++ .../adk/models/langchain4j/ToolExample.java | 34 +++++ 5 files changed, 224 insertions(+), 90 deletions(-) create mode 100644 contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java create mode 100644 contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/ToolExample.java diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 2942dd5..9528040 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -22,6 +22,7 @@ import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionCallingConfigMode; @@ -35,11 +36,20 @@ import dev.langchain4j.Experimental; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.audio.Audio; +import dev.langchain4j.data.image.Image; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.AudioContent; import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ImageContent; +import dev.langchain4j.data.message.PdfFileContent; import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.TextContent; import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.data.message.VideoContent; +import dev.langchain4j.data.pdf.PdfFile; +import dev.langchain4j.data.video.Video; import dev.langchain4j.exception.UnsupportedFeatureException; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; @@ -54,10 +64,12 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import io.grpc.netty.shaded.io.netty.handler.codec.base64.Base64Encoder; import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; +import java.util.Base64; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -74,7 +86,7 @@ public class LangChain4j extends BaseLlm { private final StreamingChatModel streamingChatModel; private final ObjectMapper objectMapper; - public LangChain4j(ChatModel chatModel) { // TODO + public LangChain4j(ChatModel chatModel) { super(Objects.requireNonNull(chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null")); this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); @@ -82,7 +94,7 @@ public LangChain4j(ChatModel chatModel) { // TODO this.objectMapper = new ObjectMapper(); } - public LangChain4j(ChatModel chatModel, String modelName) { // TODO + public LangChain4j(ChatModel chatModel, String modelName) { super(Objects.requireNonNull(modelName, "chat model name cannot be null")); this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); @@ -90,7 +102,7 @@ public LangChain4j(ChatModel chatModel, String modelName) { // TODO this.objectMapper = new ObjectMapper(); } - public LangChain4j(StreamingChatModel streamingChatModel) { // TODO + public LangChain4j(StreamingChatModel streamingChatModel) { super(Objects.requireNonNull(streamingChatModel.defaultRequestParameters().modelName(), "streaming chat model name cannot be null")); this.chatModel = null; @@ -98,7 +110,7 @@ public LangChain4j(StreamingChatModel streamingChatModel) { // TODO this.objectMapper = new ObjectMapper(); } - public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { // TODO + public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); this.chatModel = null; this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); @@ -121,7 +133,6 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre ChatRequest chatRequest = toChatRequest(llmRequest); - // TODO is streaming properly implemented? What happens for function calls? return Flowable.create(emitter -> { streamingChatModel.chat(chatRequest, new StreamingChatResponseHandler() { @Override @@ -164,12 +175,12 @@ public void onError(Throwable throwable) { ChatRequest chatRequest = toChatRequest(llmRequest); ChatResponse chatResponse = chatModel.chat(chatRequest); LlmResponse llmResponse = toLlmResponse(chatResponse); + return Flowable.just(llmResponse); } } private ChatRequest toChatRequest(LlmRequest llmRequest) { - // TODO llmRequest.model() ? ChatRequest.Builder requestBuilder = ChatRequest.builder(); List toolSpecifications = toToolSpecifications(llmRequest); @@ -195,7 +206,6 @@ private ChatRequest toChatRequest(LlmRequest llmRequest) { ToolConfig toolConfig = generateContentConfig.toolConfig().get(); toolConfig.functionCallingConfig().ifPresent(functionCallingConfig -> { functionCallingConfig.mode().ifPresent(functionMode -> { - // TODO if (functionMode.knownEnum().equals(FunctionCallingConfigMode.Known.AUTO)) { requestBuilder.toolChoice(ToolChoice.AUTO); } else if (functionMode.knownEnum().equals(FunctionCallingConfigMode.Known.ANY)) { @@ -221,7 +231,6 @@ private ChatRequest toChatRequest(LlmRequest llmRequest) { return requestBuilder .messages(toMessages(llmRequest)) - // TODO? .build(); } @@ -233,7 +242,7 @@ private List toMessages(LlmRequest llmRequest) { } private ChatMessage toChatMessage(Content content) { - String role = content.role().orElseThrow().toLowerCase(); // TODO + String role = content.role().orElseThrow().toLowerCase(); return switch (role) { case "user" -> toUserOrToolResultMessage(content); case "model", "assistant" -> toAiMessage(content); @@ -242,15 +251,15 @@ private ChatMessage toChatMessage(Content content) { } private ChatMessage toUserOrToolResultMessage(Content content) { - List texts = new ArrayList<>(); ToolExecutionResultMessage toolExecutionResultMessage = null; ToolExecutionRequest toolExecutionRequest = null; + List lc4jContents = new ArrayList<>(); + for (Part part : content.parts().orElse(List.of())) { if (part.text().isPresent()) { - texts.add(part.text().get()); + lc4jContents.add(TextContent.from(part.text().get())); } else if (part.functionResponse().isPresent()) { - // TODO multiple tool calls? should be 1 per part? FunctionResponse functionResponse = part.functionResponse().get(); toolExecutionResultMessage = ToolExecutionResultMessage.from( functionResponse.id().orElseThrow(), @@ -264,8 +273,56 @@ private ChatMessage toUserOrToolResultMessage(Content content) { .name(functionCall.name().orElseThrow()) .arguments(toJson(functionCall.args().orElse(Map.of()))) .build(); + } else if (part.inlineData().isPresent()) { + Blob blob = part.inlineData().get(); + + if (blob.mimeType().isEmpty() || blob.data().isEmpty()) { + throw new IllegalArgumentException("Mime type and data required"); + } + + byte[] bytes = blob.data().get(); + String mimeType = blob.mimeType().get(); + + Base64.Encoder encoder = Base64.getEncoder(); + + dev.langchain4j.data.message.Content lc4jContent = null; + + if (mimeType.startsWith("audio/")) { + lc4jContent = AudioContent.from(Audio.builder() + .base64Data(encoder.encodeToString(bytes)) + .mimeType(mimeType) + .build()); + } else if (mimeType.startsWith("video/")) { + lc4jContent = VideoContent.from(Video.builder() + .base64Data(encoder.encodeToString(bytes)) + .mimeType(mimeType) + .build()); + } else if (mimeType.startsWith("image/")) { + lc4jContent = ImageContent.from(Image.builder() + .base64Data(encoder.encodeToString(bytes)) + .mimeType(mimeType) + .build()); + } else if (mimeType.startsWith("application/pdf")) { + lc4jContent = PdfFileContent.from(PdfFile.builder() + .base64Data(encoder.encodeToString(bytes)) + .mimeType(mimeType) + .build()); + } else if (mimeType.startsWith("text/") + || mimeType.equals("application/json") + || mimeType.endsWith("+json") + || mimeType.endsWith("+xml")) { + // TODO are there missing text based mime types? + // TODO should we assume UTF_8? + lc4jContents.add(TextContent.from(new String(bytes, java.nio.charset.StandardCharsets.UTF_8))); + } + + if (lc4jContent != null) { + lc4jContents.add(lc4jContent); + } else { + throw new IllegalArgumentException("Unknown or unhandled mime type: " + mimeType); + } } else { - throw new IllegalStateException("Either text or functionCall is expected, but was: " + part); + throw new IllegalStateException("Text, media or functionCall is expected, but was: " + part); } } @@ -274,7 +331,7 @@ private ChatMessage toUserOrToolResultMessage(Content content) { } else if (toolExecutionRequest != null){ return AiMessage.aiMessage(toolExecutionRequest); } else { - return UserMessage.from(String.join("\n", texts)); + return UserMessage.from(lc4jContents); } } @@ -324,7 +381,7 @@ private List toToolSpecifications(LlmRequest llmRequest) { ToolSpecification toolSpecification = ToolSpecification.builder() .name(baseTool.name()) .description(baseTool.description()) - .parameters(toParameters(schema)) // TODO + .parameters(toParameters(schema)) .build(); toolSpecifications.add(toolSpecification); } else { @@ -337,7 +394,7 @@ private List toToolSpecifications(LlmRequest llmRequest) { } }); - return toolSpecifications; // TODO + return toolSpecifications; } private JsonObjectSchema toParameters(Schema schema) { @@ -345,7 +402,7 @@ private JsonObjectSchema toParameters(Schema schema) { return JsonObjectSchema.builder() .addProperties(toProperties(schema)) .required(schema.required().orElse(List.of())) - .build(); // TODO + .build(); } else { throw new UnsupportedOperationException("LangChain4jLlm does not support schema of type: " + schema.type()); } @@ -359,28 +416,32 @@ private Map toProperties(Schema schema) { } private JsonSchemaElement toJsonSchemaElement(Schema schema) { - Type type = schema.type().get(); // TODO - return switch (type.knownEnum()) { - case STRING -> JsonStringSchema.builder() - .description(schema.description().orElse(null)) - .build(); - case NUMBER -> JsonNumberSchema.builder() - .description(schema.description().orElse(null)) - .build(); - case INTEGER -> JsonIntegerSchema.builder() - .description(schema.description().orElse(null)) - .build(); - case BOOLEAN -> JsonBooleanSchema.builder() - .description(schema.description().orElse(null)) - .build(); - case ARRAY -> JsonArraySchema.builder() - .description(schema.description().orElse(null)) - .items(toJsonSchemaElement(schema.items().orElseThrow())) - .build(); - case OBJECT -> toParameters(schema); - case TYPE_UNSPECIFIED -> - throw new UnsupportedFeatureException("LangChain4jLlm does not support schema of type: " + type); - }; + if (schema != null && schema.type().isPresent()) { + Type type = schema.type().get(); + return switch (type.knownEnum()) { + case STRING -> JsonStringSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case NUMBER -> JsonNumberSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case INTEGER -> JsonIntegerSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case BOOLEAN -> JsonBooleanSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case ARRAY -> JsonArraySchema.builder() + .description(schema.description().orElse(null)) + .items(toJsonSchemaElement(schema.items().orElseThrow())) + .build(); + case OBJECT -> toParameters(schema); + case TYPE_UNSPECIFIED -> + throw new UnsupportedFeatureException("LangChain4jLlm does not support schema of type: " + type); + }; + } else { + throw new IllegalArgumentException("Schema type cannot be null or absent"); + } } private LlmResponse toLlmResponse(ChatResponse chatResponse) { diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index 7c2176c..beabd66 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -15,21 +15,17 @@ */ package com.google.adk.models.langchain4j; +import static com.google.adk.models.langchain4j.RunLoop.askAgent; +import static com.google.adk.models.langchain4j.RunLoop.askAgentStreaming; import static org.junit.jupiter.api.Assertions.*; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; -import com.google.adk.runner.InMemoryRunner; -import com.google.adk.runner.Runner; -import com.google.adk.sessions.Session; import com.google.adk.tools.AgentTool; -import com.google.adk.tools.Annotations.Schema; import com.google.adk.tools.FunctionTool; -import com.google.adk.tools.ToolContext; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; @@ -42,7 +38,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -106,7 +101,7 @@ void testSingleAgentWithTools() { If asked about the weather forecast for a city, you MUST call the `getWeather` function. """) - .tools(FunctionTool.create(LangChain4jIntegrationTest.class, "getWeather")) + .tools(FunctionTool.create(ToolExample.class, "getWeather")) .build(); // when @@ -125,16 +120,20 @@ void testSingleAgentWithTools() { Event eventThree = events.get(2); // assert the first event is a function call + assertTrue(eventOne.content().isPresent()); Content contentOne = eventOne.content().get(); assertTrue(contentOne.parts().isPresent()); List partsOne = contentOne.parts().get(); assertEquals(1, partsOne.size()); Optional functionCall = partsOne.get(0).functionCall(); assertTrue(functionCall.isPresent()); + assertTrue(functionCall.get().name().isPresent()); assertEquals("getWeather", functionCall.get().name().get()); + assertTrue(functionCall.get().args().isPresent()); assertTrue(functionCall.get().args().get().containsKey("city")); // assert the second event is a function response + assertTrue(eventTwo.content().isPresent()); Content contentTwo = eventTwo.content().get(); assertTrue(contentTwo.parts().isPresent()); List partsTwo = contentTwo.parts().get(); @@ -144,10 +143,12 @@ void testSingleAgentWithTools() { // assert the third event is the final text response assertTrue(eventThree.finalResponse()); + assertTrue(eventThree.content().isPresent()); Content contentThree = eventThree.content().get(); assertTrue(contentThree.parts().isPresent()); List partsThree = contentThree.parts().get(); assertEquals(1, partsThree.size()); + assertTrue(partsThree.get(0).text().isPresent()); assertTrue(partsThree.get(0).text().get().contains("beautiful")); } @@ -397,7 +398,6 @@ void testStreamingRunConfig() { // assertEquals(1, eventsHi.size(), "Expected 1 event for 'Hi'"); // assertEquals("Hello", responseToHi, "Response to 'Hi' should be 'Hello'"); // If "Hello" can be streamed in multiple parts: - assertTrue(eventsHi.size() >= 1, "Expected at least 1 event for 'Hi'"); assertTrue(responseToHi.trim().contains("Hello"), "Response to 'Hi' should be 'Hello'"); @@ -450,46 +450,4 @@ void testStreamingRunConfig() { assertTrue(responseToWeather.contains("sunny")); assertTrue(responseToWeather.contains("24")); } - - private static List askAgent(BaseAgent agent, String... messages) { - return runLoop(agent, false, messages); - } - - private static List askAgentStreaming(BaseAgent agent, String... messages) { - return runLoop(agent, true, messages); - } - - private static List runLoop(BaseAgent agent, boolean streaming, String... messages) { - ArrayList allEvents = new ArrayList<>(); - - Runner runner = new InMemoryRunner(agent, agent.name()); - Session session = runner.sessionService().createSession(agent.name(), "user132").blockingGet(); - - for (String message : messages) { - Content messageContent = Content.fromParts(Part.fromText(message)); - allEvents.addAll( - runner.runAsync(session, messageContent, - RunConfig.builder() - .setStreamingMode(streaming ? RunConfig.StreamingMode.SSE : RunConfig.StreamingMode.NONE) - .build()) - .blockingStream() - .toList() - ); - } - - return allEvents; - } - - @Schema(description = "Function to get the weather forecast for a given city") - public static Map getWeather( - @Schema(name = "city", description = "The city to get the weather forecast for") - String city, - ToolContext toolContext) { - - return Map.of( - "city", city, - "forecast", "a beautiful and sunny weather", - "temperature", "from 10°C in the morning up to 24°C in the afternoon" - ); - } } diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index f7e6450..75ec6ac 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.google.adk.models.langchain4j; import com.google.adk.models.LlmRequest; diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java new file mode 100644 index 0000000..bfba6d9 --- /dev/null +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java @@ -0,0 +1,66 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.models.langchain4j; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.RunConfig; +import com.google.adk.events.Event; +import com.google.adk.runner.InMemoryRunner; +import com.google.adk.runner.Runner; +import com.google.adk.sessions.Session; +import com.google.genai.types.Content; +import com.google.genai.types.Part; + +import java.util.ArrayList; +import java.util.List; + +public class RunLoop { + public static List askAgent(BaseAgent agent, Object... messages) { + return runLoop(agent, false, messages); + } + + public static List askAgentStreaming(BaseAgent agent, Object... messages) { + return runLoop(agent, true, messages); + } + + public static List runLoop(BaseAgent agent, boolean streaming, Object... messages) { + ArrayList allEvents = new ArrayList<>(); + + Runner runner = new InMemoryRunner(agent, agent.name()); + Session session = runner.sessionService().createSession(agent.name(), "user132").blockingGet(); + + for (Object message : messages) { + Content messageContent = null; + if (message instanceof String) { + messageContent = Content.fromParts(Part.fromText((String) message)); + } else if (message instanceof Part) { + messageContent = Content.fromParts((Part) message); + } else if (message instanceof Content) { + messageContent = (Content) message; + } + allEvents.addAll( + runner.runAsync(session, messageContent, + RunConfig.builder() + .setStreamingMode(streaming ? RunConfig.StreamingMode.SSE : RunConfig.StreamingMode.NONE) + .build()) + .blockingStream() + .toList() + ); + } + + return allEvents; + } +} diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/ToolExample.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/ToolExample.java new file mode 100644 index 0000000..a6f92b7 --- /dev/null +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/ToolExample.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.models.langchain4j; + +import com.google.adk.tools.Annotations; + +import java.util.Map; + +public class ToolExample { + @Annotations.Schema(description = "Function to get the weather forecast for a given city") + public static Map getWeather( + @Annotations.Schema(name = "city", description = "The city to get the weather forecast for") + String city) { + + return Map.of( + "city", city, + "forecast", "a beautiful and sunny weather", + "temperature", "from 10°C in the morning up to 24°C in the afternoon" + ); + } +} From 7ac470c2954b89ed6ce44d230dadd5c7f28c9061 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Wed, 18 Jun 2025 14:05:30 +0200 Subject: [PATCH 17/20] [WIP] Depend on parent pom.xml --- contrib/langchain4j/pom.xml | 18 ++++++++++++------ .../adk/models/langchain4j/LangChain4j.java | 1 - .../LangChain4jIntegrationTest.java | 2 +- pom.xml | 1 + 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index c9ca06c..1bdd55a 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -18,11 +18,17 @@ 4.0.0 - com.google.adk - google-adk-contrib - 0.1.0 + + + com.google.adk + google-adk-parent + 0.1.1-SNAPSHOT + + + google-adk-contrib-langchain4j jar - Agent Development Kit - Contributions + + Agent Development Kit - Contributions - LangChain4j https://github.com/google/adk-java @@ -91,12 +97,12 @@ com.google.adk google-adk - 0.1.0 + 0.1.1-SNAPSHOT com.google.adk google-adk-dev - 0.1.0 + 0.1.1-SNAPSHOT com.google.genai diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 9528040..ba94100 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -64,7 +64,6 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; -import io.grpc.netty.shaded.io.netty.handler.codec.base64.Base64Encoder; import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index beabd66..d7f1007 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -370,7 +370,7 @@ void testStreamingRunConfig() { """) .model(new LangChain4j(streamingModel, "GPT_4_O_MINI")) // .model(new LangChain4j(streamingModel, CLAUDE_3_7_SONNET_20250219)) - .tools(FunctionTool.create(LangChain4jIntegrationTest.class, "getWeather")) + .tools(FunctionTool.create(ToolExample.class, "getWeather")) .build(); // when diff --git a/pom.xml b/pom.xml index 92bced1..4e2e737 100644 --- a/pom.xml +++ b/pom.xml @@ -28,6 +28,7 @@ core dev + contrib/langchain4j From 069ea36f2f803dc1809187527ce15136f9c2ea66 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Wed, 18 Jun 2025 14:59:52 +0200 Subject: [PATCH 18/20] [WIP] Don't use the demo key for OpenAI --- .../adk/models/langchain4j/LangChain4jIntegrationTest.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index d7f1007..09b2406 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -157,8 +157,7 @@ void testSingleAgentWithTools() { void testAgentTool() { // given OpenAiChatModel gptModel = OpenAiChatModel.builder() - .baseUrl("http://langchain4j.dev/demo/openai/v1") - .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) + .apiKey(System.getenv("OPENAI_API_KEY")) .modelName(GPT_4_O_MINI) .build(); From 85f52efba5fc2563cab211e4f583766f675d9c58 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Wed, 18 Jun 2025 17:14:43 +0200 Subject: [PATCH 19/20] [WIP] Use ${project.version} in pom.xml --- contrib/langchain4j/pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index 1bdd55a..e58f9f3 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -97,12 +97,12 @@ com.google.adk google-adk - 0.1.1-SNAPSHOT + ${project.version} com.google.adk google-adk-dev - 0.1.1-SNAPSHOT + ${project.version} com.google.genai From 6aeb27e7025e2d631cfa555e9e3aa5c2519c489e Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Wed, 18 Jun 2025 17:49:31 +0200 Subject: [PATCH 20/20] [WIP] Removing langchain4j mention from core --- core/pom.xml | 1 - 1 file changed, 1 deletion(-) diff --git a/core/pom.xml b/core/pom.xml index 91e3139..05fc5d9 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -41,7 +41,6 @@ 1.6.0 2.19.0 4.12.0 - 1.0.1