I implemented the following code to calculate the YTD sum in Pandas:
def calculateYTDSum(df:pd.DataFrame)->pd.DataFrame:
'''Calculates the YTD sum of numeric values in a dataframe.
This assumes the input dataframe contains a "quarter" column of type "Quarter"
'''
ans = (df
.sort_values(by='quarter', ascending=True)
.assign(_year = lambda x: x['quarter'].apply(lambda x: x.year))
.groupby('_year')
.apply(lambda x: x
.set_index('quarter')
.cumsum()
)
.drop(columns=['_year'])
.reset_index()
.drop(columns=['_year'])
.sort_values(by='quarter', ascending=False)
)
return ans
To see it in action consider the following:
@dataclass
class Quarter: # This class is used elsewhere in the codebase
year:int
quarter:int
def __repr__(self):
return f'{self.year} Q{self.quarter}'
def __hash__(self) -> int:
return self.year*4 + self.quarter
def __lt__(self, other):
return hash(self) < hash(other)
df = pd.DataFrame({
'quarter': [Quarter(2020, 4),
Quarter(2020, 3),
Quarter(2020, 2),
Quarter(2020, 1),
Quarter(2019, 4),
Quarter(2019, 3),
Quarter(2019, 2),
Quarter(2019, 1)],
'quantity1' : [1,1,1,1,1,1,1,1],
'quantity2' : [2,2,2,2,3,3,3,3]
})
Then you have:
df =
| quarter | quantity1 | quantity2 | |
|---|---|---|---|
| 0 | 2020 Q4 | 1 | 2 |
| 1 | 2020 Q3 | 1 | 2 |
| 2 | 2020 Q2 | 1 | 2 |
| 3 | 2020 Q1 | 1 | 2 |
| 4 | 2019 Q4 | 1 | 3 |
| 5 | 2019 Q3 | 1 | 3 |
| 6 | 2019 Q2 | 1 | 3 |
| 7 | 2019 Q1 | 1 | 3 |
and df.pipe(calculateYTDSum) =
| quarter | quantity1 | quantity2 | |
|---|---|---|---|
| 4 | 2020 Q4 | 4 | 8 |
| 5 | 2020 Q3 | 3 | 6 |
| 6 | 2020 Q2 | 2 | 4 |
| 7 | 2020 Q1 | 1 | 2 |
| 0 | 2019 Q4 | 4 | 12 |
| 1 | 2019 Q3 | 3 | 9 |
| 2 | 2019 Q2 | 2 | 6 |
| 3 | 2019 Q1 | 1 | 3 |
However, even for a small sample like the above, the calculation takes ~4ms - and tbh it looks unmaintainable.
I welcome any recommendations on Python tooling, libraries, Pandas extensions, or code changes that would improve the performance and/or simplicity of the code.
_yearcolumn from thegroupbyis retained (in fact I do the firstdropto remove_yearso that I canreset index) \$\endgroup\$