Skip to content

Commit bbe54ce

Browse files
authored
Merge pull request #424 from MIT-LCP/wrsamp-checksum-spf
Set correct checksums and samps_per_frame in Record.wrsamp
2 parents 7b361e2 + 489aa62 commit bbe54ce

File tree

4 files changed

+41
-9
lines changed

4 files changed

+41
-9
lines changed

tests/test_record.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,22 @@ def test_4d(self):
633633

634634
assert np.array_equal(sig_round, sig_target)
635635

636+
def test_write_smoothed(self):
637+
"""
638+
Test writing a record after reading with smooth_frames
639+
"""
640+
record = wfdb.rdrecord(
641+
"sample-data/drive02",
642+
physical=False,
643+
smooth_frames=True,
644+
)
645+
record.wrsamp(write_dir=self.temp_path)
646+
record2 = wfdb.rdrecord(
647+
os.path.join(self.temp_path, "drive02"),
648+
physical=False,
649+
)
650+
np.testing.assert_array_equal(record.d_signal, record2.d_signal)
651+
636652
def test_to_dataframe(self):
637653
record = wfdb.rdrecord("sample-data/test01_00s")
638654
df = record.to_dataframe()

wfdb/io/_header.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def set_defaults(self):
278278
for f in sfields:
279279
self.set_default(f)
280280

281-
def wrheader(self, write_dir=""):
281+
def wrheader(self, write_dir="", expanded=True):
282282
"""
283283
Write a WFDB header file. The signals are not used. Before
284284
writing:
@@ -290,6 +290,10 @@ def wrheader(self, write_dir=""):
290290
----------
291291
write_dir : str, optional
292292
The output directory in which the header is written.
293+
expanded : bool, optional
294+
Whether the header file should include `samps_per_frame` (this
295+
should only be true if the signal files are written using
296+
`expanded=True`).
293297
294298
Returns
295299
-------
@@ -305,6 +309,8 @@ def wrheader(self, write_dir=""):
305309
# sig_write_fields is a dictionary of
306310
# {field_name:required_channels}
307311
rec_write_fields, sig_write_fields = self.get_write_fields()
312+
if not expanded:
313+
sig_write_fields.pop("samps_per_frame", None)
308314

309315
# Check the validity of individual fields used to write the header
310316
# Record specification fields (and comments)

wfdb/io/_signal.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def wr_dats(self, expanded, write_dir):
142142
# Get all the fields used to write the header
143143
# Assuming this method was called through wrsamp,
144144
# these will have already been checked in wrheader()
145-
write_fields = self.get_write_fields()
145+
_, _ = self.get_write_fields()
146146

147147
if expanded:
148148
# Using list of arrays e_d_signal
@@ -152,8 +152,10 @@ def wr_dats(self, expanded, write_dir):
152152
self.check_field("d_signal")
153153

154154
# Check the cohesion of the d_signal field against the other
155-
# fields used to write the header
156-
self.check_sig_cohesion(write_fields, expanded)
155+
# fields used to write the header. (Note that for historical
156+
# reasons, this doesn't actually check any of the optional
157+
# header fields.)
158+
self.check_sig_cohesion([], expanded)
157159

158160
# Write each of the specified dat files
159161
self.wr_dat_files(expanded=expanded, write_dir=write_dir)
@@ -192,10 +194,8 @@ def check_sig_cohesion(self, write_fields, expanded):
192194
for ch in range(self.n_sig):
193195
if len(self.e_d_signal[ch]) != spf[ch] * self.sig_len:
194196
raise ValueError(
195-
"Length of channel "
196-
+ str(ch)
197-
+ "does not match samps_per_frame["
198-
+ str(ch + "]*sig_len")
197+
f"Length of channel {ch} does not match "
198+
f"samps_per_frame[{ch}]*sig_len"
199199
)
200200

201201
# For each channel (if any), make sure the digital format has no values out of bounds

wfdb/io/record.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,9 +922,19 @@ def wrsamp(self, expanded=False, write_dir=""):
922922
N/A
923923
924924
"""
925+
# Update the checksum field (except for channels that did not have
926+
# a checksum to begin with, or where the checksum was already
927+
# valid.)
928+
if self.checksum is not None:
929+
checksums = self.calc_checksum(expanded=expanded)
930+
for ch, old_val in enumerate(self.checksum):
931+
if old_val is None or (checksums[ch] - old_val) % 65536 == 0:
932+
checksums[ch] = old_val
933+
self.checksum = checksums
934+
925935
# Perform field validity and cohesion checks, and write the
926936
# header file.
927-
self.wrheader(write_dir=write_dir)
937+
self.wrheader(write_dir=write_dir, expanded=expanded)
928938
if self.n_sig > 0:
929939
# Perform signal validity and cohesion checks, and write the
930940
# associated dat files.

0 commit comments

Comments
 (0)