Skip to content
Snippets Groups Projects
Commit 1fe8c2fe authored by r.jaepel's avatar r.jaepel
Browse files

Replace hash int with str and add __eq__ method to Options

parent d3787117
No related branches found
No related tags found
No related merge requests found
......@@ -87,7 +87,7 @@ class Case:
output_log = self.study.output_log
for log_entry in output_log:
if (self.study.current_commit_hash == log_entry.project_repo_commit_hash
and str(hash(self.options)) == log_entry.options_hash):
and self.options.hash == log_entry.options_hash):
return True
return False
......
......@@ -57,7 +57,8 @@ class Options(benedict):
def load_json_str(cls, string, **loader_kwargs):
return cls.loads(string)
def __hash__(self, excluded_keys=None):
@property
def hash(self):
excluded_keys = {"commit_message", "push", "debug"}
remaining_keys = set(self.keys()) - excluded_keys
remaining_dict = {key: self[key] for key in remaining_keys}
......@@ -69,7 +70,32 @@ class Options(benedict):
indent=None,
separators=(',', ':'),
)
return int(hashlib.md5(dump.encode('utf-8')).hexdigest(), 16)
hash_alphabet = "abcdefghjkmnpqrstvwxyz0123456789"
hash_base = len(hash_alphabet)
def to_base(number, base):
result = ""
while number:
result += hash_alphabet[number % base]
number //= base
return result[::-1] or "0"
base_16_hash = hashlib.sha1(dump.encode('utf-8')).hexdigest()
base_10_hash = int(base_16_hash, 16)
base_32_hash = to_base(base_10_hash, hash_base)
return base_32_hash
def __eq__(self, other):
if not isinstance(other, Options):
try:
other = Options(other)
except TypeError:
print(f"TypeError when casting {other} to Options()")
return NotImplemented
return self.hash == other.hash
if __name__ == '__main__':
......
......@@ -22,12 +22,12 @@ def tracks_results(func):
if key not in options or options[key] is None:
raise ValueError(f"Key {key} not found in options. Please supply options.{key}")
if hash(options) != hash(Options.load_json_str(options.dump_json_str())):
if options.hash != Options.load_json_str(options.dump_json_str()).hash:
raise ValueError("Options are not serializable. Please only use python natives and numpy ndarrays.")
project_repo = ProjectRepo(repo_path)
project_repo.options_hash = hash(options)
project_repo.options_hash = options.hash
with project_repo.track_results(
options.commit_message,
......
......@@ -7,19 +7,21 @@ def test_options_hash():
opt = Options()
opt["array"] = np.linspace(2, 200)
opt["nested_dict"] = {"ba": "foo", "bb": "bar"}
initial_hash = hash(opt)
initial_hash = opt.hash
s = opt.dumps()
opt_recovered = Options.loads(s)
post_serialization_hash = hash(opt_recovered)
post_serialization_hash = opt_recovered.hash
assert initial_hash == post_serialization_hash
assert opt == opt_recovered
def test_options_file_io():
opt = Options()
opt["array"] = np.linspace(0, 2, 200)
opt["nested_dict"] = {"ba": "foo", "bb": "bar"}
initial_hash = hash(opt)
initial_hash = opt.hash
opt.dump_json_file("options.json")
opt_recovered = Options.load_json_file("options.json")
post_serialization_hash = hash(opt_recovered)
post_serialization_hash = opt_recovered.hash
assert initial_hash == post_serialization_hash
assert opt == opt_recovered
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment