Skip to content
Snippets Groups Projects
Commit 7f4c1797 authored by Schiffer, Christian's avatar Schiffer, Christian
Browse files

Implemented the option to do distributed reading of the config file

parent 6eb2b5c6
No related branches found
No related tags found
No related merge requests found
......@@ -154,7 +154,7 @@ def train(
# Load configuration, so it can be accessed by all modules
log.info("Loading configuration from run_dir at {}".format(run_dir))
config = experiments.load_config(run_dir)
config = experiments.load_config(run_dir, distributed_comm=parallel.WORLD_COMM if distribute else None)
train_params = config.make_training_parameters()
diagnostics.add_config(training_parameters=train_params)
......
......@@ -16,18 +16,21 @@ from atlaslib.json import JSONEncoder as JSONEncoder_
# -------------------------------------------------------------------------------------------
def load_config(config):
def load_config(config, distributed_comm=None):
"""
Loads a configuration from a given directory.
Args:
config (str): Path to configuration file or directory containing a file called "config.py".
distributed_comm (MPI communicator): If given, load the configuration on the master process of this communicator and distribute the content. This can reduce load on the file system.
Returns:
Loaded configuration file, which can be accessed like a module.
"""
import sys
import hashlib
from importlib.machinery import SourceFileLoader
from importlib.util import spec_from_loader, module_from_spec
import os
if os.path.isdir(config):
......@@ -36,8 +39,25 @@ def load_config(config):
# Include unique identifier in the module name, so we can multiple configs from different files.
config_path = os.path.realpath(config)
identifier = hashlib.sha256(config_path.encode("utf-8")).hexdigest()
source_file_loader = SourceFileLoader(f"config_{identifier}", config)
return source_file_loader.load_module()
identifier = f"config_{identifier}"
if distributed_comm:
# Read on first process, distribut string, load from string
if distributed_comm.rank == 0:
with open(config, "r") as f:
config_string = f.read()
else:
config_string = None
config_string = distributed_comm.bcast(config_string, root=0)
spec = spec_from_loader(identifier, loader=None)
config = module_from_spec(spec)
exec(config_string, config.__dict__)
sys.modules[identifier] = config
else:
# simply load from file
source_file_loader = SourceFileLoader(identifier, config)
config = source_file_loader.load_module()
return config
def copy_config(src_dir, dst_dir):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment