From a1b03036752546691275844adca9a0589a6f3e76 Mon Sep 17 00:00:00 2001 From: Tristan Crockett Date: Wed, 8 May 2019 14:24:29 -0500 Subject: [PATCH 1/5] Cleanup pass 1 --- src/tests/architect_tests/utils.py | 7 - src/tests/catwalk_tests/utils.py | 4 - src/triage/component/architect/utils.py | 184 ------------------------ 3 files changed, 195 deletions(-) diff --git a/src/tests/architect_tests/utils.py b/src/tests/architect_tests/utils.py index b56f25b32..9c4baccc0 100644 --- a/src/tests/architect_tests/utils.py +++ b/src/tests/architect_tests/utils.py @@ -2,12 +2,9 @@ import shutil import sys import tempfile -import random from contextlib import contextmanager import pandas as pd -import yaml -import numpy def convert_string_column_to_date(column): @@ -138,10 +135,6 @@ def TemporaryDirectory(): shutil.rmtree(name) -def fake_labels(length): - return numpy.array([random.choice([True, False]) for i in range(0, length)]) - - def assert_index(engine, table, column): """Assert that a table has an index on a given column diff --git a/src/tests/catwalk_tests/utils.py b/src/tests/catwalk_tests/utils.py index da20734c0..6b1d5bc47 100644 --- a/src/tests/catwalk_tests/utils.py +++ b/src/tests/catwalk_tests/utils.py @@ -14,10 +14,6 @@ from triage.util.structs import FeatureNameList -def fake_labels(length): - return numpy.array([random.choice([True, False]) for i in range(0, length)]) - - @pytest.fixture def sample_metadata(): return { diff --git a/src/triage/component/architect/utils.py b/src/triage/component/architect/utils.py index b13498085..23dcb6ffc 100644 --- a/src/triage/component/architect/utils.py +++ b/src/triage/component/architect/utils.py @@ -1,19 +1,8 @@ -import datetime -import shutil -import sys -import random -from contextlib import contextmanager import functools import operator -import tempfile import sqlalchemy -import pandas as pd -import numpy -from sqlalchemy.orm import sessionmaker - -from triage.component.results_schema import Model from triage.util.structs import FeatureNameList @@ -38,178 +27,5 @@ def feature_list(feature_dictionary): )) -def convert_string_column_to_date(column): - return [datetime.datetime.strptime(date, "%Y-%m-%d").date() for date in column] - - -def create_features_table(table_number, table, engine): - engine.execute( - """ - create table features.features{} ( - entity_id int, as_of_date date, f{} int, f{} int - ) - """.format( - table_number, (table_number * 2) + 1, (table_number * 2) + 2 - ) - ) - for row in table: - engine.execute( - """ - insert into features.features{} values (%s, %s, %s, %s) - """.format( - table_number - ), - row, - ) - - -def create_entity_date_df( - labels, - states, - as_of_dates, - state_one, - state_two, - label_name, - label_type, - label_timespan, -): - """ This function makes a pandas DataFrame that mimics the entity-date table - for testing against. - """ - 0, "2016-02-01", "1 month", "booking", "binary", 0 - labels_table = pd.DataFrame( - labels, - columns=[ - "entity_id", - "as_of_date", - "label_timespan", - "label_name", - "label_type", - "label", - ], - ) - states_table = pd.DataFrame( - states, columns=["entity_id", "as_of_date", "state_one", "state_two"] - ).set_index(["entity_id", "as_of_date"]) - as_of_dates = [date.date() for date in as_of_dates] - labels_table = labels_table[labels_table["label_name"] == label_name] - labels_table = labels_table[labels_table["label_type"] == label_type] - labels_table = labels_table[labels_table["label_timespan"] == label_timespan] - labels_table = labels_table.join(other=states_table, on=("entity_id", "as_of_date")) - labels_table = labels_table[labels_table["state_one"] & labels_table["state_two"]] - ids_dates = labels_table[["entity_id", "as_of_date"]] - ids_dates = ids_dates.sort_values(["entity_id", "as_of_date"]) - ids_dates["as_of_date"] = [ - datetime.datetime.strptime(date, "%Y-%m-%d").date() - for date in ids_dates["as_of_date"] - ] - ids_dates = ids_dates[ids_dates["as_of_date"].isin(as_of_dates)] - print(ids_dates) - - return ids_dates.reset_index(drop=True) - - -def NamedTempFile(): - if sys.version_info >= (3, 0, 0): - return tempfile.NamedTemporaryFile(mode="w+", newline="") - else: - return tempfile.NamedTemporaryFile() - - -@contextmanager -def TemporaryDirectory(): - name = tempfile.mkdtemp() - try: - yield name - finally: - shutil.rmtree(name) - - -def fake_labels(length): - return numpy.array([random.choice([True, False]) for i in range(0, length)]) - - -class MockTrainedModel(object): - def predict_proba(self, dataset): - return numpy.random.rand(len(dataset), len(dataset)) - - -def fake_trained_model(project_path, model_storage_engine, db_engine): - """Creates and stores a trivial trained model - - Args: - project_path (string) a desired fs/s3 project path - model_storage_engine (triage.storage.ModelStorageEngine) - db_engine (sqlalchemy.engine) - - Returns: - (int) model id for database retrieval - """ - trained_model = MockTrainedModel() - model_storage_engine.write(trained_model, "abcd") - session = sessionmaker(db_engine)() - db_model = Model(model_hash="abcd") - session.add(db_model) - session.commit() - return trained_model, db_model.model_id - - -def assert_index(engine, table, column): - """Assert that a table has an index on a given column - - Does not care which position the column is in the index - Modified from https://www.gab.lc/articles/index_on_id_with_postgresql - - Args: - engine (sqlalchemy.engine) a database engine - table (string) the name of a table - column (string) the name of a column - """ - query = """ - SELECT 1 - FROM pg_class t - JOIN pg_index ix ON t.oid = ix.indrelid - JOIN pg_class i ON i.oid = ix.indexrelid - JOIN pg_attribute a ON a.attrelid = t.oid - WHERE - a.attnum = ANY(ix.indkey) AND - t.relkind = 'r' AND - t.relname = '{table_name}' AND - a.attname = '{column_name}' - """.format( - table_name=table, column_name=column - ) - num_results = len([row for row in engine.execute(query)]) - assert num_results >= 1 - - -def create_dense_state_table(db_engine, table_name, data): - db_engine.execute( - """create table {} ( - entity_id int, - state text, - start_time timestamp, - end_time timestamp - )""".format( - table_name - ) - ) - - for row in data: - db_engine.execute( - "insert into {} values (%s, %s, %s, %s)".format(table_name), row - ) - - -def create_binary_outcome_events(db_engine, table_name, events_data): - db_engine.execute( - "create table events (entity_id int, outcome_date date, outcome bool)" - ) - for event in events_data: - db_engine.execute( - "insert into {} values (%s, %s, %s::bool)".format(table_name), event - ) - - def retry_if_db_error(exception): return isinstance(exception, sqlalchemy.exc.OperationalError) From 9be51ffbf6cf8312272a975102a38dd3539ce84e Mon Sep 17 00:00:00 2001 From: Tristan Crockett Date: Wed, 8 May 2019 15:41:42 -0500 Subject: [PATCH 2/5] Phase 2 --- src/tests/catwalk_tests/utils.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/tests/catwalk_tests/utils.py b/src/tests/catwalk_tests/utils.py index 6b1d5bc47..0a8fff57b 100644 --- a/src/tests/catwalk_tests/utils.py +++ b/src/tests/catwalk_tests/utils.py @@ -1,16 +1,8 @@ import datetime -import random -import tempfile -from contextlib import contextmanager import pytest -import numpy import pandas -import yaml -from triage.component.catwalk.storage import ( - ProjectStorage, -) from triage.util.structs import FeatureNameList @@ -42,13 +34,3 @@ def sample_df(): "label": ["good", "bad"], } ).set_index("entity_id") - - -@pytest.fixture -def sample_matrix_store(): - with tempfile.TemporaryDirectory() as tempdir: - project_storage = ProjectStorage(tempdir) - store = project_storage.matrix_storage_engine().get_store("1234") - store.matrix = sample_df() - store.metadata = sample_metadata() - return store From c5b5696f57680d58fd667519c3d6ff7912412f32 Mon Sep 17 00:00:00 2001 From: Tristan Crockett Date: Wed, 8 May 2019 16:15:44 -0500 Subject: [PATCH 3/5] Finish removing unused utils --- src/triage/component/architect/utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/triage/component/architect/utils.py b/src/triage/component/architect/utils.py index 23dcb6ffc..f81f4f61e 100644 --- a/src/triage/component/architect/utils.py +++ b/src/triage/component/architect/utils.py @@ -1,8 +1,6 @@ import functools import operator -import sqlalchemy - from triage.util.structs import FeatureNameList @@ -25,7 +23,3 @@ def feature_list(feature_dictionary): (feature_dictionary[key] for key in feature_dictionary.keys()), ) )) - - -def retry_if_db_error(exception): - return isinstance(exception, sqlalchemy.exc.OperationalError) From 0b57f843bc13b7a260a76f318a9811e860fec9de Mon Sep 17 00:00:00 2001 From: Tristan Crockett Date: Wed, 8 May 2019 17:33:01 -0500 Subject: [PATCH 4/5] Move around some utils --- .../test_database_reflection.py | 50 ---- .../architect/database_reflection.py | 115 -------- .../architect/entity_date_table_generators.py | 3 +- .../architect/feature_dictionary_creator.py | 2 +- src/triage/component/architect/utils.py | 4 - src/triage/component/architect/validations.py | 151 ---------- .../component/audition/distance_from_best.py | 2 +- .../audition/model_group_performance.py | 2 +- .../component/audition/selection_rule_grid.py | 2 +- src/triage/component/audition/utils.py | 6 - src/triage/component/catwalk/subsetters.py | 11 +- src/triage/component/catwalk/utils.py | 261 ------------------ src/triage/component/collate/collate.py | 5 +- src/triage/component/results_schema/utils.py | 110 ++++++++ src/triage/experiments/validate.py | 6 +- src/triage/util/conf.py | 4 + src/triage/util/db.py | 55 ++++ src/triage/util/hash.py | 14 + src/triage/util/iteration.py | 27 ++ src/triage/util/sorting.py | 49 ++++ src/triage/util/sql.py | 3 + 21 files changed, 278 insertions(+), 604 deletions(-) delete mode 100644 src/tests/architect_tests/test_database_reflection.py delete mode 100644 src/triage/component/architect/database_reflection.py delete mode 100644 src/triage/component/architect/validations.py delete mode 100644 src/triage/component/catwalk/utils.py create mode 100644 src/triage/component/results_schema/utils.py create mode 100644 src/triage/util/hash.py create mode 100644 src/triage/util/iteration.py create mode 100644 src/triage/util/sorting.py create mode 100644 src/triage/util/sql.py diff --git a/src/tests/architect_tests/test_database_reflection.py b/src/tests/architect_tests/test_database_reflection.py deleted file mode 100644 index 9a8aa5fa1..000000000 --- a/src/tests/architect_tests/test_database_reflection.py +++ /dev/null @@ -1,50 +0,0 @@ -from sqlalchemy import Table -from sqlalchemy import create_engine -from sqlalchemy.types import VARCHAR -from testing.postgresql import Postgresql -from unittest import TestCase - -from triage.component.architect import database_reflection as dbreflect - - -class TestDatabaseReflection(TestCase): - def setUp(self): - self.postgresql = Postgresql() - self.engine = create_engine(self.postgresql.url()) - - def tearDown(self): - self.postgresql.stop() - - def test_split_table(self): - assert dbreflect.split_table("staging.incidents") == ("staging", "incidents") - assert dbreflect.split_table("incidents") == (None, "incidents") - with self.assertRaises(ValueError): - dbreflect.split_table("blah.staging.incidents") - - def test_table_object(self): - assert isinstance(dbreflect.table_object("incidents", self.engine), Table) - - def test_reflected_table(self): - self.engine.execute("create table incidents (col1 varchar)") - assert dbreflect.reflected_table("incidents", self.engine).exists() - - def test_table_exists(self): - self.engine.execute("create table incidents (col1 varchar)") - assert dbreflect.table_exists("incidents", self.engine) - assert not dbreflect.table_exists("compliments", self.engine) - - def test_table_has_data(self): - self.engine.execute("create table incidents (col1 varchar)") - self.engine.execute("create table compliments (col1 varchar)") - self.engine.execute("insert into compliments values ('good job')") - assert dbreflect.table_has_data("compliments", self.engine) - assert not dbreflect.table_has_data("incidents", self.engine) - - def test_table_has_column(self): - self.engine.execute("create table incidents (col1 varchar)") - assert dbreflect.table_has_column("incidents", "col1", self.engine) - assert not dbreflect.table_has_column("incidents", "col2", self.engine) - - def test_column_type(self): - self.engine.execute("create table incidents (col1 varchar)") - assert dbreflect.column_type("incidents", "col1", self.engine) == VARCHAR diff --git a/src/triage/component/architect/database_reflection.py b/src/triage/component/architect/database_reflection.py deleted file mode 100644 index f56697379..000000000 --- a/src/triage/component/architect/database_reflection.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Functions to retrieve basic information about tables in a Postgres database""" -from sqlalchemy import MetaData, Table - - -def split_table(table_name): - """Split a fully-qualified table name into schema and table - - Args: - table_name (string) A table name, either with or without a schema prefix - - Returns: (tuple) of schema and table name - """ - table_parts = table_name.split(".") - if len(table_parts) == 2: - return tuple(table_parts) - elif len(table_parts) == 1: - return (None, table_parts[0]) - else: - raise ValueError("Table name in unknown format") - - -def table_object(table_name, db_engine): - """Produce a table object for the given table name - - This does not load data about the table from the engine yet, - so it is safe to call for a table that doesn't exist. - - Args: - table_name (string) A table name (with schema) - db_engine (sqlalchemy.engine) - - Returns: (sqlalchemy.Table) - """ - schema, table = split_table(table_name) - meta = MetaData(schema=schema, bind=db_engine) - return Table(table, meta) - - -def reflected_table(table_name, db_engine): - """Produce a loaded table object for the given table name - - Will attempt to load the metadata about the table from the database - So this will fail if the table doesn't exist. - - Args: - table_name (string) A table name (with schema) - db_engine (sqlalchemy.engine) - - Returns: (sqlalchemy.Table) A loaded table object - """ - schema, table = split_table(table_name) - meta = MetaData(schema=schema, bind=db_engine) - return Table(table, meta, autoload=True, autoload_from=db_engine) - - -def table_exists(table_name, db_engine): - """Checks whether the table exists - - Args: - table_name (string) A table name (with schema) - db_engine (sqlalchemy.engine) - - Returns: (boolean) Whether or not the table exists in the database - """ - return table_object(table_name, db_engine).exists() - - -def table_has_data(table_name, db_engine): - """Check whether the table contains any data - - Args: - table_name (string) A table name (with schema) - db_engine (sqlalchemy.engine) - - Returns: (boolean) Whether or not the table has any data - """ - if not table_exists(table_name, db_engine): - return False - results = [ - row for row in db_engine.execute("select * from {} limit 1".format(table_name)) - ] - - return len(results) > 0 - - -def table_has_column(table_name, column, db_engine): - """Check whether the table contains a column of the given name - - The table is expected to exist. - - Args: - table_name (string) A table name (with schema) - column (string) A column name - db_engine (sqlalchemy.engine) - - Returns: (boolean) Whether or not the table contains the column - """ - return column in reflected_table(table_name, db_engine).columns - - -def column_type(table_name, column, db_engine): - """Find the database type of the given column in the given table - - The table is expected to exist, and contain a column of the given name - - Args: - table_name (string) A table name (with schema) - column (string) A column name - db_engine (sqlalchemy.engine) - - Returns: (sqlalchemy.types) The DDL type of the column; For instance, - sqlalchemy.types.BOOLEAN instead of - sqlalchemy.types.Boolean - """ - return type(reflected_table(table_name, db_engine).columns[column].type) diff --git a/src/triage/component/architect/entity_date_table_generators.py b/src/triage/component/architect/entity_date_table_generators.py index 728f14aa5..1a77ab34d 100644 --- a/src/triage/component/architect/entity_date_table_generators.py +++ b/src/triage/component/architect/entity_date_table_generators.py @@ -1,7 +1,6 @@ import logging -from triage.component.architect.database_reflection import table_has_data -from triage.database_reflection import table_row_count, table_exists +from triage.database_reflection import table_row_count, table_exists, table_has_data DEFAULT_ACTIVE_STATE = "active" diff --git a/src/triage/component/architect/feature_dictionary_creator.py b/src/triage/component/architect/feature_dictionary_creator.py index 4d0082001..7974662dd 100644 --- a/src/triage/component/architect/feature_dictionary_creator.py +++ b/src/triage/component/architect/feature_dictionary_creator.py @@ -1,5 +1,5 @@ import logging -from triage.component.architect.utils import str_in_sql +from triage.util.sql import str_in_sql from triage.util.structs import FeatureNameList diff --git a/src/triage/component/architect/utils.py b/src/triage/component/architect/utils.py index f81f4f61e..a912279a4 100644 --- a/src/triage/component/architect/utils.py +++ b/src/triage/component/architect/utils.py @@ -4,10 +4,6 @@ from triage.util.structs import FeatureNameList -def str_in_sql(values): - return ",".join(map(lambda x: "'{}'".format(x), values)) - - def feature_list(feature_dictionary): """Convert a feature dictionary to a sorted list diff --git a/src/triage/component/architect/validations.py b/src/triage/component/architect/validations.py deleted file mode 100644 index 4749411eb..000000000 --- a/src/triage/component/architect/validations.py +++ /dev/null @@ -1,151 +0,0 @@ -"""Functions for validating input, mostly around database schema and state""" -from triage.component.architect.database_reflection import ( - table_exists, - table_has_column, - column_type, - table_has_data, -) -from sqlalchemy.types import ( - BIGINT, - BOOLEAN, - DATE, - DATETIME, - INTEGER, - SMALLINT, - TEXT, - TIMESTAMP, - VARCHAR, -) -from sqlalchemy.dialects.postgresql.base import TIMESTAMP as POSTGRES_TIMESTAMP - - -def table_should_exist(table_name, db_engine): - """Ensures that the table exists in the given database - - Args: - table_name (string) A table name (with schema) - db_engine (sqlalchemy.engine) - - Raises: ValueError if the table does not exist - """ - if not table_exists(table_name, db_engine): - raise ValueError("{} table does not exist".format(table_name)) - - -def table_should_have_column(table_name, column, db_engine): - """Ensures that the table has the given column - - Args: - table_name (string) A table name (with schema) - column (string) The name of a column - db_engine (sqlalchemy.engine) - - Raises: ValueError if the table does not contain the column - """ - table_should_exist(table_name, db_engine) - if not table_has_column(table_name, column, db_engine): - raise ValueError("{} table does not have {} column".format(table_name, column)) - - -def table_should_have_data(table_name, db_engine): - """Ensures that the table has at least one row - - Args: - table_name (string) A table name (with schema) - db_engine (sqlalchemy.engine) - - Raises: ValueError if the table does not have at least one row - """ - table_should_exist(table_name, db_engine) - if not table_has_data(table_name, db_engine): - raise ValueError("{} table does not have any data".format(table_name)) - - -def column_should_be_in_types(table_name, column, valid_types, db_engine): - """Ensures that the given column is one of the given types - - Args: - table_name (string) A table name (with schema) - column (string) The name of a column - valid_types (list) A list of SQLAlchemy DDL types, like sqlalchemy.types.BOOLEAN - db_engine (sqlalchemy.engine) - - Raises: ValueError if the column is not one of the given types - """ - table_should_have_column(table_name, column, db_engine) - reflected_type = column_type(table_name, column, db_engine) - if reflected_type not in valid_types: - raise ValueError( - "{}.{} should be in types {} but was {}".format( - table_name, column, valid_types, reflected_type - ) - ) - - -def column_should_be_booleanlike(table_name, column, db_engine): - """Ensures that the given column can be casted to a boolean - - Allows BOOLEAN, SMALLINT, and INTEGER, as these are commonly used. - It does not check that the data in a SMALLINT column all conforms to 0/1 - - Args: - table_name (string) A table name (with schema) - column (string) The name of a column - db_engine (sqlalchemy.engine) - - Raises: ValueError if the column is not a recognized boolean-compatible type - """ - table_should_have_column(table_name, column, db_engine) - column_should_be_in_types( - table_name, column, [BOOLEAN, SMALLINT, INTEGER], db_engine - ) - - -def column_should_be_timelike(table_name, column, db_engine): - """Ensures that the given column can be used for temporal data - - Many date/time operations are fairly compatible with each other, - so this routine is fairly permissive. If you want to be more strict, - call column_should_be_in_types directly - - Args: - table_name (string) A table name (with schema) - column (string) The name of a column - db_engine (sqlalchemy.engine) - - Raises: ValueError if the column is not a recognized temporal type - """ - table_should_have_column(table_name, column, db_engine) - column_should_be_in_types( - table_name, column, [DATE, DATETIME, TIMESTAMP, POSTGRES_TIMESTAMP], db_engine - ) - - -def column_should_be_intlike(table_name, column, db_engine): - """Ensures that the given column can act as an integer - - Args: - table_name (string) A table name (with schema) - column (string) The name of a column - db_engine (sqlalchemy.engine) - - Raises: ValueError if the column is not a recognized integer type - """ - table_should_have_column(table_name, column, db_engine) - column_should_be_in_types( - table_name, column, [BIGINT, SMALLINT, INTEGER], db_engine - ) - - -def column_should_be_stringlike(table_name, column, db_engine): - """Ensures that the given column can act as an string - - Args: - table_name (string) A table name (with schema) - column (string) The name of a column - db_engine (sqlalchemy.engine) - - Raises: ValueError if the column is not a recognized string type - """ - table_should_have_column(table_name, column, db_engine) - column_should_be_in_types(table_name, column, [VARCHAR, TEXT], db_engine) diff --git a/src/triage/component/audition/distance_from_best.py b/src/triage/component/audition/distance_from_best.py index 856856c6c..34e449c4b 100644 --- a/src/triage/component/audition/distance_from_best.py +++ b/src/triage/component/audition/distance_from_best.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd -from .utils import str_in_sql +from triage.util.sql import str_in_sql from .metric_directionality import sql_rank_order from .plotting import plot_cats, plot_bounds diff --git a/src/triage/component/audition/model_group_performance.py b/src/triage/component/audition/model_group_performance.py index 9f8820c29..f22b8029a 100644 --- a/src/triage/component/audition/model_group_performance.py +++ b/src/triage/component/audition/model_group_performance.py @@ -4,7 +4,7 @@ import numpy as np from .plotting import plot_cats -from .utils import str_in_sql +from triage.util.sql import str_in_sql class ModelGroupPerformancePlotter(object): diff --git a/src/triage/component/audition/selection_rule_grid.py b/src/triage/component/audition/selection_rule_grid.py index a3681f27e..f7dfb2ca3 100644 --- a/src/triage/component/audition/selection_rule_grid.py +++ b/src/triage/component/audition/selection_rule_grid.py @@ -2,7 +2,7 @@ import logging from .selection_rules import BoundSelectionRule -from .utils import make_list +from triage.util.conf import make_list def _expand_param_sets(rule_instances, values): diff --git a/src/triage/component/audition/utils.py b/src/triage/component/audition/utils.py index 0d1c6411d..e69de29bb 100644 --- a/src/triage/component/audition/utils.py +++ b/src/triage/component/audition/utils.py @@ -1,6 +0,0 @@ -def make_list(a): - return [a] if not isinstance(a, list) else a - - -def str_in_sql(values): - return ",".join(map(lambda x: "'{}'".format(x), values)) diff --git a/src/triage/component/catwalk/subsetters.py b/src/triage/component/catwalk/subsetters.py index db8d290c5..f95f368ac 100644 --- a/src/triage/component/catwalk/subsetters.py +++ b/src/triage/component/catwalk/subsetters.py @@ -1,12 +1,19 @@ import logging -import pandas from sqlalchemy.orm import sessionmaker from triage.component.architect.entity_date_table_generators import EntityDateTableGenerator -from triage.component.catwalk.utils import (filename_friendly_hash, get_subset_table_name) +from triage.util.hash import filename_friendly_hash from triage.component.results_schema import Subset + +def get_subset_table_name(subset_config): + return "subset_{}_{}".format( + subset_config.get("name", "default"), + filename_friendly_hash(subset_config), + ) + + class Subsetter(object): def __init__( self, diff --git a/src/triage/component/catwalk/utils.py b/src/triage/component/catwalk/utils.py deleted file mode 100644 index 465a9e50f..000000000 --- a/src/triage/component/catwalk/utils.py +++ /dev/null @@ -1,261 +0,0 @@ -import csv -import datetime -import hashlib -import numpy -import json -import logging -import random -from itertools import chain -from functools import partial - -import postgres_copy -import sqlalchemy -from retrying import retry -from sqlalchemy.orm import sessionmaker -from ohio import PipeTextIO - -from triage.component.results_schema import ( - Experiment, - Matrix, - Model, - ExperimentMatrix, - ExperimentModel, -) - - -def filename_friendly_hash(inputs): - def dt_handler(x): - if isinstance(x, datetime.datetime) or isinstance(x, datetime.date): - return x.isoformat() - raise TypeError("Unknown type") - - return hashlib.md5( - json.dumps(inputs, default=dt_handler, sort_keys=True).encode("utf-8") - ).hexdigest() - - -def get_subset_table_name(subset_config): - return "subset_{}_{}".format( - subset_config.get("name", "default"), - filename_friendly_hash(subset_config), - ) - - -def retry_if_db_error(exception): - return isinstance(exception, sqlalchemy.exc.OperationalError) - - -DEFAULT_RETRY_KWARGS = { - "retry_on_exception": retry_if_db_error, - "wait_exponential_multiplier": 1000, # wait 2^x*1000ms between each retry - "stop_max_attempt_number": 14, - # with this configuration, last wait will be ~2 hours - # for a total of ~4.5 hours waiting -} - - -db_retry = retry(**DEFAULT_RETRY_KWARGS) - - -@db_retry -def save_experiment_and_get_hash(config, db_engine): - experiment_hash = filename_friendly_hash(config) - session = sessionmaker(bind=db_engine)() - session.merge(Experiment(experiment_hash=experiment_hash, config=config)) - session.commit() - session.close() - return experiment_hash - - -@db_retry -def associate_matrices_with_experiment(experiment_hash, matrix_uuids, db_engine): - session = sessionmaker(bind=db_engine)() - for matrix_uuid in matrix_uuids: - session.merge(ExperimentMatrix(experiment_hash=experiment_hash, matrix_uuid=matrix_uuid)) - session.commit() - session.close() - logging.info("Associated matrices with experiment in database") - - -@db_retry -def associate_models_with_experiment(experiment_hash, model_hashes, db_engine): - session = sessionmaker(bind=db_engine)() - for model_hash in model_hashes: - session.merge(ExperimentModel(experiment_hash=experiment_hash, model_hash=model_hash)) - session.commit() - session.close() - logging.info("Associated models with experiment in database") - - -@db_retry -def missing_matrix_uuids(experiment_hash, db_engine): - """Compare the contents of the experiment_matrices table with that of the - matrices table to produce a list of matrix_uuids that the experiment wants - but are not available. - """ - query = f""" - select experiment_matrices.matrix_uuid - from {ExperimentMatrix.__table__.fullname} experiment_matrices - left join {Matrix.__table__.fullname} matrices - on (experiment_matrices.matrix_uuid = matrices.matrix_uuid) - where experiment_hash = %s - and matrices.matrix_uuid is null - """ - return [row[0] for row in db_engine.execute(query, experiment_hash)] - - -@db_retry -def missing_model_hashes(experiment_hash, db_engine): - """Compare the contents of the experiment_models table with that of the - models table to produce a list of model hashes the experiment wants - but are not available. - """ - query = f""" - select experiment_models.model_hash - from {ExperimentModel.__table__.fullname} experiment_models - left join {Model.__table__.fullname} models - on (experiment_models.model_hash = models.model_hash) - where experiment_hash = %s - and models.model_hash is null - """ - return [row[0] for row in db_engine.execute(query, experiment_hash)] - - -class Batch: - # modified from - # http://codereview.stackexchange.com/questions/118883/split-up-an-iterable-into-batches - def __init__(self, iterable, limit=None): - self.iterator = iter(iterable) - self.limit = limit - try: - self.current = next(self.iterator) - except StopIteration: - self.on_going = False - else: - self.on_going = True - - def group(self): - yield self.current - # start enumerate at 1 because we already yielded the last saved item - for num, item in enumerate(self.iterator, 1): - self.current = item - if num == self.limit: - break - yield item - else: - self.on_going = False - - def __iter__(self): - while self.on_going: - yield self.group() - - -AVAILABLE_TIEBREAKERS = {'random', 'best', 'worst'} - -def sort_predictions_and_labels(predictions_proba, labels, tiebreaker='random', sort_seed=None, parallel_arrays=()): - """Sort predictions and labels with a configured tiebreaking rule - - Args: - predictions_proba (numpy.array) The predicted scores - labels (numpy.array) The numeric labels (1/0, not True/False) - tiebreaker (string) The tiebreaking method ('best', 'worst', 'random') - sort_seed (signed int) The sort seed. Needed if 'random' tiebreaking is picked. - parallel_arrays (tuple of numpy.array) Any other arrays, understood to be the same size - as the predictions and labels, that should be sorted alongside them. - - Returns: - (tuple) (predictions_proba, labels), sorted - """ - if len(labels) == 0: - logging.debug("No labels present, skipping sorting.") - if parallel_arrays: - return (predictions_proba, labels, parallel_arrays) - else: - return (predictions_proba, labels) - mask = None - if tiebreaker == 'random': - if not sort_seed: - raise ValueError("If random tiebreaker is used, a sort seed must be given") - random.seed(sort_seed) - numpy.random.seed(sort_seed) - random_arr = numpy.random.rand(*predictions_proba.shape) - mask = numpy.lexsort((random_arr, predictions_proba)) - elif tiebreaker == 'worst': - mask = numpy.lexsort((-labels, predictions_proba)) - elif tiebreaker == 'best': - mask = numpy.lexsort((labels, predictions_proba)) - else: - raise ValueError("Unknown tiebreaker") - - return_value = [ - numpy.flip(predictions_proba[mask]), - numpy.flip(labels[mask]), - ] - if parallel_arrays: - return_value.append(tuple(numpy.flip(arr[mask]) for arr in parallel_arrays)) - return return_value - -@db_retry -def retrieve_model_id_from_hash(db_engine, model_hash): - """Retrieves a model id from the database that matches the given hash - - Args: - db_engine (sqlalchemy.engine) A database engine - model_hash (str) The model hash to lookup - - Returns: (int) The model id (if found in DB), None (if not) - """ - session = sessionmaker(bind=db_engine)() - try: - saved = session.query(Model).filter_by(model_hash=model_hash).one_or_none() - return saved.model_id if saved else None - finally: - session.close() - - -@db_retry -def retrieve_model_hash_from_id(db_engine, model_id): - """Retrieves the model hash associated with a given model id - - Args: - model_id (int) The id of a given model in the database - - Returns: (str) the stored hash of the model - """ - session = sessionmaker(bind=db_engine)() - try: - return session.query(Model).get(model_id).model_hash - finally: - session.close() - - -def _write_csv(file_like, db_objects, type_of_object): - writer = csv.writer(file_like, quoting=csv.QUOTE_MINIMAL, lineterminator='\n') - for db_object in db_objects: - if type(db_object) != type_of_object: - raise TypeError("Cannot copy collection of objects to db as they are not all " - f"of the same type. First object was {type_of_object} " - f"and later encountered a {type(db_object)}") - writer.writerow( - [getattr(db_object, col.name) for col in db_object.__table__.columns] - ) - - -@db_retry -def save_db_objects(db_engine, db_objects): - """Saves a collection of SQLAlchemy model objects to the database using a COPY command - - Args: - db_engine (sqlalchemy.engine) - db_objects (iterable) SQLAlchemy model objects, corresponding to a valid table - """ - db_objects = iter(db_objects) - first_object = next(db_objects) - type_of_object = type(first_object) - - with PipeTextIO(partial( - _write_csv, - db_objects=chain((first_object,), db_objects), - type_of_object=type_of_object - )) as pipe: - postgres_copy.copy_from(pipe, type_of_object, db_engine, format="csv") diff --git a/src/triage/component/collate/collate.py b/src/triage/component/collate/collate.py index 06f916c4b..036b2a908 100644 --- a/src/triage/component/collate/collate.py +++ b/src/triage/component/collate/collate.py @@ -5,6 +5,7 @@ import sqlalchemy.sql.expression as ex import re from descriptors import cachedproperty +from triage.util.conf import make_list from .sql import make_sql_clause, to_sql_name, CreateTableAs, InsertFromSelect from .imputations import ( @@ -28,10 +29,6 @@ } -def make_list(a): - return [a] if not isinstance(a, list) else a - - def make_tuple(a): return (a,) if not isinstance(a, tuple) else a diff --git a/src/triage/component/results_schema/utils.py b/src/triage/component/results_schema/utils.py new file mode 100644 index 000000000..1888dea2b --- /dev/null +++ b/src/triage/component/results_schema/utils.py @@ -0,0 +1,110 @@ +import logging +from sqlalchemy.orm import sessionmaker +from triage.util.db import db_retry +from triage.util.has import filename_friendly_hash + +from .schema import ( + Experiment, + Matrix, + Model, + ExperimentMatrix, + ExperimentModel, +) + + +@db_retry +def save_experiment_and_get_hash(config, db_engine): + experiment_hash = filename_friendly_hash(config) + session = sessionmaker(bind=db_engine)() + session.merge(Experiment(experiment_hash=experiment_hash, config=config)) + session.commit() + session.close() + return experiment_hash + + +@db_retry +def associate_matrices_with_experiment(experiment_hash, matrix_uuids, db_engine): + session = sessionmaker(bind=db_engine)() + for matrix_uuid in matrix_uuids: + session.merge(ExperimentMatrix(experiment_hash=experiment_hash, matrix_uuid=matrix_uuid)) + session.commit() + session.close() + logging.info("Associated matrices with experiment in database") + + +@db_retry +def associate_models_with_experiment(experiment_hash, model_hashes, db_engine): + session = sessionmaker(bind=db_engine)() + for model_hash in model_hashes: + session.merge(ExperimentModel(experiment_hash=experiment_hash, model_hash=model_hash)) + session.commit() + session.close() + logging.info("Associated models with experiment in database") + + +@db_retry +def missing_matrix_uuids(experiment_hash, db_engine): + """Compare the contents of the experiment_matrices table with that of the + matrices table to produce a list of matrix_uuids that the experiment wants + but are not available. + """ + query = f""" + select experiment_matrices.matrix_uuid + from {ExperimentMatrix.__table__.fullname} experiment_matrices + left join {Matrix.__table__.fullname} matrices + on (experiment_matrices.matrix_uuid = matrices.matrix_uuid) + where experiment_hash = %s + and matrices.matrix_uuid is null + """ + return [row[0] for row in db_engine.execute(query, experiment_hash)] + + +@db_retry +def missing_model_hashes(experiment_hash, db_engine): + """Compare the contents of the experiment_models table with that of the + models table to produce a list of model hashes the experiment wants + but are not available. + """ + query = f""" + select experiment_models.model_hash + from {ExperimentModel.__table__.fullname} experiment_models + left join {Model.__table__.fullname} models + on (experiment_models.model_hash = models.model_hash) + where experiment_hash = %s + and models.model_hash is null + """ + return [row[0] for row in db_engine.execute(query, experiment_hash)] + + +@db_retry +def retrieve_model_id_from_hash(db_engine, model_hash): + """Retrieves a model id from the database that matches the given hash + + Args: + db_engine (sqlalchemy.engine) A database engine + model_hash (str) The model hash to lookup + + Returns: (int) The model id (if found in DB), None (if not) + """ + session = sessionmaker(bind=db_engine)() + try: + saved = session.query(Model).filter_by(model_hash=model_hash).one_or_none() + return saved.model_id if saved else None + finally: + session.close() + + +@db_retry +def retrieve_model_hash_from_id(db_engine, model_id): + """Retrieves the model hash associated with a given model id + + Args: + model_id (int) The id of a given model in the database + + Returns: (str) the stored hash of the model + """ + session = sessionmaker(bind=db_engine)() + try: + return session.query(Model).get(model_id).model_hash + finally: + session.close() diff --git a/src/triage/experiments/validate.py b/src/triage/experiments/validate.py index a9932dcc7..764057244 100644 --- a/src/triage/experiments/validate.py +++ b/src/triage/experiments/validate.py @@ -1,7 +1,6 @@ import importlib import logging from itertools import permutations -from datetime import datetime from textwrap import dedent from sklearn.model_selection import ParameterGrid @@ -10,7 +9,7 @@ from triage.component import catwalk from triage.component.timechop import Timechop -from triage.util.conf import convert_str_to_relativedelta +from triage.util.conf import convert_str_to_relativedelta, dt_from_str from triage.validation_primitives import string_is_tablesafe @@ -34,9 +33,6 @@ def run(self, *args, **kwargs): class TemporalValidator(Validator): def _run(self, temporal_config): - def dt_from_str(dt_str): - return datetime.strptime(dt_str, "%Y-%m-%d") - splits = [] try: chopper = Timechop( diff --git a/src/triage/util/conf.py b/src/triage/util/conf.py index be4552776..e48ba6fce 100644 --- a/src/triage/util/conf.py +++ b/src/triage/util/conf.py @@ -4,6 +4,10 @@ from datetime import datetime +def make_list(a): + return [a] if not isinstance(a, list) else a + + def dt_from_str(dt_str): if isinstance(dt_str, datetime): return dt_str diff --git a/src/triage/util/db.py b/src/triage/util/db.py index 585622868..c80dc94a5 100644 --- a/src/triage/util/db.py +++ b/src/triage/util/db.py @@ -1,11 +1,18 @@ # coding: utf-8 +import csv import sqlalchemy import wrapt from contextlib import contextmanager from sqlalchemy.orm import Session from sqlalchemy.engine.url import make_url +from retrying import retry +from functools import partial +from itertools import chain +import postgres_copy + +from ohio import PipeTextIO class SerializableDbEngine(wrapt.ObjectProxy): """A sqlalchemy engine that can be serialized across process boundaries. @@ -58,3 +65,51 @@ def get_for_update(db_engine, orm_class, primary_key): obj = session.query(orm_class).get(primary_key) yield obj session.merge(obj) + + +def retry_if_db_error(exception): + return isinstance(exception, sqlalchemy.exc.OperationalError) + + +DEFAULT_RETRY_KWARGS = { + "retry_on_exception": retry_if_db_error, + "wait_exponential_multiplier": 1000, # wait 2^x*1000ms between each retry + "stop_max_attempt_number": 14, + # with this configuration, last wait will be ~2 hours + # for a total of ~4.5 hours waiting +} + + +db_retry = retry(**DEFAULT_RETRY_KWARGS) + + +def _write_csv(file_like, db_objects, type_of_object): + writer = csv.writer(file_like, quoting=csv.QUOTE_MINIMAL, lineterminator='\n') + for db_object in db_objects: + if type(db_object) != type_of_object: + raise TypeError("Cannot copy collection of objects to db as they are not all " + f"of the same type. First object was {type_of_object} " + f"and later encountered a {type(db_object)}") + writer.writerow( + [getattr(db_object, col.name) for col in db_object.__table__.columns] + ) + + +@db_retry +def save_db_objects(db_engine, db_objects): + """Saves a collection of SQLAlchemy model objects to the database using a COPY command + + Args: + db_engine (sqlalchemy.engine) + db_objects (iterable) SQLAlchemy model objects, corresponding to a valid table + """ + db_objects = iter(db_objects) + first_object = next(db_objects) + type_of_object = type(first_object) + + with PipeTextIO(partial( + _write_csv, + db_objects=chain((first_object,), db_objects), + type_of_object=type_of_object + )) as pipe: + postgres_copy.copy_from(pipe, type_of_object, db_engine, format="csv") diff --git a/src/triage/util/hash.py b/src/triage/util/hash.py new file mode 100644 index 000000000..2311df9e6 --- /dev/null +++ b/src/triage/util/hash.py @@ -0,0 +1,14 @@ +import datetime +import hashlib +import json + + +def filename_friendly_hash(inputs): + def dt_handler(x): + if isinstance(x, datetime.datetime) or isinstance(x, datetime.date): + return x.isoformat() + raise TypeError("Unknown type") + + return hashlib.md5( + json.dumps(inputs, default=dt_handler, sort_keys=True).encode("utf-8") + ).hexdigest() diff --git a/src/triage/util/iteration.py b/src/triage/util/iteration.py new file mode 100644 index 000000000..38fccc855 --- /dev/null +++ b/src/triage/util/iteration.py @@ -0,0 +1,27 @@ +class Batch: + # modified from + # http://codereview.stackexchange.com/questions/118883/split-up-an-iterable-into-batches + def __init__(self, iterable, limit=None): + self.iterator = iter(iterable) + self.limit = limit + try: + self.current = next(self.iterator) + except StopIteration: + self.on_going = False + else: + self.on_going = True + + def group(self): + yield self.current + # start enumerate at 1 because we already yielded the last saved item + for num, item in enumerate(self.iterator, 1): + self.current = item + if num == self.limit: + break + yield item + else: + self.on_going = False + + def __iter__(self): + while self.on_going: + yield self.group() diff --git a/src/triage/util/sorting.py b/src/triage/util/sorting.py new file mode 100644 index 000000000..95eecd078 --- /dev/null +++ b/src/triage/util/sorting.py @@ -0,0 +1,49 @@ +import logging +import numpy +import random + + +AVAILABLE_TIEBREAKERS = {'random', 'best', 'worst'} + +def sort_predictions_and_labels(predictions_proba, labels, tiebreaker='random', sort_seed=None, parallel_arrays=()): + """Sort predictions and labels with a configured tiebreaking rule + + Args: + predictions_proba (numpy.array) The predicted scores + labels (numpy.array) The numeric labels (1/0, not True/False) + tiebreaker (string) The tiebreaking method ('best', 'worst', 'random') + sort_seed (signed int) The sort seed. Needed if 'random' tiebreaking is picked. + parallel_arrays (tuple of numpy.array) Any other arrays, understood to be the same size + as the predictions and labels, that should be sorted alongside them. + + Returns: + (tuple) (predictions_proba, labels), sorted + """ + if len(labels) == 0: + logging.debug("No labels present, skipping sorting.") + if parallel_arrays: + return (predictions_proba, labels, parallel_arrays) + else: + return (predictions_proba, labels) + mask = None + if tiebreaker == 'random': + if not sort_seed: + raise ValueError("If random tiebreaker is used, a sort seed must be given") + random.seed(sort_seed) + numpy.random.seed(sort_seed) + random_arr = numpy.random.rand(*predictions_proba.shape) + mask = numpy.lexsort((random_arr, predictions_proba)) + elif tiebreaker == 'worst': + mask = numpy.lexsort((-labels, predictions_proba)) + elif tiebreaker == 'best': + mask = numpy.lexsort((labels, predictions_proba)) + else: + raise ValueError("Unknown tiebreaker") + + return_value = [ + numpy.flip(predictions_proba[mask]), + numpy.flip(labels[mask]), + ] + if parallel_arrays: + return_value.append(tuple(numpy.flip(arr[mask]) for arr in parallel_arrays)) + return return_value diff --git a/src/triage/util/sql.py b/src/triage/util/sql.py new file mode 100644 index 000000000..cc5de18a7 --- /dev/null +++ b/src/triage/util/sql.py @@ -0,0 +1,3 @@ +def str_in_sql(values): + """Create SQL suitable for the content of an IN clause from a list""" + return ",".join(map(lambda x: "'{}'".format(x), values)) From 972964164fb746d8ec4c35da220e208ca5b4406d Mon Sep 17 00:00:00 2001 From: Tristan Crockett Date: Thu, 9 May 2019 10:49:21 -0500 Subject: [PATCH 5/5] Reorganization --- src/tests/architect_tests/test_builders.py | 2 +- src/tests/catwalk_tests/test_evaluation.py | 3 +- src/tests/catwalk_tests/test_integration.py | 2 +- src/tests/catwalk_tests/test_ranking.py | 48 ++++++ src/tests/catwalk_tests/test_utils.py | 155 ------------------ src/tests/results_tests/test_utils.py | 66 ++++++++ src/tests/test_utils_hash.py | 44 +++++ src/tests/utils.py | 2 +- src/triage/component/architect/planner.py | 2 +- src/triage/component/catwalk/__init__.py | 2 +- src/triage/component/catwalk/evaluation.py | 11 +- .../catwalk/individual_importance/__init__.py | 2 +- .../component/catwalk/model_trainers.py | 11 +- src/triage/component/catwalk/predictors.py | 5 +- .../catwalk/ranking.py} | 0 src/triage/component/results_schema/utils.py | 2 +- src/triage/experiments/base.py | 4 +- src/triage/experiments/multicore.py | 2 +- src/triage/experiments/rq.py | 2 +- src/triage/experiments/validate.py | 4 +- 20 files changed, 184 insertions(+), 185 deletions(-) create mode 100644 src/tests/catwalk_tests/test_ranking.py create mode 100644 src/tests/results_tests/test_utils.py create mode 100644 src/tests/test_utils_hash.py rename src/triage/{util/sorting.py => component/catwalk/ranking.py} (100%) diff --git a/src/tests/architect_tests/test_builders.py b/src/tests/architect_tests/test_builders.py index 100af0dc6..3600abd04 100644 --- a/src/tests/architect_tests/test_builders.py +++ b/src/tests/architect_tests/test_builders.py @@ -7,7 +7,7 @@ from sqlalchemy import create_engine from contextlib import contextmanager -from triage.component.catwalk.utils import filename_friendly_hash +from triage.util.hash import filename_friendly_hash from triage.component.architect.feature_group_creator import FeatureGroup from triage.component.architect.builders import MatrixBuilder from triage.component.catwalk.db import ensure_db diff --git a/src/tests/catwalk_tests/test_evaluation.py b/src/tests/catwalk_tests/test_evaluation.py index 6003db906..97e020987 100644 --- a/src/tests/catwalk_tests/test_evaluation.py +++ b/src/tests/catwalk_tests/test_evaluation.py @@ -6,6 +6,8 @@ subset_labels_and_predictions, ) from triage.component.catwalk.metrics import Metric +from triage.component.catwalk.subsetters import get_subset_table_name +from triage.util.hash import filename_friendly_hash import testing.postgresql import datetime import re @@ -15,7 +17,6 @@ from numpy.testing import assert_almost_equal, assert_array_equal import pandas from sqlalchemy.sql.expression import text -from triage.component.catwalk.utils import filename_friendly_hash, get_subset_table_name from tests.utils import fake_labels, fake_trained_model, MockMatrixStore from tests.results_tests.factories import ( ModelFactory, diff --git a/src/tests/catwalk_tests/test_integration.py b/src/tests/catwalk_tests/test_integration.py index 687f5a5e1..22fda7dc0 100644 --- a/src/tests/catwalk_tests/test_integration.py +++ b/src/tests/catwalk_tests/test_integration.py @@ -1,5 +1,5 @@ from triage.component.catwalk import ModelTrainTester, Predictor, ModelTrainer, ModelEvaluator, IndividualImportanceCalculator -from triage.component.catwalk.utils import save_experiment_and_get_hash +from triage.component.results_schema.utils import save_experiment_and_get_hash from triage.component.catwalk.model_trainers import flatten_grid_config from triage.component.catwalk.storage import ( ModelStorageEngine, diff --git a/src/tests/catwalk_tests/test_ranking.py b/src/tests/catwalk_tests/test_ranking.py new file mode 100644 index 000000000..be4eef991 --- /dev/null +++ b/src/tests/catwalk_tests/test_ranking.py @@ -0,0 +1,48 @@ +import numpy +from numpy.testing import assert_array_equal +import pytest + +from triage.component.catwalk.ranking import sort_predictions_and_labels + + +def test_sort_predictions_and_labels(): + predictions = numpy.array([0.5, 0.4, 0.6, 0.5]) + + labels = numpy.array([0, 0, 1, 1]) + + # best sort + sorted_predictions, sorted_labels = sort_predictions_and_labels( + predictions, labels, tiebreaker='best' + ) + assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4])) + assert_array_equal(sorted_labels, numpy.array([1, 1, 0, 0])) + + # worst wort + sorted_predictions, sorted_labels = sort_predictions_and_labels( + predictions, labels, tiebreaker='worst' + ) + assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4])) + assert_array_equal(sorted_labels, numpy.array([1, 0, 1, 0])) + + # random tiebreaker needs a seed + with pytest.raises(ValueError): + sort_predictions_and_labels(predictions, labels, tiebreaker='random') + + # random tiebreaker respects the seed + sorted_predictions, sorted_labels = sort_predictions_and_labels( + predictions, + labels, + tiebreaker='random', + sort_seed=1234 + ) + assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4])) + assert_array_equal(sorted_labels, numpy.array([1, 1, 0, 0])) + + sorted_predictions, sorted_labels = sort_predictions_and_labels( + predictions, + labels, + tiebreaker='random', + sort_seed=24376234 + ) + assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4])) + assert_array_equal(sorted_labels, numpy.array([1, 0, 1, 0])) diff --git a/src/tests/catwalk_tests/test_utils.py b/src/tests/catwalk_tests/test_utils.py index 6f0f03136..e69de29bb 100644 --- a/src/tests/catwalk_tests/test_utils.py +++ b/src/tests/catwalk_tests/test_utils.py @@ -1,155 +0,0 @@ -from triage.component.catwalk.utils import ( - filename_friendly_hash, - save_experiment_and_get_hash, - associate_models_with_experiment, - associate_matrices_with_experiment, - missing_model_hashes, - missing_matrix_uuids, - sort_predictions_and_labels, -) -from triage.component.results_schema.schema import Matrix, Model -from triage.component.catwalk.db import ensure_db -from sqlalchemy import create_engine -import testing.postgresql -import datetime -import re -import numpy -from numpy.testing import assert_array_equal -import pytest - - -def test_filename_friendly_hash(): - data = { - "stuff": "stuff", - "other_stuff": "more_stuff", - "a_datetime": datetime.datetime(2015, 1, 1), - "a_date": datetime.date(2016, 1, 1), - "a_number": 5.0, - } - output = filename_friendly_hash(data) - assert isinstance(output, str) - assert re.match("^[\w]+$", output) is not None - - # make sure ordering keys differently doesn't change the hash - new_output = filename_friendly_hash( - { - "other_stuff": "more_stuff", - "stuff": "stuff", - "a_datetime": datetime.datetime(2015, 1, 1), - "a_date": datetime.date(2016, 1, 1), - "a_number": 5.0, - } - ) - assert new_output == output - - # make sure new data hashes to something different - new_output = filename_friendly_hash({"stuff": "stuff", "a_number": 5.0}) - assert new_output != output - - -def test_filename_friendly_hash_stability(): - nested_data = {"one": "two", "three": {"four": "five", "six": "seven"}} - output = filename_friendly_hash(nested_data) - # 1. we want to make sure this is stable across different runs - # so hardcode an expected value - assert output == "9a844a7ebbfd821010b1c2c13f7391e6" - other_nested_data = {"one": "two", "three": {"six": "seven", "four": "five"}} - new_output = filename_friendly_hash(other_nested_data) - assert output == new_output - - -def test_save_experiment_and_get_hash(): - # no reason to make assertions on the config itself, use a basic dict - experiment_config = {"one": "two"} - with testing.postgresql.Postgresql() as postgresql: - engine = create_engine(postgresql.url()) - ensure_db(engine) - exp_hash = save_experiment_and_get_hash(experiment_config, engine) - assert isinstance(exp_hash, str) - new_hash = save_experiment_and_get_hash(experiment_config, engine) - assert new_hash == exp_hash - - -def test_missing_model_hashes(): - with testing.postgresql.Postgresql() as postgresql: - db_engine = create_engine(postgresql.url()) - ensure_db(db_engine) - - experiment_hash = save_experiment_and_get_hash({}, db_engine) - model_hashes = ['abcd', 'bcde', 'cdef'] - - # if we associate model hashes with an experiment but don't actually train the models - # they should show up as missing - associate_models_with_experiment(experiment_hash, model_hashes, db_engine) - assert missing_model_hashes(experiment_hash, db_engine) == model_hashes - - # if we insert a model row they should no longer be considered missing - db_engine.execute( - f"insert into {Model.__table__.fullname} (model_hash) values (%s)", - model_hashes[0] - ) - assert missing_model_hashes(experiment_hash, db_engine) == model_hashes[1:] - - -def test_missing_matrix_uuids(): - with testing.postgresql.Postgresql() as postgresql: - db_engine = create_engine(postgresql.url()) - ensure_db(db_engine) - - experiment_hash = save_experiment_and_get_hash({}, db_engine) - matrix_uuids = ['abcd', 'bcde', 'cdef'] - - # if we associate matrix uuids with an experiment but don't actually build the matrices - # they should show up as missing - associate_matrices_with_experiment(experiment_hash, matrix_uuids, db_engine) - assert missing_matrix_uuids(experiment_hash, db_engine) == matrix_uuids - - # if we insert a matrix row they should no longer be considered missing - db_engine.execute( - f"insert into {Matrix.__table__.fullname} (matrix_uuid) values (%s)", - matrix_uuids[0] - ) - assert missing_matrix_uuids(experiment_hash, db_engine) == matrix_uuids[1:] - - -def test_sort_predictions_and_labels(): - predictions = numpy.array([0.5, 0.4, 0.6, 0.5]) - - labels = numpy.array([0, 0, 1, 1]) - - # best sort - sorted_predictions, sorted_labels = sort_predictions_and_labels( - predictions, labels, tiebreaker='best' - ) - assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4])) - assert_array_equal(sorted_labels, numpy.array([1, 1, 0, 0])) - - # worst wort - sorted_predictions, sorted_labels = sort_predictions_and_labels( - predictions, labels, tiebreaker='worst' - ) - assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4])) - assert_array_equal(sorted_labels, numpy.array([1, 0, 1, 0])) - - # random tiebreaker needs a seed - with pytest.raises(ValueError): - sort_predictions_and_labels(predictions, labels, tiebreaker='random') - - # random tiebreaker respects the seed - sorted_predictions, sorted_labels = sort_predictions_and_labels( - predictions, - labels, - tiebreaker='random', - sort_seed=1234 - ) - assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4])) - assert_array_equal(sorted_labels, numpy.array([1, 1, 0, 0])) - - sorted_predictions, sorted_labels = sort_predictions_and_labels( - predictions, - labels, - tiebreaker='random', - sort_seed=24376234 - ) - assert_array_equal(sorted_predictions, numpy.array([0.6, 0.5, 0.5, 0.4])) - assert_array_equal(sorted_labels, numpy.array([1, 0, 1, 0])) diff --git a/src/tests/results_tests/test_utils.py b/src/tests/results_tests/test_utils.py new file mode 100644 index 000000000..777fd3883 --- /dev/null +++ b/src/tests/results_tests/test_utils.py @@ -0,0 +1,66 @@ +from triage.component.results_schema.utils import ( + save_experiment_and_get_hash, + associate_models_with_experiment, + associate_matrices_with_experiment, + missing_model_hashes, + missing_matrix_uuids, +) + +from triage.component.results_schema.schema import Matrix, Model +from triage.component.catwalk.db import ensure_db +from sqlalchemy import create_engine +import testing.postgresql + + +def test_save_experiment_and_get_hash(): + # no reason to make assertions on the config itself, use a basic dict + experiment_config = {"one": "two"} + with testing.postgresql.Postgresql() as postgresql: + engine = create_engine(postgresql.url()) + ensure_db(engine) + exp_hash = save_experiment_and_get_hash(experiment_config, engine) + assert isinstance(exp_hash, str) + new_hash = save_experiment_and_get_hash(experiment_config, engine) + assert new_hash == exp_hash + + +def test_missing_model_hashes(): + with testing.postgresql.Postgresql() as postgresql: + db_engine = create_engine(postgresql.url()) + ensure_db(db_engine) + + experiment_hash = save_experiment_and_get_hash({}, db_engine) + model_hashes = ['abcd', 'bcde', 'cdef'] + + # if we associate model hashes with an experiment but don't actually train the models + # they should show up as missing + associate_models_with_experiment(experiment_hash, model_hashes, db_engine) + assert missing_model_hashes(experiment_hash, db_engine) == model_hashes + + # if we insert a model row they should no longer be considered missing + db_engine.execute( + f"insert into {Model.__table__.fullname} (model_hash) values (%s)", + model_hashes[0] + ) + assert missing_model_hashes(experiment_hash, db_engine) == model_hashes[1:] + + +def test_missing_matrix_uuids(): + with testing.postgresql.Postgresql() as postgresql: + db_engine = create_engine(postgresql.url()) + ensure_db(db_engine) + + experiment_hash = save_experiment_and_get_hash({}, db_engine) + matrix_uuids = ['abcd', 'bcde', 'cdef'] + + # if we associate matrix uuids with an experiment but don't actually build the matrices + # they should show up as missing + associate_matrices_with_experiment(experiment_hash, matrix_uuids, db_engine) + assert missing_matrix_uuids(experiment_hash, db_engine) == matrix_uuids + + # if we insert a matrix row they should no longer be considered missing + db_engine.execute( + f"insert into {Matrix.__table__.fullname} (matrix_uuid) values (%s)", + matrix_uuids[0] + ) + assert missing_matrix_uuids(experiment_hash, db_engine) == matrix_uuids[1:] diff --git a/src/tests/test_utils_hash.py b/src/tests/test_utils_hash.py new file mode 100644 index 000000000..58fcee439 --- /dev/null +++ b/src/tests/test_utils_hash.py @@ -0,0 +1,44 @@ +import datetime +import re + +from triage.util.hash import filename_friendly_hash + + +def test_filename_friendly_hash(): + data = { + "stuff": "stuff", + "other_stuff": "more_stuff", + "a_datetime": datetime.datetime(2015, 1, 1), + "a_date": datetime.date(2016, 1, 1), + "a_number": 5.0, + } + output = filename_friendly_hash(data) + assert isinstance(output, str) + assert re.match("^[\w]+$", output) is not None + + # make sure ordering keys differently doesn't change the hash + new_output = filename_friendly_hash( + { + "other_stuff": "more_stuff", + "stuff": "stuff", + "a_datetime": datetime.datetime(2015, 1, 1), + "a_date": datetime.date(2016, 1, 1), + "a_number": 5.0, + } + ) + assert new_output == output + + # make sure new data hashes to something different + new_output = filename_friendly_hash({"stuff": "stuff", "a_number": 5.0}) + assert new_output != output + + +def test_filename_friendly_hash_stability(): + nested_data = {"one": "two", "three": {"four": "five", "six": "seven"}} + output = filename_friendly_hash(nested_data) + # 1. we want to make sure this is stable across different runs + # so hardcode an expected value + assert output == "9a844a7ebbfd821010b1c2c13f7391e6" + other_nested_data = {"one": "two", "three": {"six": "seven", "four": "five"}} + new_output = filename_friendly_hash(other_nested_data) + assert output == new_output diff --git a/src/tests/utils.py b/src/tests/utils.py index 07372034d..cb84997a0 100644 --- a/src/tests/utils.py +++ b/src/tests/utils.py @@ -9,7 +9,7 @@ from sqlalchemy import create_engine import testing.postgresql from triage.component.catwalk.db import ensure_db -from triage.component.catwalk.utils import filename_friendly_hash +from triage.util.hash import filename_friendly_hash from triage.component.catwalk.storage import MatrixStore, ProjectStorage from triage.component.results_schema import Model, Matrix from triage.experiments import CONFIG_VERSION diff --git a/src/triage/component/architect/planner.py b/src/triage/component/architect/planner.py index 3a2220895..ed02c2339 100644 --- a/src/triage/component/architect/planner.py +++ b/src/triage/component/architect/planner.py @@ -2,7 +2,7 @@ import itertools import logging -from triage.component.catwalk.utils import filename_friendly_hash +from triage.util.hash import filename_friendly_hash from . import utils, entity_date_table_generators diff --git a/src/triage/component/catwalk/__init__.py b/src/triage/component/catwalk/__init__.py index 0dc709bda..a811eda04 100644 --- a/src/triage/component/catwalk/__init__.py +++ b/src/triage/component/catwalk/__init__.py @@ -5,7 +5,7 @@ from .individual_importance import IndividualImportanceCalculator from .model_grouping import ModelGrouper from .subsetters import Subsetter -from .utils import filename_friendly_hash +from triage.util.hash import filename_friendly_hash import logging from collections import namedtuple diff --git a/src/triage/component/catwalk/evaluation.py b/src/triage/component/catwalk/evaluation.py index 9452669ac..2a0b03f9c 100644 --- a/src/triage/component/catwalk/evaluation.py +++ b/src/triage/component/catwalk/evaluation.py @@ -12,13 +12,10 @@ from sqlalchemy.orm import sessionmaker from . import metrics -from .utils import ( - db_retry, - sort_predictions_and_labels, - get_subset_table_name, - filename_friendly_hash -) -from triage.util.db import scoped_session +from triage.component.catwalk.ranking import sort_predictions_and_labels +from triage.component.catwalk.subsetters import get_subset_table_name +from triage.util.db import scoped_session, db_retry +from triage.util.hash import filename_friendly_hash from triage.util.random import generate_python_random_seed from triage.component.catwalk.storage import MatrixStore diff --git a/src/triage/component/catwalk/individual_importance/__init__.py b/src/triage/component/catwalk/individual_importance/__init__.py index 7d23ae08a..af881170e 100644 --- a/src/triage/component/catwalk/individual_importance/__init__.py +++ b/src/triage/component/catwalk/individual_importance/__init__.py @@ -1,6 +1,6 @@ import logging -from triage.component.catwalk.utils import save_db_objects +from triage.util.db import save_db_objects from triage.component.results_schema import IndividualImportance from .uniform import uniform_distribution diff --git a/src/triage/component/catwalk/model_trainers.py b/src/triage/component/catwalk/model_trainers.py index b4e4061e2..8e544a7f2 100644 --- a/src/triage/component/catwalk/model_trainers.py +++ b/src/triage/component/catwalk/model_trainers.py @@ -11,19 +11,16 @@ from sklearn.model_selection import ParameterGrid from sqlalchemy.orm import sessionmaker +from triage.util.db import db_retry, save_db_objects +from triage.util.hash import filename_friendly_hash from triage.util.random import generate_python_random_seed -from triage.component.results_schema import Model, FeatureImportance from triage.component.catwalk.exceptions import BaselineFeatureNotInMatrix +from triage.component.results_schema import Model, FeatureImportance +from triage.component.results_schema.utils import retrieve_model_id_from_hash from triage.tracking import built_model, skipped_model, errored_model from .model_grouping import ModelGrouper from .feature_importances import get_feature_importances -from .utils import ( - filename_friendly_hash, - retrieve_model_id_from_hash, - db_retry, - save_db_objects, -) NO_FEATURE_IMPORTANCE = ( "Algorithm does not support a standard way" + " to calculate feature importance." diff --git a/src/triage/component/catwalk/predictors.py b/src/triage/component/catwalk/predictors.py index 696c84ce1..05763f687 100644 --- a/src/triage/component/catwalk/predictors.py +++ b/src/triage/component/catwalk/predictors.py @@ -5,9 +5,10 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy import or_ -from .utils import db_retry, retrieve_model_hash_from_id, save_db_objects, sort_predictions_and_labels, AVAILABLE_TIEBREAKERS +from triage.component.catwalk.ranking import sort_predictions_and_labels, AVAILABLE_TIEBREAKERS from triage.component.results_schema import Model -from triage.util.db import scoped_session +from triage.component.results_schema.utils import retrieve_model_hash_from_id +from triage.util.db import scoped_session, db_retry, save_db_objects from triage.util.random import generate_python_random_seed import ohio.ext.pandas import pandas diff --git a/src/triage/util/sorting.py b/src/triage/component/catwalk/ranking.py similarity index 100% rename from src/triage/util/sorting.py rename to src/triage/component/catwalk/ranking.py diff --git a/src/triage/component/results_schema/utils.py b/src/triage/component/results_schema/utils.py index 1888dea2b..e193c891d 100644 --- a/src/triage/component/results_schema/utils.py +++ b/src/triage/component/results_schema/utils.py @@ -1,7 +1,7 @@ import logging from sqlalchemy.orm import sessionmaker from triage.util.db import db_retry -from triage.util.has import filename_friendly_hash +from triage.util.hash import filename_friendly_hash from .schema import ( Experiment, diff --git a/src/triage/experiments/base.py b/src/triage/experiments/base.py index 97abceca5..ae0d14daa 100644 --- a/src/triage/experiments/base.py +++ b/src/triage/experiments/base.py @@ -38,13 +38,12 @@ ModelTrainTester, Subsetter ) -from triage.component.catwalk.utils import ( +from triage.component.results_schema.utils import ( save_experiment_and_get_hash, associate_models_with_experiment, associate_matrices_with_experiment, missing_matrix_uuids, missing_model_hashes, - filename_friendly_hash, ) from triage.component.catwalk.storage import ( CSVMatrixStore, @@ -65,6 +64,7 @@ from triage.database_reflection import table_has_data from triage.util.conf import dt_from_str from triage.util.db import get_for_update +from triage.util.hash import filename_friendly_hash from triage.util.introspection import bind_kwargs, classpath diff --git a/src/triage/experiments/multicore.py b/src/triage/experiments/multicore.py index 0371aa326..e8aa7e6cb 100644 --- a/src/triage/experiments/multicore.py +++ b/src/triage/experiments/multicore.py @@ -4,7 +4,7 @@ from pebble import ProcessPool from multiprocessing.reduction import ForkingPickler -from triage.component.catwalk.utils import Batch +from triage.util.iteration import Batch from triage.experiments import ExperimentBase diff --git a/src/triage/experiments/rq.py b/src/triage/experiments/rq.py index 8f4e54776..48702fbeb 100644 --- a/src/triage/experiments/rq.py +++ b/src/triage/experiments/rq.py @@ -1,6 +1,6 @@ import logging import time -from triage.component.catwalk.utils import Batch +from triage.util.iteration import Batch from triage.experiments import ExperimentBase try: diff --git a/src/triage/experiments/validate.py b/src/triage/experiments/validate.py index 764057244..9db1b7431 100644 --- a/src/triage/experiments/validate.py +++ b/src/triage/experiments/validate.py @@ -741,10 +741,10 @@ class PredictionConfigValidator(Validator): def _run(self, prediction_config): rank_tiebreaker = prediction_config.get("rank_tiebreaker", None) # the tiebreaker is optional, so only try and validate if it's there - if rank_tiebreaker and rank_tiebreaker not in catwalk.utils.AVAILABLE_TIEBREAKERS: + if rank_tiebreaker and rank_tiebreaker not in catwalk.ranking.AVAILABLE_TIEBREAKERS: raise ValueError( "Section: prediction - " - f"given tiebreaker must be in {catwalk.utils.AVAILABLE_TIEBREAKERS}" + f"given tiebreaker must be in {catwalk.ranking.AVAILABLE_TIEBREAKERS}" )