@@ -222,7 +222,9 @@ def tune_model(
222222 if eval_spec .evaluation_data :
223223 if isinstance (eval_spec .evaluation_data , str ):
224224 if eval_spec .evaluation_data .startswith ("gs://" ):
225- tuning_parameters ["evaluation_data_uri" ] = eval_spec .evaluation_data
225+ tuning_parameters [
226+ "evaluation_data_uri"
227+ ] = eval_spec .evaluation_data
226228 else :
227229 raise ValueError ("evaluation_data should be a GCS URI" )
228230 else :
@@ -627,7 +629,7 @@ def count_tokens(
627629 ) -> CountTokensResponse :
628630 """Counts the tokens and billable characters for a given prompt.
629631
630- Note: this does not make a request to the model, it only counts the tokens
632+ Note: this does not make a prediction request to the model, it only counts the tokens
631633 in the request.
632634
633635 Args:
@@ -802,7 +804,9 @@ def predict(
802804 parameters = prediction_request .parameters ,
803805 )
804806
805- return _parse_text_generation_model_multi_candidate_response (prediction_response )
807+ return _parse_text_generation_model_multi_candidate_response (
808+ prediction_response
809+ )
806810
807811 async def predict_async (
808812 self ,
@@ -844,7 +848,9 @@ async def predict_async(
844848 parameters = prediction_request .parameters ,
845849 )
846850
847- return _parse_text_generation_model_multi_candidate_response (prediction_response )
851+ return _parse_text_generation_model_multi_candidate_response (
852+ prediction_response
853+ )
848854
849855 def predict_streaming (
850856 self ,
@@ -1587,6 +1593,47 @@ class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin):
15871593
15881594 _LAUNCH_STAGE = _model_garden_models ._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
15891595
1596+ def start_chat (
1597+ self ,
1598+ * ,
1599+ context : Optional [str ] = None ,
1600+ examples : Optional [List [InputOutputTextPair ]] = None ,
1601+ max_output_tokens : Optional [int ] = None ,
1602+ temperature : Optional [float ] = None ,
1603+ top_k : Optional [int ] = None ,
1604+ top_p : Optional [float ] = None ,
1605+ message_history : Optional [List [ChatMessage ]] = None ,
1606+ stop_sequences : Optional [List [str ]] = None ,
1607+ ) -> "_PreviewChatSession" :
1608+ """Starts a chat session with the model.
1609+
1610+ Args:
1611+ context: Context shapes how the model responds throughout the conversation.
1612+ For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
1613+ examples: List of structured messages to the model to learn how to respond to the conversation.
1614+ A list of `InputOutputTextPair` objects.
1615+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1616+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1617+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1618+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1619+ message_history: A list of previously sent and received messages.
1620+ stop_sequences: Customized stop sequences to stop the decoding process.
1621+
1622+ Returns:
1623+ A `ChatSession` object.
1624+ """
1625+ return _PreviewChatSession (
1626+ model = self ,
1627+ context = context ,
1628+ examples = examples ,
1629+ max_output_tokens = max_output_tokens ,
1630+ temperature = temperature ,
1631+ top_k = top_k ,
1632+ top_p = top_p ,
1633+ message_history = message_history ,
1634+ stop_sequences = stop_sequences ,
1635+ )
1636+
15901637
15911638class CodeChatModel (_ChatModelBase ):
15921639 """CodeChatModel represents a model that is capable of completing code.
@@ -1646,6 +1693,47 @@ class _PreviewCodeChatModel(CodeChatModel, _TunableChatModelMixin):
16461693
16471694 _LAUNCH_STAGE = _model_garden_models ._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
16481695
1696+ def start_chat (
1697+ self ,
1698+ * ,
1699+ context : Optional [str ] = None ,
1700+ examples : Optional [List [InputOutputTextPair ]] = None ,
1701+ max_output_tokens : Optional [int ] = None ,
1702+ temperature : Optional [float ] = None ,
1703+ top_k : Optional [int ] = None ,
1704+ top_p : Optional [float ] = None ,
1705+ message_history : Optional [List [ChatMessage ]] = None ,
1706+ stop_sequences : Optional [List [str ]] = None ,
1707+ ) -> "_PreviewCodeChatSession" :
1708+ """Starts a chat session with the model.
1709+
1710+ Args:
1711+ context: Context shapes how the model responds throughout the conversation.
1712+ For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
1713+ examples: List of structured messages to the model to learn how to respond to the conversation.
1714+ A list of `InputOutputTextPair` objects.
1715+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1716+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1717+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1718+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1719+ message_history: A list of previously sent and received messages.
1720+ stop_sequences: Customized stop sequences to stop the decoding process.
1721+
1722+ Returns:
1723+ A `ChatSession` object.
1724+ """
1725+ return _PreviewCodeChatSession (
1726+ model = self ,
1727+ context = context ,
1728+ examples = examples ,
1729+ max_output_tokens = max_output_tokens ,
1730+ temperature = temperature ,
1731+ top_k = top_k ,
1732+ top_p = top_p ,
1733+ message_history = message_history ,
1734+ stop_sequences = stop_sequences ,
1735+ )
1736+
16491737
16501738class _ChatSessionBase :
16511739 """_ChatSessionBase is a base class for all chat sessions."""
@@ -2071,6 +2159,67 @@ async def send_message_streaming_async(
20712159 )
20722160
20732161
2162+ class _ChatSessionBaseWithCountTokensMixin (_ChatSessionBase ):
2163+ """A mixin class for adding count_tokens to ChatSession."""
2164+
2165+ def count_tokens (
2166+ self ,
2167+ message : str ,
2168+ ) -> CountTokensResponse :
2169+ """Counts the tokens and billable characters for the provided chat message and any message history,
2170+ context, or examples set on the chat session.
2171+
2172+ If you've called `send_message()` in the current chat session before calling `count_tokens()`, the
2173+ response will include the total tokens and characters for the previously sent message and the one in the
2174+ `count_tokens()` request. To count the tokens for a single message, call `count_tokens()` right after
2175+ calling `start_chat()` before calling `send_message()`.
2176+
2177+ Note: this does not make a prediction request to the model, it only counts the tokens
2178+ in the request.
2179+
2180+ Examples::
2181+
2182+ model = ChatModel.from_pretrained("chat-bison@001")
2183+ chat_session = model.start_chat()
2184+ count_tokens_response = chat_session.count_tokens("How's it going?")
2185+
2186+ count_tokens_response.total_tokens
2187+ count_tokens_response.total_billable_characters
2188+
2189+ Args:
2190+ message (str):
2191+ Required. A chat message to count tokens or. For example: "How's it going?"
2192+ Returns:
2193+ A `CountTokensResponse` object that contains the number of tokens
2194+ in the text and the number of billable characters.
2195+ """
2196+
2197+ count_tokens_request = self ._prepare_request (message = message )
2198+
2199+ count_tokens_response = self ._model ._endpoint ._prediction_client .select_version (
2200+ "v1beta1"
2201+ ).count_tokens (
2202+ endpoint = self ._model ._endpoint_name ,
2203+ instances = [count_tokens_request .instance ],
2204+ )
2205+
2206+ return CountTokensResponse (
2207+ total_tokens = count_tokens_response .total_tokens ,
2208+ total_billable_characters = count_tokens_response .total_billable_characters ,
2209+ _count_tokens_response = count_tokens_response ,
2210+ )
2211+
2212+
2213+ class _PreviewChatSession (_ChatSessionBaseWithCountTokensMixin ):
2214+
2215+ __module__ = "vertexai.preview.language_models"
2216+
2217+
2218+ class _PreviewCodeChatSession (_ChatSessionBaseWithCountTokensMixin ):
2219+
2220+ __module__ = "vertexai.preview.language_models"
2221+
2222+
20742223class ChatSession (_ChatSessionBase ):
20752224 """ChatSession represents a chat session with a language model.
20762225
@@ -2361,7 +2510,9 @@ def predict(
23612510 instances = [prediction_request .instance ],
23622511 parameters = prediction_request .parameters ,
23632512 )
2364- return _parse_text_generation_model_multi_candidate_response (prediction_response )
2513+ return _parse_text_generation_model_multi_candidate_response (
2514+ prediction_response
2515+ )
23652516
23662517 async def predict_async (
23672518 self ,
@@ -2400,7 +2551,9 @@ async def predict_async(
24002551 instances = [prediction_request .instance ],
24012552 parameters = prediction_request .parameters ,
24022553 )
2403- return _parse_text_generation_model_multi_candidate_response (prediction_response )
2554+ return _parse_text_generation_model_multi_candidate_response (
2555+ prediction_response
2556+ )
24042557
24052558 def predict_streaming (
24062559 self ,
0 commit comments