1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
1517import re
1618import typing
17- from typing import List , Optional
19+ from typing import Dict , List , Optional
1820import warnings
1921
2022import numpy as np
@@ -34,7 +36,13 @@ def __init__(self, df) -> None:
3436
3537 self ._df : bigframes .dataframe .DataFrame = df
3638
37- def filter (self , instruction : str , model , ground_with_google_search : bool = False ):
39+ def filter (
40+ self ,
41+ instruction : str ,
42+ model ,
43+ ground_with_google_search : bool = False ,
44+ attach_logprobs : bool = False ,
45+ ):
3846 """
3947 Filters the DataFrame with the semantics of the user instruction.
4048
@@ -74,6 +82,10 @@ def filter(self, instruction: str, model, ground_with_google_search: bool = Fals
7482 page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
7583 The default is `False`.
7684
85+ attach_logprobs (bool, default False):
86+ Controls whether to attach an additional "logprob" column for each result. Logprobs are float-point values reflecting the confidence level
87+ of the LLM for their responses. Higher values indicate more confidence. The value is in the range between negative infinite and 0.
88+
7789 Returns:
7890 bigframes.pandas.DataFrame: DataFrame filtered by the instruction.
7991
@@ -82,72 +94,27 @@ def filter(self, instruction: str, model, ground_with_google_search: bool = Fals
8294 ValueError: when the instruction refers to a non-existing column, or when no
8395 columns are referred to.
8496 """
85- import bigframes .dataframe
86- import bigframes .series
8797
88- self ._validate_model (model )
89- columns = self ._parse_columns (instruction )
90- for column in columns :
91- if column not in self ._df .columns :
92- raise ValueError (f"Column { column } not found." )
98+ answer_col = "answer"
9399
94- if ground_with_google_search :
95- msg = exceptions .format_message (
96- "Enables Grounding with Google Search may impact billing cost. See pricing "
97- "details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models"
98- )
99- warnings .warn (msg , category = UserWarning )
100-
101- self ._confirm_operation (len (self ._df ))
102-
103- df : bigframes .dataframe .DataFrame = self ._df [columns ].copy ()
104- has_blob_column = False
105- for column in columns :
106- if df [column ].dtype == dtypes .OBJ_REF_DTYPE :
107- # Don't cast blob columns to string
108- has_blob_column = True
109- continue
110-
111- if df [column ].dtype != dtypes .STRING_DTYPE :
112- df [column ] = df [column ].astype (dtypes .STRING_DTYPE )
113-
114- user_instruction = self ._format_instruction (instruction , columns )
115- output_instruction = "Based on the provided context, reply to the following claim by only True or False:"
116-
117- if has_blob_column :
118- results = typing .cast (
119- bigframes .dataframe .DataFrame ,
120- model .predict (
121- df ,
122- prompt = self ._make_multimodel_prompt (
123- df , columns , user_instruction , output_instruction
124- ),
125- temperature = 0.0 ,
126- ground_with_google_search = ground_with_google_search ,
127- ),
128- )
129- else :
130- results = typing .cast (
131- bigframes .dataframe .DataFrame ,
132- model .predict (
133- self ._make_text_prompt (
134- df , columns , user_instruction , output_instruction
135- ),
136- temperature = 0.0 ,
137- ground_with_google_search = ground_with_google_search ,
138- ),
139- )
100+ output_schema = {answer_col : "bool" }
101+ result = self .map (
102+ instruction ,
103+ model ,
104+ output_schema ,
105+ ground_with_google_search ,
106+ attach_logprobs ,
107+ )
140108
141- return self ._df [
142- results ["ml_generate_text_llm_result" ].str .lower ().str .contains ("true" )
143- ]
109+ return result [result [answer_col ]].drop (answer_col , axis = 1 )
144110
145111 def map (
146112 self ,
147113 instruction : str ,
148- output_column : str ,
149114 model ,
115+ output_schema : Dict [str , str ] | None = None ,
150116 ground_with_google_search : bool = False ,
117+ attach_logprobs = False ,
151118 ):
152119 """
153120 Maps the DataFrame with the semantics of the user instruction.
@@ -163,7 +130,7 @@ def map(
163130 >>> model = llm.GeminiTextGenerator(model_name="gemini-2.0-flash-001")
164131
165132 >>> df = bpd.DataFrame({"ingredient_1": ["Burger Bun", "Soy Bean"], "ingredient_2": ["Beef Patty", "Bittern"]})
166- >>> df.ai.map("What is the food made from {ingredient_1} and {ingredient_2}? One word only.", output_column= "food", model=model )
133+ >>> df.ai.map("What is the food made from {ingredient_1} and {ingredient_2}? One word only.", model=model, output_schema={ "food": "string"} )
167134 ingredient_1 ingredient_2 food
168135 0 Burger Bun Beef Patty Burger
169136 <BLANKLINE>
@@ -180,12 +147,14 @@ def map(
180147 in the instructions like:
181148 "Get the ingredients of {food}."
182149
183- output_column (str):
184- The column name of the mapping result.
185-
186150 model (bigframes.ml.llm.GeminiTextGenerator):
187151 A GeminiTextGenerator provided by Bigframes ML package.
188152
153+ output_schema (Dict[str, str] or None, default None):
154+ The schema used to generate structured output as a bigframes DataFrame. The schema is a string key-value pair of <column_name>:<type>.
155+ Supported types are int64, float64, bool, string, array<type> and struct<column type>. If None, generate string result under the column
156+ "ml_generate_text_llm_result".
157+
189158 ground_with_google_search (bool, default False):
190159 Enables Grounding with Google Search for the GeminiTextGenerator model.
191160 When set to True, the model incorporates relevant information from Google
@@ -194,6 +163,11 @@ def map(
194163 page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
195164 The default is `False`.
196165
166+ attach_logprobs (bool, default False):
167+ Controls whether to attach an additional "logprob" column for each result. Logprobs are float-point values reflecting the confidence level
168+ of the LLM for their responses. Higher values indicate more confidence. The value is in the range between negative infinite and 0.
169+
170+
197171 Returns:
198172 bigframes.pandas.DataFrame: DataFrame with attached mapping results.
199173
@@ -236,6 +210,9 @@ def map(
236210 "Based on the provided contenxt, answer the following instruction:"
237211 )
238212
213+ if output_schema is None :
214+ output_schema = {"ml_generate_text_llm_result" : "string" }
215+
239216 if has_blob_column :
240217 results = typing .cast (
241218 bigframes .series .Series ,
@@ -246,7 +223,8 @@ def map(
246223 ),
247224 temperature = 0.0 ,
248225 ground_with_google_search = ground_with_google_search ,
249- )["ml_generate_text_llm_result" ],
226+ output_schema = output_schema ,
227+ ),
250228 )
251229 else :
252230 results = typing .cast (
@@ -257,19 +235,36 @@ def map(
257235 ),
258236 temperature = 0.0 ,
259237 ground_with_google_search = ground_with_google_search ,
260- )["ml_generate_text_llm_result" ],
238+ output_schema = output_schema ,
239+ ),
240+ )
241+
242+ attach_columns = [results [col ] for col , _ in output_schema .items ()]
243+
244+ def extract_logprob (s : bigframes .series .Series ) -> bigframes .series .Series :
245+ from bigframes import bigquery as bbq
246+
247+ logprob_jsons = bbq .json_extract_array (s , "$.candidates" ).list [0 ]
248+ logprobs = bbq .json_extract (logprob_jsons , "$.avg_logprobs" ).astype (
249+ "Float64"
261250 )
251+ logprobs .name = "logprob"
252+ return logprobs
253+
254+ if attach_logprobs :
255+ attach_columns .append (extract_logprob (results ["full_response" ]))
262256
263257 from bigframes .core .reshape .api import concat
264258
265- return concat ([self ._df , results . rename ( output_column ) ], axis = 1 )
259+ return concat ([self ._df , * attach_columns ], axis = 1 )
266260
267261 def join (
268262 self ,
269263 other ,
270264 instruction : str ,
271265 model ,
272266 ground_with_google_search : bool = False ,
267+ attach_logprobs = False ,
273268 ):
274269 """
275270 Joines two dataframes by applying the instruction over each pair of rows from
@@ -313,10 +308,6 @@ def join(
313308 model (bigframes.ml.llm.GeminiTextGenerator):
314309 A GeminiTextGenerator provided by Bigframes ML package.
315310
316- max_rows (int, default 1000):
317- The maximum number of rows allowed to be sent to the model per call. If the result is too large, the method
318- call will end early with an error.
319-
320311 ground_with_google_search (bool, default False):
321312 Enables Grounding with Google Search for the GeminiTextGenerator model.
322313 When set to True, the model incorporates relevant information from Google
@@ -325,6 +316,10 @@ def join(
325316 page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
326317 The default is `False`.
327318
319+ attach_logprobs (bool, default False):
320+ Controls whether to attach an additional "logprob" column for each result. Logprobs are float-point values reflecting the confidence level
321+ of the LLM for their responses. Higher values indicate more confidence. The value is in the range between negative infinite and 0.
322+
328323 Returns:
329324 bigframes.pandas.DataFrame: The joined dataframe.
330325
@@ -400,7 +395,10 @@ def join(
400395 joined_df = self ._df .merge (other , how = "cross" , suffixes = ("_left" , "_right" ))
401396
402397 return joined_df .ai .filter (
403- instruction , model , ground_with_google_search = ground_with_google_search
398+ instruction ,
399+ model ,
400+ ground_with_google_search = ground_with_google_search ,
401+ attach_logprobs = attach_logprobs ,
404402 ).reset_index (drop = True )
405403
406404 def search (
0 commit comments