From 07b033a40107a59180962b8572e4a3cd97b06cfa Mon Sep 17 00:00:00 2001 From: Mike McLean Date: Oct 10 2024 13:45:14 +0000 Subject: [PATCH 1/4] match longest archivetype extension first --- diff --git a/kojihub/kojihub.py b/kojihub/kojihub.py index 28d4c74..73b4e64 100644 --- a/kojihub/kojihub.py +++ b/kojihub/kojihub.py @@ -7796,13 +7796,15 @@ def get_archive_type(filename=None, type_name=None, type_id=None, strict=False): else: raise koji.GenericError('one of filename, type_name, or type_id must be specified') - parts = filename.split('.') + # otherwise match the filename query = QueryProcessor( tables=['archivetypes'], columns=['id', 'name', 'description', 'extensions', 'compression_type'], clauses=['extensions ~* %(pattern)s'], ) - for start in range(len(parts) - 1, -1, -1): + # match longest extension first. e.g. .tar.gz before .gz + parts = filename.split('.') + for start in range(len(parts)): ext = '.'.join(parts[start:]) query.values['pattern'] = r'(\s|^)%s(\s|$)' % ext results = query.execute() @@ -7814,7 +7816,7 @@ def get_archive_type(filename=None, type_name=None, type_id=None, strict=False): raise koji.GenericError('multiple matches for file extension: %s' % ext) # otherwise if strict: - raise koji.GenericError('unsupported file extension: %s' % ext) + raise koji.GenericError('unsupported file extension: %s' % filename) else: return None diff --git a/tests/test_hub/test_get_archive_type.py b/tests/test_hub/test_get_archive_type.py index 5318a68..ec965cf 100644 --- a/tests/test_hub/test_get_archive_type.py +++ b/tests/test_hub/test_get_archive_type.py @@ -61,7 +61,8 @@ class TestGetArchiveType(DBQueryTestCase): archive_info = [{'id': 1, 'name': 'archive-type-1', 'extensions': 'ext'}, {'id': 2, 'name': 'archive-type-2', 'extensions': 'ext'}] filename = 'test-filename.ext' - self.qp_execute_return_value = archive_info + self.qp_execute_side_effect = [[], archive_info] + # no matches for full name, multiple matches for .ext with self.assertRaises(koji.GenericError) as ex: kojihub.get_archive_type(filename=filename) self.assertEqual("multiple matches for file extension: ext", str(ex.exception)) From fb228a107a69d6ed605a10a839036f10bbb78351 Mon Sep 17 00:00:00 2001 From: Mike McLean Date: Oct 10 2024 18:11:27 +0000 Subject: [PATCH 2/4] simplify extension match this avoids errors if the ext value contains special characters --- diff --git a/kojihub/kojihub.py b/kojihub/kojihub.py index 73b4e64..928e99d 100644 --- a/kojihub/kojihub.py +++ b/kojihub/kojihub.py @@ -7800,13 +7800,13 @@ def get_archive_type(filename=None, type_name=None, type_id=None, strict=False): query = QueryProcessor( tables=['archivetypes'], columns=['id', 'name', 'description', 'extensions', 'compression_type'], - clauses=['extensions ~* %(pattern)s'], + clauses=[r"%(ext)s = ANY(regexp_split_to_array(extensions, '\s+'))"], ) # match longest extension first. e.g. .tar.gz before .gz parts = filename.split('.') for start in range(len(parts)): ext = '.'.join(parts[start:]) - query.values['pattern'] = r'(\s|^)%s(\s|$)' % ext + query.values['ext'] = ext results = query.execute() if len(results) == 1: diff --git a/tests/test_hub/test_get_archive_type.py b/tests/test_hub/test_get_archive_type.py index ec965cf..39a8fb4 100644 --- a/tests/test_hub/test_get_archive_type.py +++ b/tests/test_hub/test_get_archive_type.py @@ -71,7 +71,7 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) + self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called() @@ -92,7 +92,7 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) + self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called() @@ -112,7 +112,7 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) + self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called() @@ -130,7 +130,7 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) + self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called() From 7cf1bff0992760b8a2f65eb4fee271b2017c96af Mon Sep 17 00:00:00 2001 From: Mike McLean Date: Oct 11 2024 04:49:57 +0000 Subject: [PATCH 3/4] make sure extension matching is case insensitive --- diff --git a/kojihub/kojihub.py b/kojihub/kojihub.py index 928e99d..34c499f 100644 --- a/kojihub/kojihub.py +++ b/kojihub/kojihub.py @@ -7800,10 +7800,11 @@ def get_archive_type(filename=None, type_name=None, type_id=None, strict=False): query = QueryProcessor( tables=['archivetypes'], columns=['id', 'name', 'description', 'extensions', 'compression_type'], - clauses=[r"%(ext)s = ANY(regexp_split_to_array(extensions, '\s+'))"], + clauses=[r"%(ext)s IN (SELECT lower(s)" + r" FROM unnest(regexp_split_to_array(extensions, '\s+')) AS s)"], ) # match longest extension first. e.g. .tar.gz before .gz - parts = filename.split('.') + parts = filename.lower().split('.') for start in range(len(parts)): ext = '.'.join(parts[start:]) query.values['ext'] = ext From eed1854390cb50a28121631ea9c52052186ef546 Mon Sep 17 00:00:00 2001 From: Mike McLean Date: Oct 11 2024 13:49:39 +0000 Subject: [PATCH 4/4] fix unit test --- diff --git a/tests/test_hub/test_get_archive_type.py b/tests/test_hub/test_get_archive_type.py index 39a8fb4..83f321f 100644 --- a/tests/test_hub/test_get_archive_type.py +++ b/tests/test_hub/test_get_archive_type.py @@ -71,7 +71,10 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) + _clauses = [ + "%(ext)s IN (SELECT lower(s) FROM " + "unnest(regexp_split_to_array(extensions, '\\s+')) AS s)"] + self.assertEqual(query.clauses, _clauses) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called() @@ -92,7 +95,10 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) + _clauses = [ + "%(ext)s IN (SELECT lower(s) FROM " + "unnest(regexp_split_to_array(extensions, '\\s+')) AS s)"] + self.assertEqual(query.clauses, _clauses) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called() @@ -112,7 +118,10 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) + _clauses = [ + "%(ext)s IN (SELECT lower(s) FROM " + "unnest(regexp_split_to_array(extensions, '\\s+')) AS s)"] + self.assertEqual(query.clauses, _clauses) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called() @@ -130,7 +139,10 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) + _clauses = [ + "%(ext)s IN (SELECT lower(s) FROM " + "unnest(regexp_split_to_array(extensions, '\\s+')) AS s)"] + self.assertEqual(query.clauses, _clauses) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called()