Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 51 additions & 24 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3003,10 +3003,9 @@ def combiner(x, y, needs_i8_conversion=False):
return self.combine(other, combiner, overwrite=False)

def update(self, other, join='left', overwrite=True, filter_func=None,
raise_conflict=False):
raise_conflict=False, on=None):
"""
Modify DataFrame in place using non-NA values from passed
DataFrame. Aligns on indices
Modify DataFrame in place using non-NA values from passed DataFrame.

Parameters
----------
Expand All @@ -3020,6 +3019,10 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
raise_conflict : boolean
If True, will raise an error if the DataFrame and other both
contain data in the same place.
on : label or list, optional
Identify the column to should match up observations in other and
self. If None, other.reindex_like(self) is called so the index
must match to get a meaningful result.
"""
# TODO: Support other joins
if join != 'left': # pragma: no cover
Expand All @@ -3028,31 +3031,55 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
if not isinstance(other, DataFrame):
other = DataFrame(other)

other = other.reindex_like(self)
if on is None:
other = other.reindex(index=self.index)
else:
try:
old_index = self.index
col_order = self.columns
self.set_index(on, inplace=True)
other.set_index(on, inplace=True)
other = other.reindex(index=self.index)
except Exception, err:
self.reset_index(inplace=True)
self.set_index(old_index)
raise(err)

for col in self.columns:
this = self[col].values
that = other[col].values
if filter_func is not None:
mask = -filter_func(this) | isnull(that)
else:
if raise_conflict:
mask_this = notnull(that)
mask_that = notnull(this)
if any(mask_this & mask_that):
raise ValueError("Data overlaps.")
try:
for col in other.columns:
if col not in self: # don't update what doesn't exist
continue
this = self[col].values
that = other[col].values
if filter_func is not None:
mask = -filter_func(this) | isnull(that)
else:
if raise_conflict:
mask_this = notnull(that)
mask_that = notnull(this)
if any(mask_this & mask_that):
raise ValueError("Data overlaps.")

if overwrite:
mask = isnull(that)

# don't overwrite columns unecessarily
if mask.all():
continue
else:
mask = notnull(this)

if overwrite:
mask = isnull(that)
self[col] = expressions.where(
mask, this, that, raise_on_error=True)

# don't overwrite columns unecessarily
if mask.all():
continue
else:
mask = notnull(this)
except Exception, err:
raise(err)

self[col] = expressions.where(
mask, this, that, raise_on_error=True)
finally:
if on is not None:
self.reset_index(inplace=True)
self.set_index(old_index)
self = self[col_order]

#----------------------------------------------------------------------
# Misc methods
Expand Down
39 changes: 38 additions & 1 deletion pandas/tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from numpy.random import randn
import numpy as np
import numpy.ma as ma
from numpy.testing import assert_array_equal
from numpy.testing import assert_array_equal, assert_
import numpy.ma.mrecords as mrecords

import pandas.core.nanops as nanops
Expand Down Expand Up @@ -9974,6 +9974,43 @@ def test_update(self):
[1.5, nan, 7.]])
assert_frame_equal(df, expected)

def test_update_on(self):
df = DataFrame([[np.nan, 'A'],
[np.nan, 'A'],
[np.nan, 'A'],
[1.5, 'B'],
[2.2, 'C'],
[3.1, 'C'],
[1.2, 'B']], columns=['number', 'name'])

df2 = DataFrame([[3.5, 'A']], columns=['number', 'name'])

expected = DataFrame([[3.5, 'A'],
[3.5, 'A'],
[3.5, 'A'],
[1.5, 'B'],
[2.2, 'C'],
[3.1, 'C'],
[1.2, 'B']], columns=['number', 'name'])
df.update(df2, on='name')
assert_frame_equal(df, expected)

df = DataFrame([[np.nan, 'A'],
[np.nan, 'A'],
[np.nan, 'A'],
[1.5, 'B'],
[2.2, 'C'],
[3.1, 'C'],
[1.2, 'B']], columns=['number', 'name'])

df2 = DataFrame([[3.5, 'A'], [2.5, 'A']],
columns=['number', 'name'])

assertRaises(ValueError, df.update, df2, on='name')

## and the index should be reset
assert_(df.index.equals(pd.Index(range(7))))

def test_update_dtypes(self):

# gh 3016
Expand Down