Skip to content
Snippets Groups Projects
Commit d2b49ecd authored by Schiffer, Christian's avatar Schiffer, Christian
Browse files

[DATALAD] Recorded changes

parent 47751fcf
No related branches found
No related tags found
No related merge requests found
def make_training_parameters(**kwargs):
import os
import math
from atlas.configuration import config, constants
from atlas.configuration import data_source as ds
from atlas.data.image_provider import CHANNEL_HISTOLOGY
from atlas.experiments import get_slices_in_range
from atlas.torch.optimizers import WeightDecaySelector
brain = __BRAIN__
labels = ["bg", ]
no_aux_labels = __NO_AUX_LABELS__
if not no_aux_labels:
labels = labels + ["cor", "wm", ]
labels = labels + __LABELS__
sections = __SECTIONS__
freeze_encoder = True
weight_file = "/p/fastdata/bigbrains/personal/schiffer1/experiments/2024_cytonet/experiments/atlasui_models/models/cytonet_unet_200k_sigma10.pt"
train_slices = [(brain, s) for s in sections]
# Slightly increase the number of sections to create predictions for
increase = 0.05
from_section = min(sections)
to_section = max(sections)
section_range = to_section - from_section
increase = int(math.ceil(section_range * increase)) // 2
min_section = max(from_section - increase, 1)
max_section = to_section + increase
validation_sections = get_slices_in_range(start=min_section,
stop=max_section,
brain=brain,
pattern=constants.original_path,
step=1)
congruent_slices = {key: train_slices for key in set(validation_sections)}
# Settings were configured for a batch size of 12.
# We adjust the learning rate if the actually used batch size is higher
base_batch_size = 12
base_learning_rate = 0.01
# Also decrease the number of iterations and respective learning rate steps
base_iterations = 3000
base_steps = [1000, 1400, 1800, 2200, 2600]
# Actual batch size and learning rate
batch_size = 42
lr_multiplier = batch_size / base_batch_size
# Multiply -> Increase the learning rate if batch size gets larger
learning_rate = base_learning_rate * lr_multiplier
# Divide -> Reduce number of iterations if batch size gets larger
iterations = int(math.ceil(base_iterations / lr_multiplier))
steps = [int(math.ceil(it / lr_multiplier)) for it in base_steps]
# Apply weight decay
param_selectors = [WeightDecaySelector(), WeightDecaySelector(inverse=True)]
weight_decay = [0.0001, 0.0]
learning_rate_params = config.LearningRateParams(learning_rate=learning_rate,
lr_policy="piecewise",
piecewise_steps=base_steps,
piecewise_values=[learning_rate*(0.5 ** i) for i in range(1, 6)])
learning_rate_params = config.LearningRateParamsPerParamGroup(learning_rate_params, learning_rate_params)
# Initialize weights of encoder
# noinspection PyUnusedLocal
def _weight_init_func(model, trainer, *args, **kwargs):
import torch
import types
weights = torch.load(weight_file, map_location=kwargs.get("device", trainer.device))["model"]
# Filter out non-encoder weights
missing = model.load_state_dict(weights, strict=False)
assert not missing.unexpected_keys, missing.unexpected_keys
missing_str = ",".join(missing.missing_keys)
trainer.log.info(f"Missing keys when loading weights, these weights will __NOT__ be initialized: {missing_str}")
if freeze_encoder:
# Lock the encoder
for model_part in ("down_2", "down_16", "bottom_2", "bottom_16"):
model_part = getattr(model, model_part)
model_part.eval()
train_fun = model_part.train
model_part.train = types.MethodType(lambda self, mode=False: train_fun(False), model_part)
trainer.log.info("Locked encoder")
# Disable gradients of the encoder
for p in model_part.parameters():
p.requires_grad = False
trainer.log.info("Disabled gradient computation for encoder")
train_params = config.TrainParams(
backend="torch",
mode="fast_segmentation",
iterations=iterations,
test_interval=None,
snapshot=100,
data_format="NCHW",
class_weights="$CLASS_WEIGHTS",
save_best_weights=False,
input_output_mapping=ds.InputOutputMapping(
# Multi scale input
input=("gray2", "gray16"),
# Labels for supervised training
output="labels"
),
num_classes=len(labels),
distributed_learning_rate_scaling=True,
optimizer_params=config.OptimizerParams(
# Note: Learning rate will be scaled by number of GPUs (most often 4)
learning_rate_params=learning_rate_params,
param_group_selectors=param_selectors,
weight_decay=weight_decay,
),
network_params=config.NetworkParams(make_network=make_network, weight_init_func=_weight_init_func),
iterator_params={
config.TRAIN_SPLIT: config.IteratorParams(
batch_size=batch_size,
data_sources=[ds.ImageDataSource(name="gray2", size=2025, spacing=2, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.ImageDataSource(name="gray16", size=628, spacing=16, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.LabelDataSource(name="labels", size=68, spacing=16, labels=labels)],
image_provider_params=config.ImageProviderParams(
opencv_num_threads=4,
numba_num_threads=4,
transformation_parameters=config.TransformationParams(
normalize_orientation=True,
pre_registration_base_slice=sections[0],
brightness_augmentation=True,
brightness_augmentation_range=(-0.2, +0.2),
contrast_augmentation=True,
contrast_augmentation_range=(0.9, 1.1),
gamma_contrast_augmentation=True,
gamma_contrast_augmentation_range=(0.8, 1.214),
random_rotation=True,
random_rotation_range=(-math.pi / 4, +math.pi / 4)
),
),
sampler_params=config.SamplerParams(
sampler_mode="weighted",
weighted_sampler_probability_weights={label: 1 / len(labels) for label in labels},
weighted_sampler_labels=labels,
area_provider_params=config.AreaProviderParams(
area_mode="distance",
distance_area_provider_radius=2500,
distance_area_provider_radius_spacing=2,
distance_area_provider_radius_per_section=True,
congruent_enlargement_factor=1.05,
labels=labels,
congruent_slices=congruent_slices,
slices=train_slices,
use_pre_registration_for_congruent_sections=True,
congruent_mode="match",
),
),
),
config.VAL_SPLIT: config.IteratorParams(
batch_size=5,
data_sources=[ds.ImageDataSource(name="gray2", size=8169, spacing=2, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.ImageDataSource(name="gray16", size=1396, spacing=16, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.LabelDataSource(name="labels", size=836, spacing=16, labels=labels, create=False)],
image_provider_params=config.ImageProviderParams(
transformation_parameters=config.TransformationParams(
normalize_orientation=True,
),
),
sampler_params=config.SamplerParams(
sampler_mode="grid",
grid_sampler_output_format=config.OutputFormat(size=836, spacing=16),
image_provider_params=config.ImageProviderParams(
transformation_parameters=config.TransformationParams(
normalize_orientation=True,
),
),
area_provider_params=config.AreaProviderParams(
area_mode="distance",
distance_area_provider_radius=2500,
distance_area_provider_radius_spacing=2,
distance_area_provider_radius_per_section=True,
congruent_enlargement_factor=1.05,
labels=labels,
slices=validation_sections,
congruent_slices=congruent_slices,
use_pre_registration_for_congruent_sections=True,
congruent_mode="match",
)
)
),
}
)
return train_params
def make_network(**kwargs):
from atlas.torch.models.segmentation.unet_multiscale import UNet
from atlas.torch.models.segmentation.unet_parts import UNetDownPathParams, UNetBottomPathParams, UNetUpPathParams
input_shape = kwargs["input_shape"]
input_spacing = kwargs["input_spacing"]
num_classes = kwargs["num_classes"]
input_channels = [s[1] for s in input_shape]
model = UNet(input_channels=input_channels,
input_spacings=input_spacing,
num_classes=num_classes,
down_params=[
UNetDownPathParams(filters=[16, 32, 64, 64, 128],
kernel_size=[(5, 3), 3, 3, 3, 3],
stride=[(4, 1), 1, 1, 1, 1]),
UNetDownPathParams(filters=[16, 32, 64, 64, 128, ],
dilation=[1, 2, 2, 2, 2])
],
bottom_params=[
UNetBottomPathParams(filters=[128, ]),
UNetBottomPathParams(filters=[128, ], dilation=2)
],
up_params=UNetUpPathParams(filters=[128, 64, 64, 32],
upsampling_mode="upsampling"))
return model
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment