From d2b49ecdf2f8d03af35415cc0002d17bd098fad9 Mon Sep 17 00:00:00 2001 From: Christian Schiffer <c.schiffer@fz-juelich.de> Date: Fri, 12 Apr 2024 09:53:57 +0200 Subject: [PATCH] [DATALAD] Recorded changes --- .../src/static/training/config_cytonet.py | 217 ++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 atlas_server/src/static/training/config_cytonet.py diff --git a/atlas_server/src/static/training/config_cytonet.py b/atlas_server/src/static/training/config_cytonet.py new file mode 100644 index 0000000..ddf7843 --- /dev/null +++ b/atlas_server/src/static/training/config_cytonet.py @@ -0,0 +1,217 @@ +def make_training_parameters(**kwargs): + import os + import math + from atlas.configuration import config, constants + from atlas.configuration import data_source as ds + from atlas.data.image_provider import CHANNEL_HISTOLOGY + from atlas.experiments import get_slices_in_range + from atlas.torch.optimizers import WeightDecaySelector + + brain = __BRAIN__ + labels = ["bg", ] + no_aux_labels = __NO_AUX_LABELS__ + if not no_aux_labels: + labels = labels + ["cor", "wm", ] + labels = labels + __LABELS__ + sections = __SECTIONS__ + freeze_encoder = True + weight_file = "/p/fastdata/bigbrains/personal/schiffer1/experiments/2024_cytonet/experiments/atlasui_models/models/cytonet_unet_200k_sigma10.pt" + train_slices = [(brain, s) for s in sections] + + # Slightly increase the number of sections to create predictions for + increase = 0.05 + from_section = min(sections) + to_section = max(sections) + section_range = to_section - from_section + increase = int(math.ceil(section_range * increase)) // 2 + min_section = max(from_section - increase, 1) + max_section = to_section + increase + + validation_sections = get_slices_in_range(start=min_section, + stop=max_section, + brain=brain, + pattern=constants.original_path, + step=1) + congruent_slices = {key: train_slices for key in set(validation_sections)} + + # Settings were configured for a batch size of 12. + # We adjust the learning rate if the actually used batch size is higher + base_batch_size = 12 + base_learning_rate = 0.01 + # Also decrease the number of iterations and respective learning rate steps + base_iterations = 3000 + base_steps = [1000, 1400, 1800, 2200, 2600] + + # Actual batch size and learning rate + batch_size = 42 + lr_multiplier = batch_size / base_batch_size + # Multiply -> Increase the learning rate if batch size gets larger + learning_rate = base_learning_rate * lr_multiplier + # Divide -> Reduce number of iterations if batch size gets larger + iterations = int(math.ceil(base_iterations / lr_multiplier)) + steps = [int(math.ceil(it / lr_multiplier)) for it in base_steps] + + # Apply weight decay + param_selectors = [WeightDecaySelector(), WeightDecaySelector(inverse=True)] + weight_decay = [0.0001, 0.0] + learning_rate_params = config.LearningRateParams(learning_rate=learning_rate, + lr_policy="piecewise", + piecewise_steps=base_steps, + piecewise_values=[learning_rate*(0.5 ** i) for i in range(1, 6)]) + learning_rate_params = config.LearningRateParamsPerParamGroup(learning_rate_params, learning_rate_params) + + # Initialize weights of encoder + # noinspection PyUnusedLocal + def _weight_init_func(model, trainer, *args, **kwargs): + import torch + import types + weights = torch.load(weight_file, map_location=kwargs.get("device", trainer.device))["model"] + # Filter out non-encoder weights + missing = model.load_state_dict(weights, strict=False) + assert not missing.unexpected_keys, missing.unexpected_keys + missing_str = ",".join(missing.missing_keys) + trainer.log.info(f"Missing keys when loading weights, these weights will __NOT__ be initialized: {missing_str}") + + if freeze_encoder: + # Lock the encoder + for model_part in ("down_2", "down_16", "bottom_2", "bottom_16"): + model_part = getattr(model, model_part) + model_part.eval() + train_fun = model_part.train + model_part.train = types.MethodType(lambda self, mode=False: train_fun(False), model_part) + trainer.log.info("Locked encoder") + + # Disable gradients of the encoder + for p in model_part.parameters(): + p.requires_grad = False + trainer.log.info("Disabled gradient computation for encoder") + + train_params = config.TrainParams( + backend="torch", + mode="fast_segmentation", + iterations=iterations, + test_interval=None, + snapshot=100, + data_format="NCHW", + class_weights="$CLASS_WEIGHTS", + save_best_weights=False, + input_output_mapping=ds.InputOutputMapping( + # Multi scale input + input=("gray2", "gray16"), + # Labels for supervised training + output="labels" + ), + num_classes=len(labels), + distributed_learning_rate_scaling=True, + optimizer_params=config.OptimizerParams( + # Note: Learning rate will be scaled by number of GPUs (most often 4) + learning_rate_params=learning_rate_params, + param_group_selectors=param_selectors, + weight_decay=weight_decay, + ), + network_params=config.NetworkParams(make_network=make_network, weight_init_func=_weight_init_func), + iterator_params={ + config.TRAIN_SPLIT: config.IteratorParams( + batch_size=batch_size, + data_sources=[ds.ImageDataSource(name="gray2", size=2025, spacing=2, channels=CHANNEL_HISTOLOGY, data_format="NCHW"), + ds.ImageDataSource(name="gray16", size=628, spacing=16, channels=CHANNEL_HISTOLOGY, data_format="NCHW"), + ds.LabelDataSource(name="labels", size=68, spacing=16, labels=labels)], + image_provider_params=config.ImageProviderParams( + opencv_num_threads=4, + numba_num_threads=4, + transformation_parameters=config.TransformationParams( + normalize_orientation=True, + pre_registration_base_slice=sections[0], + brightness_augmentation=True, + brightness_augmentation_range=(-0.2, +0.2), + contrast_augmentation=True, + contrast_augmentation_range=(0.9, 1.1), + gamma_contrast_augmentation=True, + gamma_contrast_augmentation_range=(0.8, 1.214), + random_rotation=True, + random_rotation_range=(-math.pi / 4, +math.pi / 4) + ), + ), + sampler_params=config.SamplerParams( + sampler_mode="weighted", + weighted_sampler_probability_weights={label: 1 / len(labels) for label in labels}, + weighted_sampler_labels=labels, + area_provider_params=config.AreaProviderParams( + area_mode="distance", + distance_area_provider_radius=2500, + distance_area_provider_radius_spacing=2, + distance_area_provider_radius_per_section=True, + congruent_enlargement_factor=1.05, + labels=labels, + congruent_slices=congruent_slices, + slices=train_slices, + use_pre_registration_for_congruent_sections=True, + congruent_mode="match", + ), + ), + ), + config.VAL_SPLIT: config.IteratorParams( + batch_size=5, + data_sources=[ds.ImageDataSource(name="gray2", size=8169, spacing=2, channels=CHANNEL_HISTOLOGY, data_format="NCHW"), + ds.ImageDataSource(name="gray16", size=1396, spacing=16, channels=CHANNEL_HISTOLOGY, data_format="NCHW"), + ds.LabelDataSource(name="labels", size=836, spacing=16, labels=labels, create=False)], + image_provider_params=config.ImageProviderParams( + transformation_parameters=config.TransformationParams( + normalize_orientation=True, + ), + ), + sampler_params=config.SamplerParams( + sampler_mode="grid", + grid_sampler_output_format=config.OutputFormat(size=836, spacing=16), + image_provider_params=config.ImageProviderParams( + transformation_parameters=config.TransformationParams( + normalize_orientation=True, + ), + ), + area_provider_params=config.AreaProviderParams( + area_mode="distance", + distance_area_provider_radius=2500, + distance_area_provider_radius_spacing=2, + distance_area_provider_radius_per_section=True, + congruent_enlargement_factor=1.05, + labels=labels, + slices=validation_sections, + congruent_slices=congruent_slices, + use_pre_registration_for_congruent_sections=True, + congruent_mode="match", + ) + ) + ), + } + ) + + return train_params + + +def make_network(**kwargs): + from atlas.torch.models.segmentation.unet_multiscale import UNet + from atlas.torch.models.segmentation.unet_parts import UNetDownPathParams, UNetBottomPathParams, UNetUpPathParams + + input_shape = kwargs["input_shape"] + input_spacing = kwargs["input_spacing"] + num_classes = kwargs["num_classes"] + + input_channels = [s[1] for s in input_shape] + model = UNet(input_channels=input_channels, + input_spacings=input_spacing, + num_classes=num_classes, + down_params=[ + UNetDownPathParams(filters=[16, 32, 64, 64, 128], + kernel_size=[(5, 3), 3, 3, 3, 3], + stride=[(4, 1), 1, 1, 1, 1]), + UNetDownPathParams(filters=[16, 32, 64, 64, 128, ], + dilation=[1, 2, 2, 2, 2]) + ], + bottom_params=[ + UNetBottomPathParams(filters=[128, ]), + UNetBottomPathParams(filters=[128, ], dilation=2) + ], + up_params=UNetUpPathParams(filters=[128, 64, 64, 32], + upsampling_mode="upsampling")) + return model + -- GitLab