Skip to content

Commit ec9dd00

Browse files
feat: [vertexai] add GenerateContentConfig to generateContentStream method (#10424)
PiperOrigin-RevId: 609300701 Co-authored-by: Jaycee Li <jayceeli@google.com>
1 parent 04e9574 commit ec9dd00

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

‎java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java

+66
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,21 @@ public ResponseStream<GenerateContentResponse> generateContentStream(String text
646646
return generateContentStream(text, null, null);
647647
}
648648

649+
/**
650+
* Generate content with streaming support from generative model given a text and configs.
651+
*
652+
* @param text a text message to send to the generative model
653+
* @param config a {@link GenerateContentConfig} that contains all the configs in making a
654+
* generate content api call
655+
* @return a {@link ResponseStream} that contains a streaming of {@link
656+
* com.google.cloud.vertexai.api.GenerateContentResponse}
657+
* @throws IOException if an I/O error occurs while making the API call
658+
*/
659+
public ResponseStream<GenerateContentResponse> generateContentStream(
660+
String text, GenerateContentConfig config) throws IOException {
661+
return generateContentStream(ContentMaker.fromString(text), config);
662+
}
663+
649664
/**
650665
* Generate content with streaming support from generative model given a text and generation
651666
* config.
@@ -716,6 +731,22 @@ public ResponseStream<GenerateContentResponse> generateContentStream(Content con
716731
return generateContentStream(content, null, null);
717732
}
718733

734+
/**
735+
* Generate content with streaming support from generative model given a single content and
736+
* configs.
737+
*
738+
* @param content a {@link com.google.cloud.vertexai.api.Content} to send to the generative model
739+
* @param config a {@link GenerateContentConfig} that contains all the configs in making a
740+
* generate content api call
741+
* @return a {@link ResponseStream} that contains a streaming of {@link
742+
* com.google.cloud.vertexai.api.GenerateContentResponse}
743+
* @throws IOException if an I/O error occurs while making the API call
744+
*/
745+
public ResponseStream<GenerateContentResponse> generateContentStream(
746+
Content content, GenerateContentConfig config) throws IOException {
747+
return generateContentStream(Arrays.asList(content), config);
748+
}
749+
719750
/**
720751
* Generate content with streaming support from generative model given a single Content and
721752
* generation config.
@@ -856,6 +887,41 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
856887
return generateContentStream(requestBuilder);
857888
}
858889

890+
/**
891+
* Generate content with streaming support from generative model given a list of contents and
892+
* configs.
893+
*
894+
* @param contents a list of {@link com.google.cloud.vertexai.api.Content} to send to the
895+
* generative model
896+
* @param config a {@link GenerateContentConfig} that contains all the configs in making a
897+
* generate content api call
898+
* @return a {@link ResponseStream} that contains a streaming of {@link
899+
* com.google.cloud.vertexai.api.GenerateContentResponse}
900+
* @throws IOException if an I/O error occurs while making the API call
901+
*/
902+
public ResponseStream<GenerateContentResponse> generateContentStream(
903+
List<Content> contents, GenerateContentConfig config) throws IOException {
904+
GenerateContentRequest.Builder requestBuilder =
905+
GenerateContentRequest.newBuilder().addAllContents(contents);
906+
if (config.getGenerationConfig() != null) {
907+
requestBuilder.setGenerationConfig(config.getGenerationConfig());
908+
} else if (this.generationConfig != null) {
909+
requestBuilder.setGenerationConfig(this.generationConfig);
910+
}
911+
if (config.getSafetySettings().isEmpty() == false) {
912+
requestBuilder.addAllSafetySettings(config.getSafetySettings());
913+
} else if (this.safetySettings != null) {
914+
requestBuilder.addAllSafetySettings(this.safetySettings);
915+
}
916+
if (config.getTools().isEmpty() == false) {
917+
requestBuilder.addAllTools(config.getTools());
918+
} else if (this.tools != null) {
919+
requestBuilder.addAllTools(this.tools);
920+
}
921+
922+
return generateContentStream(requestBuilder);
923+
}
924+
859925
/**
860926
* A base generateContentStream method that will be used internally.
861927
*

‎java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java

+31
Original file line numberDiff line numberDiff line change
@@ -639,4 +639,35 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception {
639639
verify(mockServerStreamCallable).call(request.capture());
640640
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
641641
}
642+
643+
@Test
644+
public void testGenerateContentStreamwithGenerateContentConfig() throws Exception {
645+
model = new GenerativeModel(MODEL_NAME, vertexAi);
646+
GenerateContentConfig config =
647+
GenerateContentConfig.newBuilder()
648+
.setGenerationConfig(GENERATION_CONFIG)
649+
.setSafetySettings(safetySettings)
650+
.setTools(tools)
651+
.build();
652+
653+
Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
654+
field.setAccessible(true);
655+
field.set(vertexAi, mockPredictionServiceClient);
656+
657+
when(mockPredictionServiceClient.streamGenerateContentCallable())
658+
.thenReturn(mockServerStreamCallable);
659+
when(mockServerStreamCallable.call(any(GenerateContentRequest.class)))
660+
.thenReturn(mockServerStream);
661+
when(mockServerStream.iterator()).thenReturn(mockServerStreamIterator);
662+
663+
ResponseStream unused = model.generateContentStream(TEXT, config);
664+
665+
ArgumentCaptor<GenerateContentRequest> request =
666+
ArgumentCaptor.forClass(GenerateContentRequest.class);
667+
verify(mockServerStreamCallable).call(request.capture());
668+
669+
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
670+
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
671+
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
672+
}
642673
}

0 commit comments

Comments
 (0)