1616
1717from __future__ import annotations
1818
19- from typing import cast , Literal , Optional
19+ from typing import Callable , cast , Literal , Mapping , Optional
2020import warnings
2121
2222import bigframes_vendored .constants as constants
@@ -616,7 +616,7 @@ def to_gbq(
616616
617617
618618@log_adapter .class_logger
619- class TextEmbeddingGenerator (base .BaseEstimator ):
619+ class TextEmbeddingGenerator (base .RetriableRemotePredictor ):
620620 """Text embedding generator LLM model.
621621
622622 Args:
@@ -715,18 +715,33 @@ def _from_bq(
715715 model ._bqml_model = core .BqmlModel (session , bq_model )
716716 return model
717717
718- def predict (self , X : utils .ArrayType ) -> bpd .DataFrame :
718+ @property
719+ def _predict_func (self ) -> Callable [[bpd .DataFrame , Mapping ], bpd .DataFrame ]:
720+ return self ._bqml_model .generate_embedding
721+
722+ @property
723+ def _status_col (self ) -> str :
724+ return _ML_GENERATE_EMBEDDING_STATUS
725+
726+ def predict (self , X : utils .ArrayType , * , max_retries : int = 0 ) -> bpd .DataFrame :
719727 """Predict the result from input DataFrame.
720728
721729 Args:
722730 X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
723731 Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction.
724732
733+ max_retries (int, default 0):
734+ Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
735+ Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
736+
725737 Returns:
726738 bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
727739 """
740+ if max_retries < 0 :
741+ raise ValueError (
742+ f"max_retries must be larger than or equal to 0, but is { max_retries } ."
743+ )
728744
729- # Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
730745 (X ,) = utils .batch_convert_to_dataframe (X , session = self ._bqml_model .session )
731746
732747 if len (X .columns ) == 1 :
@@ -738,15 +753,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
738753 "flatten_json_output" : True ,
739754 }
740755
741- df = self ._bqml_model .generate_embedding (X , options )
742-
743- if (df [_ML_GENERATE_EMBEDDING_STATUS ] != "" ).any ():
744- warnings .warn (
745- f"Some predictions failed. Check column { _ML_GENERATE_EMBEDDING_STATUS } for detailed status. You may want to filter the failed rows and retry." ,
746- RuntimeWarning ,
747- )
748-
749- return df
756+ return self ._predict_and_retry (X , options = options , max_retries = max_retries )
750757
751758 def to_gbq (self , model_name : str , replace : bool = False ) -> TextEmbeddingGenerator :
752759 """Save the model to BigQuery.
@@ -765,7 +772,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerat
765772
766773
767774@log_adapter .class_logger
768- class GeminiTextGenerator (base .BaseEstimator ):
775+ class GeminiTextGenerator (base .RetriableRemotePredictor ):
769776 """Gemini text generator LLM model.
770777
771778 Args:
@@ -891,6 +898,14 @@ def _bqml_options(self) -> dict:
891898 }
892899 return options
893900
901+ @property
902+ def _predict_func (self ) -> Callable [[bpd .DataFrame , Mapping ], bpd .DataFrame ]:
903+ return self ._bqml_model .generate_text
904+
905+ @property
906+ def _status_col (self ) -> str :
907+ return _ML_GENERATE_TEXT_STATUS
908+
894909 def fit (
895910 self ,
896911 X : utils .ArrayType ,
@@ -1028,41 +1043,7 @@ def predict(
10281043 "ground_with_google_search" : ground_with_google_search ,
10291044 }
10301045
1031- df_result = bpd .DataFrame (session = self ._bqml_model .session ) # placeholder
1032- df_fail = X
1033- for _ in range (max_retries + 1 ):
1034- df = self ._bqml_model .generate_text (df_fail , options )
1035-
1036- success = df [_ML_GENERATE_TEXT_STATUS ].str .len () == 0
1037- df_succ = df [success ]
1038- df_fail = df [~ success ]
1039-
1040- if df_succ .empty :
1041- if max_retries > 0 :
1042- warnings .warn (
1043- "Can't make any progress, stop retrying." , RuntimeWarning
1044- )
1045- break
1046-
1047- df_result = (
1048- bpd .concat ([df_result , df_succ ]) if not df_result .empty else df_succ
1049- )
1050-
1051- if df_fail .empty :
1052- break
1053-
1054- if not df_fail .empty :
1055- warnings .warn (
1056- f"Some predictions failed. Check column { _ML_GENERATE_TEXT_STATUS } for detailed status. You may want to filter the failed rows and retry." ,
1057- RuntimeWarning ,
1058- )
1059-
1060- df_result = cast (
1061- bpd .DataFrame ,
1062- bpd .concat ([df_result , df_fail ]) if not df_result .empty else df_fail ,
1063- )
1064-
1065- return df_result
1046+ return self ._predict_and_retry (X , options = options , max_retries = max_retries )
10661047
10671048 def score (
10681049 self ,
@@ -1144,7 +1125,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator:
11441125
11451126
11461127@log_adapter .class_logger
1147- class Claude3TextGenerator (base .BaseEstimator ):
1128+ class Claude3TextGenerator (base .RetriableRemotePredictor ):
11481129 """Claude3 text generator LLM model.
11491130
11501131 Go to Google Cloud Console -> Vertex AI -> Model Garden page to enabe the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models.
@@ -1273,13 +1254,22 @@ def _bqml_options(self) -> dict:
12731254 }
12741255 return options
12751256
1257+ @property
1258+ def _predict_func (self ) -> Callable [[bpd .DataFrame , Mapping ], bpd .DataFrame ]:
1259+ return self ._bqml_model .generate_text
1260+
1261+ @property
1262+ def _status_col (self ) -> str :
1263+ return _ML_GENERATE_TEXT_STATUS
1264+
12761265 def predict (
12771266 self ,
12781267 X : utils .ArrayType ,
12791268 * ,
12801269 max_output_tokens : int = 128 ,
12811270 top_k : int = 40 ,
12821271 top_p : float = 0.95 ,
1272+ max_retries : int = 0 ,
12831273 ) -> bpd .DataFrame :
12841274 """Predict the result from input DataFrame.
12851275
@@ -1307,6 +1297,10 @@ def predict(
13071297 Specify a lower value for less random responses and a higher value for more random responses.
13081298 Default 0.95. Possible values [0.0, 1.0].
13091299
1300+ max_retries (int, default 0):
1301+ Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
1302+ Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
1303+
13101304
13111305 Returns:
13121306 bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
@@ -1324,6 +1318,11 @@ def predict(
13241318 if top_p < 0.0 or top_p > 1.0 :
13251319 raise ValueError (f"top_p must be [0.0, 1.0], but is { top_p } ." )
13261320
1321+ if max_retries < 0 :
1322+ raise ValueError (
1323+ f"max_retries must be larger than or equal to 0, but is { max_retries } ."
1324+ )
1325+
13271326 (X ,) = utils .batch_convert_to_dataframe (X , session = self ._bqml_model .session )
13281327
13291328 if len (X .columns ) == 1 :
@@ -1338,15 +1337,7 @@ def predict(
13381337 "flatten_json_output" : True ,
13391338 }
13401339
1341- df = self ._bqml_model .generate_text (X , options )
1342-
1343- if (df [_ML_GENERATE_TEXT_STATUS ] != "" ).any ():
1344- warnings .warn (
1345- f"Some predictions failed. Check column { _ML_GENERATE_TEXT_STATUS } for detailed status. You may want to filter the failed rows and retry." ,
1346- RuntimeWarning ,
1347- )
1348-
1349- return df
1340+ return self ._predict_and_retry (X , options = options , max_retries = max_retries )
13501341
13511342 def to_gbq (self , model_name : str , replace : bool = False ) -> Claude3TextGenerator :
13521343 """Save the model to BigQuery.
0 commit comments