@@ -1270,6 +1270,35 @@ def combine(
12701270 def combine_first (self , other : DataFrame ):
12711271 return self ._apply_dataframe_binop (other , ops .fillna_op )
12721272
1273+ def _fast_stat_matrix (self , op : agg_ops .BinaryAggregateOp ) -> DataFrame :
1274+ """Faster corr, cov calculations, but creates more sql text, so cannot scale to many columns"""
1275+ assert len (self .columns ) * len (self .columns ) < bigframes .constants .MAX_COLUMNS
1276+ orig_columns = self .columns
1277+ frame = self .copy ()
1278+ # Replace column names with 0 to n - 1 to keep order
1279+ # and avoid the influence of duplicated column name
1280+ frame .columns = pandas .Index (range (len (orig_columns )))
1281+ frame = frame .astype (bigframes .dtypes .FLOAT_DTYPE )
1282+ block = frame ._block
1283+
1284+ aggregations = [
1285+ ex .BinaryAggregation (op , ex .deref (left_col ), ex .deref (right_col ))
1286+ for left_col in block .value_columns
1287+ for right_col in block .value_columns
1288+ ]
1289+ # unique columns stops
1290+ uniq_orig_columns = utils .combine_indices (
1291+ orig_columns , pandas .Index (range (len (orig_columns )))
1292+ )
1293+ labels = utils .cross_indices (uniq_orig_columns , uniq_orig_columns )
1294+
1295+ block , _ = block .aggregate (aggregations = aggregations , column_labels = labels )
1296+
1297+ block = block .stack (levels = orig_columns .nlevels + 1 )
1298+ # The aggregate operation crated a index level with just 0, need to drop it
1299+ # Also, drop the last level of each index, which was created to guarantee uniqueness
1300+ return DataFrame (block ).droplevel (0 ).droplevel (- 1 , axis = 0 ).droplevel (- 1 , axis = 1 )
1301+
12731302 def corr (self , method = "pearson" , min_periods = None , numeric_only = False ) -> DataFrame :
12741303 if method != "pearson" :
12751304 raise NotImplementedError (
@@ -1285,6 +1314,10 @@ def corr(self, method="pearson", min_periods=None, numeric_only=False) -> DataFr
12851314 else :
12861315 frame = self ._drop_non_numeric ()
12871316
1317+ if len (frame .columns ) <= 30 :
1318+ return frame ._fast_stat_matrix (agg_ops .CorrOp ())
1319+
1320+ frame = frame .copy ()
12881321 orig_columns = frame .columns
12891322 # Replace column names with 0 to n - 1 to keep order
12901323 # and avoid the influence of duplicated column name
@@ -1393,6 +1426,10 @@ def cov(self, *, numeric_only: bool = False) -> DataFrame:
13931426 else :
13941427 frame = self ._drop_non_numeric ()
13951428
1429+ if len (frame .columns ) <= 30 :
1430+ return frame ._fast_stat_matrix (agg_ops .CovOp ())
1431+
1432+ frame = frame .copy ()
13961433 orig_columns = frame .columns
13971434 # Replace column names with 0 to n - 1 to keep order
13981435 # and avoid the influence of duplicated column name
0 commit comments