From 6dc51e73f9d97434280bf7494a0c99d7820fec6a Mon Sep 17 00:00:00 2001 From: Will Woods Date: Jun 05 2020 17:55:35 +0000 Subject: countme.ItemReader: add .fields and ._get_fields() Refactor the field-reading / field-check from ItemReader._get_reader() to make things a little clearer, and add a `.fields` property so callers can examine what fields are in the thing they're reading. --- diff --git a/countme/__init__.py b/countme/__init__.py index 1e82f1e..e3504f8 100644 --- a/countme/__init__.py +++ b/countme/__init__.py @@ -334,15 +334,21 @@ class ItemReader: self._itemtuple = itemtuple self._itemfields = itemtuple._fields self._itemfactory = itemtuple._make - self._filefields = None self._get_reader(**kwargs) - if not self._filefields: + filefields = self._get_fields() + if not filefields: raise ReaderError("no field names found") - if self._filefields != self._itemfields: - raise ReaderError(f"field mismatch: expected {self._itemfields}, got {self._filefields}") + if filefields != self._itemfields: + raise ReaderError(f"field mismatch: expected {self._itemfields}, got {filefields}") + @property + def fields(self): + return self._itemfields def _get_reader(self): - '''Set up the ItemReader. - Should set self._filefields to a tuple of the fields found in fp.''' + '''Set up the ItemReader.''' + raise NotImplementedError + def _get_fields(self): + '''Called immediately after _get_reader(). + Should return a tuple of the fieldnames found in self._fp.''' raise NotImplementedError def _iter_rows(self): '''Return an iterator/generator that produces a row for each item.''' @@ -355,11 +361,13 @@ class CSVReader(ItemReader): def _get_reader(self, **kwargs): import csv self._reader = csv.reader(self._fp) - self._filefields = tuple(next(self._reader)) - # If we have numbers in our fieldnames, probably there was no header - if any(name.isnumeric() for name in self._filefields): - header = ','.join(fields) - raise ReaderError(f"header bad/missing, got: {header}") + def _get_fields(self): + filefields = tuple(next(self._reader)) + # Sanity check: if any field is a number, this isn't a header + if any(name.isnumeric() for name in filefields): + header = ','.join(filefields) + raise ReaderError(f"header bad/missing: expected {self._itemfields}, got {header!r}") + return filefields def _iter_rows(self): return self._reader @@ -372,12 +380,10 @@ class SQLiteReader(ItemReader): # TODO: self._con.set_progress_handler(handler, call_interval) self._cur = self._con.cursor() self._tablename = tablename - if False and sqlite3.sqlite_version_info >= (3,16,0): - fields_sql = f"SELECT name FROM pragma_table_info(?)" - self._filefields = tuple(r[0] for r in self._cur.execute(fields_sql, (tablename,))) - else: - fields_sql = f"PRAGMA table_info('{tablename}')" - self._filefields = tuple(r[1] for r in self._cur.execute(fields_sql)) + def _get_fields(self): + fields_sql = f"PRAGMA table_info('{self._tablename}')" + filefields = tuple(r[1] for r in self._cur.execute(fields_sql)) + return filefields def _iter_rows(self): fields = ",".join(self._itemfields) return self._cur.execute(f"SELECT {fields} FROM {self._tablename}")