Skip to content

Commit 14f9825

Browse files
feat: [vertexai] infer location and project when user doesn't specify them. (#10868)
PiperOrigin-RevId: 635997756 Co-authored-by: Yvonne Yu <yyyu@google.com>
1 parent 3e7752f commit 14f9825

File tree

7 files changed

+435
-56
lines changed

7 files changed

+435
-56
lines changed

‎java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/Constants.java

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
public final class Constants {
2121
// Constants for VertexAI class
2222
public static final String USER_AGENT_HEADER = "model-builder";
23+
static final String DEFAULT_LOCATION = "us-central1";
24+
static final String GOOGLE_CLOUD_REGION = "GOOGLE_CLOUD_REGION";
25+
static final String CLOUD_ML_REGION = "CLOUD_ML_REGION";
26+
static final String GOOGLE_CLOUD_PROJECT = "GOOGLE_CLOUD_PROJECT";
2327

2428
private Constants() {}
2529
}

‎java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java

+59-8
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import com.google.api.gax.rpc.FixedHeaderProvider;
2828
import com.google.api.gax.rpc.HeaderProvider;
2929
import com.google.auth.Credentials;
30+
import com.google.auth.oauth2.GoogleCredentials;
3031
import com.google.cloud.vertexai.api.LlmUtilityServiceClient;
3132
import com.google.cloud.vertexai.api.LlmUtilityServiceSettings;
3233
import com.google.cloud.vertexai.api.PredictionServiceClient;
@@ -67,6 +68,11 @@ public class VertexAI implements AutoCloseable {
6768
private final transient Supplier<PredictionServiceClient> predictionClientSupplier;
6869
private final transient Supplier<LlmUtilityServiceClient> llmClientSupplier;
6970

71+
@InternalApi
72+
static Optional<String> getEnvironmentVariable(String envKey) {
73+
return Optional.ofNullable(System.getenv(envKey));
74+
}
75+
7076
/**
7177
* Constructs a VertexAI instance.
7278
*
@@ -85,6 +91,29 @@ public VertexAI(String projectId, String location) {
8591
/* llmClientSupplierOpt= */ Optional.empty());
8692
}
8793

94+
/**
95+
* Constructs a VertexAI instance.
96+
*
97+
* <p><b>Note:</b> SDK infers location from runtime environment first. If there is no location
98+
* inferred from runtime environment, SDK will default location to `us-central1`.
99+
*
100+
* <p>SDK will infer projectId from runtime environment and GoogleCredentials.
101+
*
102+
* @throws java.lang.IllegalArgumentException If there is not projectId inferred from either
103+
* runtime environment or GoogleCredentials
104+
*/
105+
public VertexAI() {
106+
this(
107+
null,
108+
null,
109+
Transport.GRPC,
110+
ImmutableList.of(),
111+
/* credentials= */ Optional.empty(),
112+
/* apiEndpoint= */ Optional.empty(),
113+
/* predictionClientSupplierOpt= */ Optional.empty(),
114+
/* llmClientSupplierOpt= */ Optional.empty());
115+
}
116+
88117
private VertexAI(
89118
String projectId,
90119
String location,
@@ -98,12 +127,8 @@ private VertexAI(
98127
throw new IllegalArgumentException(
99128
"At most one of Credentials and scopes should be specified.");
100129
}
101-
checkArgument(!Strings.isNullOrEmpty(projectId), "projectId can't be null or empty");
102-
checkArgument(!Strings.isNullOrEmpty(location), "location can't be null or empty");
103130
checkNotNull(transport, "transport can't be null");
104-
105-
this.projectId = projectId;
106-
this.location = location;
131+
this.location = Strings.isNullOrEmpty(location) ? inferLocation() : location;
107132
this.transport = transport;
108133

109134
if (credentials.isPresent()) {
@@ -118,13 +143,15 @@ private VertexAI(
118143
.build();
119144
}
120145

146+
this.projectId = Strings.isNullOrEmpty(projectId) ? inferProjectId() : projectId;
121147
this.predictionClientSupplier =
122148
Suppliers.memoize(predictionClientSupplierOpt.orElse(this::newPredictionServiceClient));
123149

124150
this.llmClientSupplier =
125151
Suppliers.memoize(llmClientSupplierOpt.orElse(this::newLlmUtilityClient));
126152

127-
this.apiEndpoint = apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", location));
153+
this.apiEndpoint =
154+
apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", this.location));
128155
}
129156

130157
/** Builder for {@link VertexAI}. */
@@ -141,8 +168,6 @@ public static class Builder {
141168
private Supplier<LlmUtilityServiceClient> llmClientSupplier;
142169

143170
public VertexAI build() {
144-
checkNotNull(projectId, "projectId must be set.");
145-
checkNotNull(location, "location must be set.");
146171

147172
return new VertexAI(
148173
projectId,
@@ -339,6 +364,32 @@ private LlmUtilityServiceClient newLlmUtilityClient() {
339364
}
340365
}
341366

367+
private String inferProjectId() {
368+
final String projectNotFoundErrorMessage =
369+
("Unable to infer your project. Please provide a project Id by one of the following:"
370+
+ "\n- Passing a constructor argument by using new VertexAI(String projectId, String"
371+
+ " location)"
372+
+ "\n- Setting project using 'gcloud config set project my-project'");
373+
final Optional<String> projectIdOptional =
374+
getEnvironmentVariable(Constants.GOOGLE_CLOUD_PROJECT);
375+
if (projectIdOptional.isPresent()) {
376+
return projectIdOptional.get();
377+
}
378+
try {
379+
return Optional.ofNullable((GoogleCredentials) this.credentialsProvider.getCredentials())
380+
.map((credentials) -> credentials.getQuotaProjectId())
381+
.orElseThrow(() -> new IllegalArgumentException(projectNotFoundErrorMessage));
382+
} catch (IOException e) {
383+
throw new IllegalArgumentException(projectNotFoundErrorMessage, e);
384+
}
385+
}
386+
387+
private String inferLocation() {
388+
return getEnvironmentVariable(Constants.GOOGLE_CLOUD_REGION)
389+
.orElse(
390+
getEnvironmentVariable(Constants.CLOUD_ML_REGION).orElse(Constants.DEFAULT_LOCATION));
391+
}
392+
342393
private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IOException {
343394
LlmUtilityServiceSettings.Builder settingsBuilder;
344395
if (transport == Transport.REST) {

‎java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java

+72-36
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,24 @@
2929
import com.google.cloud.vertexai.api.GenerationConfig;
3030
import com.google.cloud.vertexai.api.SafetySetting;
3131
import com.google.cloud.vertexai.api.Tool;
32+
import com.google.cloud.vertexai.api.ToolConfig;
3233
import com.google.common.collect.ImmutableList;
3334
import java.io.IOException;
3435
import java.util.ArrayList;
3536
import java.util.List;
3637
import java.util.Optional;
3738

38-
/** Represents a conversation between the user and the model */
39+
/**
40+
* Represents a conversation between the user and the model.
41+
*
42+
* <p>Note: this class is NOT thread-safe.
43+
*/
3944
public final class ChatSession {
4045
private final GenerativeModel model;
4146
private final Optional<ChatSession> rootChatSession;
4247
private final Optional<AutomaticFunctionCallingResponder> automaticFunctionCallingResponder;
43-
private List<Content> history = new ArrayList<>();
44-
private int previousHistorySize = 0;
48+
private List<Content> history;
49+
private int previousHistorySize;
4550
private Optional<ResponseStream<GenerateContentResponse>> currentResponseStream;
4651
private Optional<GenerateContentResponse> currentResponse;
4752

@@ -50,14 +55,17 @@ public final class ChatSession {
5055
* GenerationConfig) inherits from the model.
5156
*/
5257
public ChatSession(GenerativeModel model) {
53-
this(model, Optional.empty(), Optional.empty());
58+
this(model, new ArrayList<>(), 0, Optional.empty(), Optional.empty());
5459
}
5560

5661
/**
5762
* Creates a new chat session given a GenerativeModel instance and a root chat session.
5863
* Configurations of the chat (e.g., GenerationConfig) inherits from the model.
5964
*
6065
* @param model a {@link GenerativeModel} instance that generates contents in the chat.
66+
* @param history a list of {@link Content} containing interleaving conversation between "user"
67+
* and "model".
68+
* @param previousHistorySize the size of the previous history.
6169
* @param rootChatSession a root {@link ChatSession} instance. All the chat history in the current
6270
* chat session will be merged to the root chat session.
6371
* @param automaticFunctionCallingResponder an {@link AutomaticFunctionCallingResponder} instance
@@ -66,10 +74,14 @@ public ChatSession(GenerativeModel model) {
6674
*/
6775
private ChatSession(
6876
GenerativeModel model,
77+
List<Content> history,
78+
int previousHistorySize,
6979
Optional<ChatSession> rootChatSession,
7080
Optional<AutomaticFunctionCallingResponder> automaticFunctionCallingResponder) {
7181
checkNotNull(model, "model should not be null");
7282
this.model = model;
83+
this.history = history;
84+
this.previousHistorySize = previousHistorySize;
7385
this.rootChatSession = rootChatSession;
7486
this.automaticFunctionCallingResponder = automaticFunctionCallingResponder;
7587
currentResponseStream = Optional.empty();
@@ -84,15 +96,12 @@ private ChatSession(
8496
* @return a new {@link ChatSession} instance with the specified GenerationConfig.
8597
*/
8698
public ChatSession withGenerationConfig(GenerationConfig generationConfig) {
87-
ChatSession rootChat = rootChatSession.orElse(this);
88-
ChatSession newChatSession =
89-
new ChatSession(
90-
model.withGenerationConfig(generationConfig),
91-
Optional.of(rootChat),
92-
automaticFunctionCallingResponder);
93-
newChatSession.history = history;
94-
newChatSession.previousHistorySize = previousHistorySize;
95-
return newChatSession;
99+
return new ChatSession(
100+
model.withGenerationConfig(generationConfig),
101+
history,
102+
previousHistorySize,
103+
Optional.of(rootChatSession.orElse(this)),
104+
automaticFunctionCallingResponder);
96105
}
97106

98107
/**
@@ -103,15 +112,12 @@ public ChatSession withGenerationConfig(GenerationConfig generationConfig) {
103112
* @return a new {@link ChatSession} instance with the specified SafetySettings.
104113
*/
105114
public ChatSession withSafetySettings(List<SafetySetting> safetySettings) {
106-
ChatSession rootChat = rootChatSession.orElse(this);
107-
ChatSession newChatSession =
108-
new ChatSession(
109-
model.withSafetySettings(safetySettings),
110-
Optional.of(rootChat),
111-
automaticFunctionCallingResponder);
112-
newChatSession.history = history;
113-
newChatSession.previousHistorySize = previousHistorySize;
114-
return newChatSession;
115+
return new ChatSession(
116+
model.withSafetySettings(safetySettings),
117+
history,
118+
previousHistorySize,
119+
Optional.of(rootChatSession.orElse(this)),
120+
automaticFunctionCallingResponder);
115121
}
116122

117123
/**
@@ -122,13 +128,44 @@ public ChatSession withSafetySettings(List<SafetySetting> safetySettings) {
122128
* @return a new {@link ChatSession} instance with the specified Tools.
123129
*/
124130
public ChatSession withTools(List<Tool> tools) {
125-
ChatSession rootChat = rootChatSession.orElse(this);
126-
ChatSession newChatSession =
127-
new ChatSession(
128-
model.withTools(tools), Optional.of(rootChat), automaticFunctionCallingResponder);
129-
newChatSession.history = history;
130-
newChatSession.previousHistorySize = previousHistorySize;
131-
return newChatSession;
131+
return new ChatSession(
132+
model.withTools(tools),
133+
history,
134+
previousHistorySize,
135+
Optional.of(rootChatSession.orElse(this)),
136+
automaticFunctionCallingResponder);
137+
}
138+
139+
/**
140+
* Creates a copy of the current ChatSession with updated ToolConfig.
141+
*
142+
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used in the
143+
* new ChatSession.
144+
* @return a new {@link ChatSession} instance with the specified ToolConfigs.
145+
*/
146+
public ChatSession withToolConfig(ToolConfig toolConfig) {
147+
return new ChatSession(
148+
model.withToolConfig(toolConfig),
149+
history,
150+
previousHistorySize,
151+
Optional.of(rootChatSession.orElse(this)),
152+
automaticFunctionCallingResponder);
153+
}
154+
155+
/**
156+
* Creates a copy of the current ChatSession with updated SystemInstruction.
157+
*
158+
* @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} containing system
159+
* instructions.
160+
* @return a new {@link ChatSession} instance with the specified ToolConfigs.
161+
*/
162+
public ChatSession withSystemInstruction(Content systemInstruction) {
163+
return new ChatSession(
164+
model.withSystemInstruction(systemInstruction),
165+
history,
166+
previousHistorySize,
167+
Optional.of(rootChatSession.orElse(this)),
168+
automaticFunctionCallingResponder);
132169
}
133170

134171
/**
@@ -141,13 +178,12 @@ public ChatSession withTools(List<Tool> tools) {
141178
*/
142179
public ChatSession withAutomaticFunctionCallingResponder(
143180
AutomaticFunctionCallingResponder automaticFunctionCallingResponder) {
144-
ChatSession rootChat = rootChatSession.orElse(this);
145-
ChatSession newChatSession =
146-
new ChatSession(
147-
model, Optional.of(rootChat), Optional.of(automaticFunctionCallingResponder));
148-
newChatSession.history = history;
149-
newChatSession.previousHistorySize = previousHistorySize;
150-
return newChatSession;
181+
return new ChatSession(
182+
model,
183+
history,
184+
previousHistorySize,
185+
Optional.of(rootChatSession.orElse(this)),
186+
Optional.of(automaticFunctionCallingResponder));
151187
}
152188

153189
/**

‎java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java

+12-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,13 @@
3939
import java.util.List;
4040
import java.util.Optional;
4141

42-
/** This class holds a generative model that can complete what you provided. */
42+
/**
43+
* This class holds a generative model that can complete what you provided. This class is
44+
* thread-safe.
45+
*
46+
* <p>Note: The instances of {@link ChatSession} returned by {@link GenerativeModel#startChat()} are
47+
* NOT thread-safe.
48+
*/
4349
public final class GenerativeModel {
4450
private final String modelName;
4551
private final String resourceName;
@@ -645,6 +651,11 @@ public Optional<Content> getSystemInstruction() {
645651
return systemInstruction;
646652
}
647653

654+
/**
655+
* Returns a new {@link ChatSession} instance that can be used to start a chat with this model.
656+
*
657+
* <p>Note: the returned {@link ChatSession} instance is NOT thread-safe.
658+
*/
648659
public ChatSession startChat() {
649660
return new ChatSession(this);
650661
}

0 commit comments

Comments
 (0)