Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,19 @@
import com.google.adk.models.BaseLlmConnection;
import com.google.adk.models.LlmRequest;
import com.google.adk.models.LlmResponse;
import com.google.auto.value.AutoValue;
import com.google.genai.types.Blob;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.FunctionCallingConfigMode;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.GenerateContentResponseUsageMetadata;
import com.google.genai.types.Part;
import com.google.genai.types.Schema;
import com.google.genai.types.ToolConfig;
import com.google.genai.types.Type;
import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.audio.Audio;
Expand All @@ -52,6 +53,7 @@
import dev.langchain4j.data.pdf.PdfFile;
import dev.langchain4j.data.video.Video;
import dev.langchain4j.exception.UnsupportedFeatureException;
import dev.langchain4j.model.TokenCountEstimator;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.request.ChatRequest;
Expand All @@ -65,128 +67,167 @@
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.TokenUsage;
import io.reactivex.rxjava3.core.BackpressureStrategy;
import io.reactivex.rxjava3.core.Flowable;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import org.jspecify.annotations.Nullable;

@Experimental
public class LangChain4j extends BaseLlm {
@AutoValue
public abstract class LangChain4j extends BaseLlm {

private static final TypeReference<Map<String, Object>> MAP_TYPE_REFERENCE =
new TypeReference<>() {};

private final ChatModel chatModel;
private final StreamingChatModel streamingChatModel;
private final ObjectMapper objectMapper;
LangChain4j() {
super("");
}

@Nullable
public abstract ChatModel chatModel();

@Nullable
public abstract StreamingChatModel streamingChatModel();

public abstract ObjectMapper objectMapper();

public abstract String modelName();

@Nullable
public abstract TokenCountEstimator tokenCountEstimator();

@Override
public String model() {
return modelName();
}

public static Builder builder() {
return new AutoValue_LangChain4j.Builder().objectMapper(new ObjectMapper());
}

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder chatModel(ChatModel chatModel);

public abstract Builder streamingChatModel(StreamingChatModel streamingChatModel);

public abstract Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator);

public abstract Builder objectMapper(ObjectMapper objectMapper);

public abstract Builder modelName(String modelName);

public abstract LangChain4j build();
}

public LangChain4j(ChatModel chatModel) {
super(
Objects.requireNonNull(
chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null"));
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
this.streamingChatModel = null;
this.objectMapper = new ObjectMapper();
this(chatModel, null, null, chatModel.defaultRequestParameters().modelName(), null);
}

public LangChain4j(ChatModel chatModel, String modelName) {
super(Objects.requireNonNull(modelName, "chat model name cannot be null"));
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
this.streamingChatModel = null;
this.objectMapper = new ObjectMapper();
this(chatModel, null, null, modelName, null);
}

public LangChain4j(StreamingChatModel streamingChatModel) {
super(
Objects.requireNonNull(
streamingChatModel.defaultRequestParameters().modelName(),
"streaming chat model name cannot be null"));
this.chatModel = null;
this.streamingChatModel =
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
this.objectMapper = new ObjectMapper();
this(
null,
streamingChatModel,
null,
streamingChatModel.defaultRequestParameters().modelName(),
null);
}

public LangChain4j(StreamingChatModel streamingChatModel, String modelName) {
super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null"));
this.chatModel = null;
this.streamingChatModel =
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
this.objectMapper = new ObjectMapper();
this(null, streamingChatModel, null, modelName, null);
}

public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) {
super(Objects.requireNonNull(modelName, "model name cannot be null"));
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
this.streamingChatModel =
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
this.objectMapper = new ObjectMapper();
this(chatModel, streamingChatModel, null, modelName, null);
}

private LangChain4j(
ChatModel chatModel,
StreamingChatModel streamingChatModel,
ObjectMapper objectMapper,
String modelName,
TokenCountEstimator tokenCountEstimator) {
this();
LangChain4j.builder()
.chatModel(chatModel)
.streamingChatModel(streamingChatModel)
.objectMapper(objectMapper)
.modelName(modelName)
.tokenCountEstimator(tokenCountEstimator)
.build();
}

@Override
public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stream) {
if (stream) {
if (this.streamingChatModel == null) {
if (this.streamingChatModel() == null) {
return Flowable.error(new IllegalStateException("StreamingChatModel is not configured"));
}

ChatRequest chatRequest = toChatRequest(llmRequest);

return Flowable.create(
emitter -> {
streamingChatModel.chat(
chatRequest,
new StreamingChatResponseHandler() {
@Override
public void onPartialResponse(String s) {
emitter.onNext(
LlmResponse.builder().content(Content.fromParts(Part.fromText(s))).build());
}

@Override
public void onCompleteResponse(ChatResponse chatResponse) {
if (chatResponse.aiMessage().hasToolExecutionRequests()) {
AiMessage aiMessage = chatResponse.aiMessage();
toParts(aiMessage).stream()
.map(Part::functionCall)
.forEach(
functionCall -> {
functionCall.ifPresent(
function -> {
emitter.onNext(
LlmResponse.builder()
.content(
Content.fromParts(
Part.fromFunctionCall(
function.name().orElse(""),
function.args().orElse(Map.of()))))
.build());
});
});
}
emitter.onComplete();
}

@Override
public void onError(Throwable throwable) {
emitter.onError(throwable);
}
});
streamingChatModel()
.chat(
chatRequest,
new StreamingChatResponseHandler() {
@Override
public void onPartialResponse(String s) {
emitter.onNext(
LlmResponse.builder()
.content(Content.fromParts(Part.fromText(s)))
.build());
}

@Override
public void onCompleteResponse(ChatResponse chatResponse) {
if (chatResponse.aiMessage().hasToolExecutionRequests()) {
AiMessage aiMessage = chatResponse.aiMessage();
toParts(aiMessage).stream()
.map(Part::functionCall)
.forEach(
functionCall -> {
functionCall.ifPresent(
function -> {
emitter.onNext(
LlmResponse.builder()
.content(
Content.fromParts(
Part.fromFunctionCall(
function.name().orElse(""),
function.args().orElse(Map.of()))))
.build());
});
});
}
emitter.onComplete();
}

@Override
public void onError(Throwable throwable) {
emitter.onError(throwable);
}
});
},
BackpressureStrategy.BUFFER);
} else {
if (this.chatModel == null) {
if (this.chatModel() == null) {
return Flowable.error(new IllegalStateException("ChatModel is not configured"));
}

ChatRequest chatRequest = toChatRequest(llmRequest);
ChatResponse chatResponse = chatModel.chat(chatRequest);
LlmResponse llmResponse = toLlmResponse(chatResponse);
ChatResponse chatResponse = chatModel().chat(chatRequest);
LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest);

return Flowable.just(llmResponse);
}
Expand Down Expand Up @@ -413,7 +454,7 @@ private AiMessage toAiMessage(Content content) {

private String toJson(Object object) {
try {
return objectMapper.writeValueAsString(object);
return objectMapper().writeValueAsString(object);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -511,11 +552,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) {
}
}

private LlmResponse toLlmResponse(ChatResponse chatResponse) {
private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) {
Content content =
Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build();

return LlmResponse.builder().content(content).build();
LlmResponse.Builder builder = LlmResponse.builder().content(content);
TokenUsage tokenUsage = chatResponse.tokenUsage();
if (tokenCountEstimator() != null) {
try {
int estimatedInput =
tokenCountEstimator().estimateTokenCountInMessages(chatRequest.messages());
int estimatedOutput =
tokenCountEstimator().estimateTokenCountInText(chatResponse.aiMessage().text());
int estimatedTotal = estimatedInput + estimatedOutput;
builder.usageMetadata(
GenerateContentResponseUsageMetadata.builder()
.promptTokenCount(estimatedInput)
.candidatesTokenCount(estimatedOutput)
.totalTokenCount(estimatedTotal)
.build());
} catch (Exception e) {
e.printStackTrace();
}
} else if (tokenUsage != null) {
builder.usageMetadata(
GenerateContentResponseUsageMetadata.builder()
.promptTokenCount(tokenUsage.inputTokenCount())
.candidatesTokenCount(tokenUsage.outputTokenCount())
.totalTokenCount(tokenUsage.totalTokenCount())
.build());
}

return builder.build();
}

private List<Part> toParts(AiMessage aiMessage) {
Expand Down Expand Up @@ -546,7 +614,7 @@ private List<Part> toParts(AiMessage aiMessage) {

private Map<String, Object> toArgs(ToolExecutionRequest toolExecutionRequest) {
try {
return objectMapper.readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE);
return objectMapper().readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
Expand Down
Loading