diff --git a/cadetrdm/batch_runner.py b/cadetrdm/batch_runner.py index 97fe258e2cf346fdb27d13aa1ba188648cc398b1..434050c2e2cd58b7ef95c255a1c55292af6ed193 100644 --- a/cadetrdm/batch_runner.py +++ b/cadetrdm/batch_runner.py @@ -87,7 +87,7 @@ class Case: output_log = self.study.output_log for log_entry in output_log: if (self.study.current_commit_hash == log_entry.project_repo_commit_hash - and str(hash(self.options)) == log_entry.options_hash): + and self.options.hash == log_entry.options_hash): return True return False diff --git a/cadetrdm/configuration_options.py b/cadetrdm/configuration_options.py index 196bedcb0e391e6c836236012592caf5250a9997..953d4f7e30bde4b64a974e4bb590dc9b14ce7e31 100644 --- a/cadetrdm/configuration_options.py +++ b/cadetrdm/configuration_options.py @@ -57,7 +57,8 @@ class Options(benedict): def load_json_str(cls, string, **loader_kwargs): return cls.loads(string) - def __hash__(self, excluded_keys=None): + @property + def hash(self): excluded_keys = {"commit_message", "push", "debug"} remaining_keys = set(self.keys()) - excluded_keys remaining_dict = {key: self[key] for key in remaining_keys} @@ -69,7 +70,32 @@ class Options(benedict): indent=None, separators=(',', ':'), ) - return int(hashlib.md5(dump.encode('utf-8')).hexdigest(), 16) + + hash_alphabet = "abcdefghjkmnpqrstvwxyz0123456789" + hash_base = len(hash_alphabet) + + def to_base(number, base): + result = "" + while number: + result += hash_alphabet[number % base] + number //= base + return result[::-1] or "0" + + base_16_hash = hashlib.sha1(dump.encode('utf-8')).hexdigest() + base_10_hash = int(base_16_hash, 16) + base_32_hash = to_base(base_10_hash, hash_base) + + return base_32_hash + + def __eq__(self, other): + if not isinstance(other, Options): + try: + other = Options(other) + except TypeError: + print(f"TypeError when casting {other} to Options()") + return NotImplemented + + return self.hash == other.hash if __name__ == '__main__': diff --git a/cadetrdm/wrapper.py b/cadetrdm/wrapper.py index 5b34c1a0b3849aa385478411f1627ae1b009149f..a803e3ee99f279d8aad2f3473f2098a3c072477e 100644 --- a/cadetrdm/wrapper.py +++ b/cadetrdm/wrapper.py @@ -22,12 +22,12 @@ def tracks_results(func): if key not in options or options[key] is None: raise ValueError(f"Key {key} not found in options. Please supply options.{key}") - if hash(options) != hash(Options.load_json_str(options.dump_json_str())): + if options.hash != Options.load_json_str(options.dump_json_str()).hash: raise ValueError("Options are not serializable. Please only use python natives and numpy ndarrays.") project_repo = ProjectRepo(repo_path) - project_repo.options_hash = hash(options) + project_repo.options_hash = options.hash with project_repo.track_results( options.commit_message, diff --git a/tests/test_configuration_options.py b/tests/test_configuration_options.py index ae0e87a3026e1b612fd8938b206b81ae14da32af..8ce7a7a51a53388061ebebf82f7e31cf4856f74d 100644 --- a/tests/test_configuration_options.py +++ b/tests/test_configuration_options.py @@ -7,19 +7,21 @@ def test_options_hash(): opt = Options() opt["array"] = np.linspace(2, 200) opt["nested_dict"] = {"ba": "foo", "bb": "bar"} - initial_hash = hash(opt) + initial_hash = opt.hash s = opt.dumps() opt_recovered = Options.loads(s) - post_serialization_hash = hash(opt_recovered) + post_serialization_hash = opt_recovered.hash assert initial_hash == post_serialization_hash + assert opt == opt_recovered def test_options_file_io(): opt = Options() opt["array"] = np.linspace(0, 2, 200) opt["nested_dict"] = {"ba": "foo", "bb": "bar"} - initial_hash = hash(opt) + initial_hash = opt.hash opt.dump_json_file("options.json") opt_recovered = Options.load_json_file("options.json") - post_serialization_hash = hash(opt_recovered) + post_serialization_hash = opt_recovered.hash assert initial_hash == post_serialization_hash + assert opt == opt_recovered