From d2b49ecdf2f8d03af35415cc0002d17bd098fad9 Mon Sep 17 00:00:00 2001
From: Christian Schiffer <c.schiffer@fz-juelich.de>
Date: Fri, 12 Apr 2024 09:53:57 +0200
Subject: [PATCH] [DATALAD] Recorded changes

---
 .../src/static/training/config_cytonet.py     | 217 ++++++++++++++++++
 1 file changed, 217 insertions(+)
 create mode 100644 atlas_server/src/static/training/config_cytonet.py

diff --git a/atlas_server/src/static/training/config_cytonet.py b/atlas_server/src/static/training/config_cytonet.py
new file mode 100644
index 0000000..ddf7843
--- /dev/null
+++ b/atlas_server/src/static/training/config_cytonet.py
@@ -0,0 +1,217 @@
+def make_training_parameters(**kwargs):
+    import os
+    import math
+    from atlas.configuration import config, constants
+    from atlas.configuration import data_source as ds
+    from atlas.data.image_provider import CHANNEL_HISTOLOGY
+    from atlas.experiments import get_slices_in_range
+    from atlas.torch.optimizers import WeightDecaySelector
+
+    brain = __BRAIN__
+    labels = ["bg", ]
+    no_aux_labels = __NO_AUX_LABELS__
+    if not no_aux_labels:
+        labels = labels + ["cor", "wm", ]
+    labels = labels + __LABELS__
+    sections = __SECTIONS__
+    freeze_encoder = True
+    weight_file = "/p/fastdata/bigbrains/personal/schiffer1/experiments/2024_cytonet/experiments/atlasui_models/models/cytonet_unet_200k_sigma10.pt"
+    train_slices = [(brain, s) for s in sections]
+
+    # Slightly increase the number of sections to create predictions for
+    increase = 0.05
+    from_section = min(sections)
+    to_section = max(sections)
+    section_range = to_section - from_section
+    increase = int(math.ceil(section_range * increase)) // 2
+    min_section = max(from_section - increase, 1)
+    max_section = to_section + increase
+
+    validation_sections = get_slices_in_range(start=min_section,
+            stop=max_section,
+            brain=brain,
+            pattern=constants.original_path,
+            step=1)
+    congruent_slices = {key: train_slices for key in set(validation_sections)}
+
+    # Settings were configured for a batch size of 12.
+    # We adjust the learning rate if the actually used batch size is higher
+    base_batch_size = 12
+    base_learning_rate = 0.01
+    # Also decrease the number of iterations and respective learning rate steps
+    base_iterations = 3000
+    base_steps = [1000, 1400, 1800, 2200, 2600]
+
+    # Actual batch size and learning rate
+    batch_size = 42
+    lr_multiplier = batch_size / base_batch_size
+    # Multiply -> Increase the learning rate if batch size gets larger
+    learning_rate = base_learning_rate * lr_multiplier
+    # Divide -> Reduce number of iterations if batch size gets larger
+    iterations = int(math.ceil(base_iterations / lr_multiplier))
+    steps = [int(math.ceil(it / lr_multiplier)) for it in base_steps]
+
+    # Apply weight decay
+    param_selectors = [WeightDecaySelector(), WeightDecaySelector(inverse=True)]
+    weight_decay = [0.0001, 0.0]
+    learning_rate_params = config.LearningRateParams(learning_rate=learning_rate,
+                                                     lr_policy="piecewise",
+                                                     piecewise_steps=base_steps,
+                                                     piecewise_values=[learning_rate*(0.5 ** i) for i in range(1, 6)])
+    learning_rate_params = config.LearningRateParamsPerParamGroup(learning_rate_params, learning_rate_params)
+
+    # Initialize weights of encoder
+    # noinspection PyUnusedLocal
+    def _weight_init_func(model, trainer, *args, **kwargs):
+        import torch
+        import types
+        weights = torch.load(weight_file, map_location=kwargs.get("device", trainer.device))["model"]
+        # Filter out non-encoder weights
+        missing = model.load_state_dict(weights, strict=False)
+        assert not missing.unexpected_keys, missing.unexpected_keys
+        missing_str = ",".join(missing.missing_keys)
+        trainer.log.info(f"Missing keys when loading weights, these weights will __NOT__ be initialized: {missing_str}")
+
+        if freeze_encoder:
+            # Lock the encoder
+            for model_part in ("down_2", "down_16", "bottom_2", "bottom_16"):
+                model_part = getattr(model, model_part)
+                model_part.eval()
+                train_fun = model_part.train
+                model_part.train = types.MethodType(lambda self, mode=False: train_fun(False), model_part)
+                trainer.log.info("Locked encoder")
+
+                # Disable gradients of the encoder
+                for p in model_part.parameters():
+                    p.requires_grad = False
+                trainer.log.info("Disabled gradient computation for encoder")
+
+    train_params = config.TrainParams(
+            backend="torch",
+            mode="fast_segmentation",
+            iterations=iterations,
+            test_interval=None,
+            snapshot=100,
+            data_format="NCHW",
+            class_weights="$CLASS_WEIGHTS",
+            save_best_weights=False,
+            input_output_mapping=ds.InputOutputMapping(
+                # Multi scale input
+                input=("gray2", "gray16"),
+                # Labels for supervised training
+                output="labels"
+                ),
+            num_classes=len(labels),
+            distributed_learning_rate_scaling=True,
+            optimizer_params=config.OptimizerParams(
+                # Note: Learning rate will be scaled by number of GPUs (most often 4)
+                learning_rate_params=learning_rate_params,
+                param_group_selectors=param_selectors,
+                weight_decay=weight_decay,
+            ),
+                network_params=config.NetworkParams(make_network=make_network, weight_init_func=_weight_init_func),
+            iterator_params={
+                config.TRAIN_SPLIT: config.IteratorParams(
+                    batch_size=batch_size,
+                    data_sources=[ds.ImageDataSource(name="gray2", size=2025, spacing=2, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
+                        ds.ImageDataSource(name="gray16", size=628, spacing=16, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
+                        ds.LabelDataSource(name="labels", size=68, spacing=16, labels=labels)],
+                    image_provider_params=config.ImageProviderParams(
+                        opencv_num_threads=4,
+                        numba_num_threads=4,
+                        transformation_parameters=config.TransformationParams(
+                            normalize_orientation=True,
+                            pre_registration_base_slice=sections[0],
+                            brightness_augmentation=True,
+                            brightness_augmentation_range=(-0.2, +0.2),
+                            contrast_augmentation=True,
+                            contrast_augmentation_range=(0.9, 1.1),
+                            gamma_contrast_augmentation=True,
+                            gamma_contrast_augmentation_range=(0.8, 1.214),
+                            random_rotation=True,
+                            random_rotation_range=(-math.pi / 4, +math.pi / 4)
+                            ),
+                        ),
+                    sampler_params=config.SamplerParams(
+                        sampler_mode="weighted",
+                        weighted_sampler_probability_weights={label: 1 / len(labels) for label in labels},
+                        weighted_sampler_labels=labels,
+                        area_provider_params=config.AreaProviderParams(
+                            area_mode="distance",
+                            distance_area_provider_radius=2500,
+                            distance_area_provider_radius_spacing=2,
+                            distance_area_provider_radius_per_section=True,
+                            congruent_enlargement_factor=1.05,
+                            labels=labels,
+                            congruent_slices=congruent_slices,
+                            slices=train_slices,
+                            use_pre_registration_for_congruent_sections=True,
+                            congruent_mode="match",
+                            ),
+                        ),
+                    ),
+                config.VAL_SPLIT: config.IteratorParams(
+                    batch_size=5,
+                    data_sources=[ds.ImageDataSource(name="gray2", size=8169, spacing=2, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
+                        ds.ImageDataSource(name="gray16", size=1396, spacing=16, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
+                        ds.LabelDataSource(name="labels", size=836, spacing=16, labels=labels, create=False)],
+                    image_provider_params=config.ImageProviderParams(
+                        transformation_parameters=config.TransformationParams(
+                            normalize_orientation=True,
+                            ),
+                        ),
+                    sampler_params=config.SamplerParams(
+                        sampler_mode="grid",
+                        grid_sampler_output_format=config.OutputFormat(size=836, spacing=16),
+                        image_provider_params=config.ImageProviderParams(
+                            transformation_parameters=config.TransformationParams(
+                                normalize_orientation=True,
+                                ),
+                            ),
+                        area_provider_params=config.AreaProviderParams(
+                            area_mode="distance",
+                            distance_area_provider_radius=2500,
+                            distance_area_provider_radius_spacing=2,
+                            distance_area_provider_radius_per_section=True,
+                            congruent_enlargement_factor=1.05,
+                            labels=labels,
+                            slices=validation_sections,
+                            congruent_slices=congruent_slices,
+                            use_pre_registration_for_congruent_sections=True,
+                            congruent_mode="match",
+                            )
+                        )
+                    ),
+                }
+    )
+
+    return train_params
+
+
+def make_network(**kwargs):
+    from atlas.torch.models.segmentation.unet_multiscale import UNet
+    from atlas.torch.models.segmentation.unet_parts import UNetDownPathParams, UNetBottomPathParams, UNetUpPathParams
+
+    input_shape = kwargs["input_shape"]
+    input_spacing = kwargs["input_spacing"]
+    num_classes = kwargs["num_classes"]
+
+    input_channels = [s[1] for s in input_shape]
+    model = UNet(input_channels=input_channels,
+            input_spacings=input_spacing,
+            num_classes=num_classes,
+            down_params=[
+                UNetDownPathParams(filters=[16, 32, 64, 64, 128],
+                    kernel_size=[(5, 3), 3, 3, 3, 3],
+                    stride=[(4, 1), 1, 1, 1, 1]),
+                UNetDownPathParams(filters=[16, 32, 64, 64, 128, ],
+                    dilation=[1, 2, 2, 2, 2])
+                ],
+            bottom_params=[
+                UNetBottomPathParams(filters=[128, ]),
+                UNetBottomPathParams(filters=[128, ], dilation=2)
+                ],
+            up_params=UNetUpPathParams(filters=[128, 64, 64, 32],
+                upsampling_mode="upsampling"))
+    return model
+
-- 
GitLab