Skip to content
Snippets Groups Projects
Commit e0b356de authored by Schiffer, Christian's avatar Schiffer, Christian
Browse files

Fixed with new paths

parent 4acffb94
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python3
WORK_DIR_ROOT = "/p/fastdata/bigbrains/personal/schiffer1/atlasui/work_dir/"
WORK_DIR_ROOT = "/p/data1/bigbrains/personal/schiffer1/atlasui/work_dir/"
......@@ -8,10 +8,15 @@ def make_training_parameters(**kwargs):
from atlas.torch.optimizers import WeightDecaySelector
brain = __BRAIN__
labels = ["bg", ]
labels = [
"bg",
]
no_aux_labels = __NO_AUX_LABELS__
if not no_aux_labels:
labels = labels + ["cor", "wm", ]
labels = labels + [
"cor",
"wm",
]
labels = labels + __LABELS__
sections = __SECTIONS__
train_slices = [(brain, s) for s in sections]
......@@ -25,11 +30,13 @@ def make_training_parameters(**kwargs):
min_section = max(from_section - increase, 1)
max_section = to_section + increase
validation_sections = get_slices_in_range(start=min_section,
validation_sections = get_slices_in_range(
start=min_section,
stop=max_section,
brain=brain,
pattern=constants.original_path,
step=1)
step=1,
)
congruent_slices = {key: train_slices for key in set(validation_sections)}
# Settings were configured for a batch size of 12.
......@@ -52,11 +59,15 @@ def make_training_parameters(**kwargs):
# Apply weight decay
param_selectors = [WeightDecaySelector(), WeightDecaySelector(inverse=True)]
weight_decay = [0.0001, 0.0]
learning_rate_params = config.LearningRateParams(learning_rate=learning_rate,
learning_rate_params = config.LearningRateParams(
learning_rate=learning_rate,
lr_policy="piecewise",
piecewise_steps=base_steps,
piecewise_values=[learning_rate*(0.5 ** i) for i in range(1, 6)])
learning_rate_params = config.LearningRateParamsPerParamGroup(learning_rate_params, learning_rate_params)
piecewise_values=[learning_rate * (0.5**i) for i in range(1, 6)],
)
learning_rate_params = config.LearningRateParamsPerParamGroup(
learning_rate_params, learning_rate_params
)
train_params = config.TrainParams(
backend="torch",
......@@ -71,7 +82,7 @@ def make_training_parameters(**kwargs):
# Multi scale input
input=("gray2", "gray16"),
# Labels for supervised training
output="labels"
output="labels",
),
num_classes=len(labels),
distributed_learning_rate_scaling=True,
......@@ -83,16 +94,32 @@ def make_training_parameters(**kwargs):
),
network_params=config.NetworkParams(
make_network=make_network,
finetune="/p/project/cjinm16/schiffer1/experiments/models/self_supervised/self_supervised_pretrained/model.hdf5",
finetune_mapping = "/p/project/cjinm16/schiffer1/experiments/models/keras2torch/self_supervised_pretrained_to_multi_scale_unet.json",
finetune="/p/project0/cjinm16/schiffer1/experiments/models/self_supervised/self_supervised_pretrained/model.hdf5",
finetune_mapping="/p/project0/cjinm16/schiffer1/experiments/models/keras2torch/self_supervised_pretrained_to_multi_scale_unet.json",
finetune_file_key=None,
),
iterator_params={
config.TRAIN_SPLIT: config.IteratorParams(
batch_size=batch_size,
data_sources=[ds.ImageDataSource(name="gray2", size=2025, spacing=2, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.ImageDataSource(name="gray16", size=628, spacing=16, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.LabelDataSource(name="labels", size=68, spacing=16, labels=labels)],
data_sources=[
ds.ImageDataSource(
name="gray2",
size=2025,
spacing=2,
channels=CHANNEL_HISTOLOGY,
data_format="NCHW",
),
ds.ImageDataSource(
name="gray16",
size=628,
spacing=16,
channels=CHANNEL_HISTOLOGY,
data_format="NCHW",
),
ds.LabelDataSource(
name="labels", size=68, spacing=16, labels=labels
),
],
image_provider_params=config.ImageProviderParams(
opencv_num_threads=4,
numba_num_threads=4,
......@@ -106,12 +133,14 @@ def make_training_parameters(**kwargs):
gamma_contrast_augmentation=True,
gamma_contrast_augmentation_range=(0.8, 1.214),
random_rotation=True,
random_rotation_range=(-math.pi / 4, +math.pi / 4)
random_rotation_range=(-math.pi / 4, +math.pi / 4),
),
),
sampler_params=config.SamplerParams(
sampler_mode="weighted",
weighted_sampler_probability_weights={label: 1 / len(labels) for label in labels},
weighted_sampler_probability_weights={
label: 1 / len(labels) for label in labels
},
weighted_sampler_labels=labels,
area_provider_params=config.AreaProviderParams(
area_mode="distance",
......@@ -129,9 +158,25 @@ def make_training_parameters(**kwargs):
),
config.VAL_SPLIT: config.IteratorParams(
batch_size=5,
data_sources=[ds.ImageDataSource(name="gray2", size=8169, spacing=2, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.ImageDataSource(name="gray16", size=1396, spacing=16, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.LabelDataSource(name="labels", size=836, spacing=16, labels=labels, create=False)],
data_sources=[
ds.ImageDataSource(
name="gray2",
size=8169,
spacing=2,
channels=CHANNEL_HISTOLOGY,
data_format="NCHW",
),
ds.ImageDataSource(
name="gray16",
size=1396,
spacing=16,
channels=CHANNEL_HISTOLOGY,
data_format="NCHW",
),
ds.LabelDataSource(
name="labels", size=836, spacing=16, labels=labels, create=False
),
],
image_provider_params=config.ImageProviderParams(
transformation_parameters=config.TransformationParams(
normalize_orientation=True,
......@@ -139,7 +184,9 @@ def make_training_parameters(**kwargs):
),
sampler_params=config.SamplerParams(
sampler_mode="grid",
grid_sampler_output_format=config.OutputFormat(size=836, spacing=16),
grid_sampler_output_format=config.OutputFormat(
size=836, spacing=16
),
image_provider_params=config.ImageProviderParams(
transformation_parameters=config.TransformationParams(
normalize_orientation=True,
......@@ -156,10 +203,10 @@ def make_training_parameters(**kwargs):
congruent_slices=congruent_slices,
use_pre_registration_for_congruent_sections=True,
congruent_mode="match",
)
)
),
}
),
),
},
)
return train_params
......@@ -167,28 +214,53 @@ def make_training_parameters(**kwargs):
def make_network(**kwargs):
from atlas.torch.models.segmentation.unet_multiscale import UNet
from atlas.torch.models.segmentation.unet_parts import UNetDownPathParams, UNetBottomPathParams, UNetUpPathParams
from atlas.torch.models.segmentation.unet_parts import (
UNetDownPathParams,
UNetBottomPathParams,
UNetUpPathParams,
)
input_shape = kwargs["input_shape"]
input_spacing = kwargs["input_spacing"]
num_classes = kwargs["num_classes"]
input_channels = [s[1] for s in input_shape]
model = UNet(input_channels=input_channels,
model = UNet(
input_channels=input_channels,
input_spacings=input_spacing,
num_classes=num_classes,
down_params=[
UNetDownPathParams(filters=[16, 32, 64, 64, 128],
UNetDownPathParams(
filters=[16, 32, 64, 64, 128],
kernel_size=[(5, 3), 3, 3, 3, 3],
stride=[(4, 1), 1, 1, 1, 1]),
UNetDownPathParams(filters=[16, 32, 64, 64, 128, ],
dilation=[1, 2, 2, 2, 2])
stride=[(4, 1), 1, 1, 1, 1],
),
UNetDownPathParams(
filters=[
16,
32,
64,
64,
128,
],
dilation=[1, 2, 2, 2, 2],
),
],
bottom_params=[
UNetBottomPathParams(filters=[128, ]),
UNetBottomPathParams(filters=[128, ], dilation=2)
UNetBottomPathParams(
filters=[
128,
]
),
UNetBottomPathParams(
filters=[
128,
],
dilation=2,
),
],
up_params=UNetUpPathParams(filters=[128, 64, 64, 32],
upsampling_mode="upsampling"))
up_params=UNetUpPathParams(
filters=[128, 64, 64, 32], upsampling_mode="upsampling"
),
)
return model
......@@ -8,14 +8,19 @@ def make_training_parameters(**kwargs):
from atlas.torch.optimizers import WeightDecaySelector
brain = __BRAIN__
labels = ["bg", ]
labels = [
"bg",
]
no_aux_labels = __NO_AUX_LABELS__
if not no_aux_labels:
labels = labels + ["cor", "wm", ]
labels = labels + [
"cor",
"wm",
]
labels = labels + __LABELS__
sections = __SECTIONS__
freeze_encoder = True
weight_file = "/p/fastdata/bigbrains/personal/schiffer1/experiments/2024_cytonet/experiments/atlasui_models/models/cytonet_unet_200k_sigma10.pt"
weight_file = "/p/data1/bigbrains/personal/schiffer1/experiments/2024_cytonet/experiments/atlasui_models/models/cytonet_unet_200k_sigma10.pt"
train_slices = [(brain, s) for s in sections]
# Slightly increase the number of sections to create predictions for
......@@ -27,11 +32,13 @@ def make_training_parameters(**kwargs):
min_section = max(from_section - increase, 1)
max_section = to_section + increase
validation_sections = get_slices_in_range(start=min_section,
validation_sections = get_slices_in_range(
start=min_section,
stop=max_section,
brain=brain,
pattern=constants.original_path,
step=1)
step=1,
)
congruent_slices = {key: train_slices for key in set(validation_sections)}
# Settings were configured for a batch size of 12.
......@@ -54,23 +61,32 @@ def make_training_parameters(**kwargs):
# Apply weight decay
param_selectors = [WeightDecaySelector(), WeightDecaySelector(inverse=True)]
weight_decay = [0.0001, 0.0]
learning_rate_params = config.LearningRateParams(learning_rate=learning_rate,
learning_rate_params = config.LearningRateParams(
learning_rate=learning_rate,
lr_policy="piecewise",
piecewise_steps=base_steps,
piecewise_values=[learning_rate*(0.5 ** i) for i in range(1, 6)])
learning_rate_params = config.LearningRateParamsPerParamGroup(learning_rate_params, learning_rate_params)
piecewise_values=[learning_rate * (0.5**i) for i in range(1, 6)],
)
learning_rate_params = config.LearningRateParamsPerParamGroup(
learning_rate_params, learning_rate_params
)
# Initialize weights of encoder
# noinspection PyUnusedLocal
def _weight_init_func(model, trainer, *args, **kwargs):
import torch
import types
weights = torch.load(weight_file, map_location=kwargs.get("device", trainer.device))["model"]
weights = torch.load(
weight_file, map_location=kwargs.get("device", trainer.device)
)["model"]
# Filter out non-encoder weights
missing = model.load_state_dict(weights, strict=False)
assert not missing.unexpected_keys, missing.unexpected_keys
missing_str = ",".join(missing.missing_keys)
trainer.log.info(f"Missing keys when loading weights, these weights will __NOT__ be initialized: {missing_str}")
trainer.log.info(
f"Missing keys when loading weights, these weights will __NOT__ be initialized: {missing_str}"
)
if freeze_encoder:
# Lock the encoder
......@@ -78,7 +94,9 @@ def make_training_parameters(**kwargs):
model_part = getattr(model, model_part)
model_part.eval()
train_fun = model_part.train
model_part.train = types.MethodType(lambda self, mode=False: train_fun(False), model_part)
model_part.train = types.MethodType(
lambda self, mode=False: train_fun(False), model_part
)
trainer.log.info("Locked encoder")
# Disable gradients of the encoder
......@@ -99,7 +117,7 @@ def make_training_parameters(**kwargs):
# Multi scale input
input=("gray2", "gray16"),
# Labels for supervised training
output="labels"
output="labels",
),
num_classes=len(labels),
distributed_learning_rate_scaling=True,
......@@ -109,13 +127,31 @@ def make_training_parameters(**kwargs):
param_group_selectors=param_selectors,
weight_decay=weight_decay,
),
network_params=config.NetworkParams(make_network=make_network, weight_init_func=_weight_init_func),
network_params=config.NetworkParams(
make_network=make_network, weight_init_func=_weight_init_func
),
iterator_params={
config.TRAIN_SPLIT: config.IteratorParams(
batch_size=batch_size,
data_sources=[ds.ImageDataSource(name="gray2", size=2025, spacing=2, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.ImageDataSource(name="gray16", size=628, spacing=16, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.LabelDataSource(name="labels", size=68, spacing=16, labels=labels)],
data_sources=[
ds.ImageDataSource(
name="gray2",
size=2025,
spacing=2,
channels=CHANNEL_HISTOLOGY,
data_format="NCHW",
),
ds.ImageDataSource(
name="gray16",
size=628,
spacing=16,
channels=CHANNEL_HISTOLOGY,
data_format="NCHW",
),
ds.LabelDataSource(
name="labels", size=68, spacing=16, labels=labels
),
],
image_provider_params=config.ImageProviderParams(
opencv_num_threads=4,
numba_num_threads=4,
......@@ -129,12 +165,14 @@ def make_training_parameters(**kwargs):
gamma_contrast_augmentation=True,
gamma_contrast_augmentation_range=(0.8, 1.214),
random_rotation=True,
random_rotation_range=(-math.pi / 4, +math.pi / 4)
random_rotation_range=(-math.pi / 4, +math.pi / 4),
),
),
sampler_params=config.SamplerParams(
sampler_mode="weighted",
weighted_sampler_probability_weights={label: 1 / len(labels) for label in labels},
weighted_sampler_probability_weights={
label: 1 / len(labels) for label in labels
},
weighted_sampler_labels=labels,
area_provider_params=config.AreaProviderParams(
area_mode="distance",
......@@ -152,9 +190,25 @@ def make_training_parameters(**kwargs):
),
config.VAL_SPLIT: config.IteratorParams(
batch_size=5,
data_sources=[ds.ImageDataSource(name="gray2", size=8169, spacing=2, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.ImageDataSource(name="gray16", size=1396, spacing=16, channels=CHANNEL_HISTOLOGY, data_format="NCHW"),
ds.LabelDataSource(name="labels", size=836, spacing=16, labels=labels, create=False)],
data_sources=[
ds.ImageDataSource(
name="gray2",
size=8169,
spacing=2,
channels=CHANNEL_HISTOLOGY,
data_format="NCHW",
),
ds.ImageDataSource(
name="gray16",
size=1396,
spacing=16,
channels=CHANNEL_HISTOLOGY,
data_format="NCHW",
),
ds.LabelDataSource(
name="labels", size=836, spacing=16, labels=labels, create=False
),
],
image_provider_params=config.ImageProviderParams(
transformation_parameters=config.TransformationParams(
normalize_orientation=True,
......@@ -162,7 +216,9 @@ def make_training_parameters(**kwargs):
),
sampler_params=config.SamplerParams(
sampler_mode="grid",
grid_sampler_output_format=config.OutputFormat(size=836, spacing=16),
grid_sampler_output_format=config.OutputFormat(
size=836, spacing=16
),
image_provider_params=config.ImageProviderParams(
transformation_parameters=config.TransformationParams(
normalize_orientation=True,
......@@ -179,10 +235,10 @@ def make_training_parameters(**kwargs):
congruent_slices=congruent_slices,
use_pre_registration_for_congruent_sections=True,
congruent_mode="match",
)
)
),
}
),
),
},
)
return train_params
......@@ -190,28 +246,53 @@ def make_training_parameters(**kwargs):
def make_network(**kwargs):
from atlas.torch.models.segmentation.unet_multiscale import UNet
from atlas.torch.models.segmentation.unet_parts import UNetDownPathParams, UNetBottomPathParams, UNetUpPathParams
from atlas.torch.models.segmentation.unet_parts import (
UNetDownPathParams,
UNetBottomPathParams,
UNetUpPathParams,
)
input_shape = kwargs["input_shape"]
input_spacing = kwargs["input_spacing"]
num_classes = kwargs["num_classes"]
input_channels = [s[1] for s in input_shape]
model = UNet(input_channels=input_channels,
model = UNet(
input_channels=input_channels,
input_spacings=input_spacing,
num_classes=num_classes,
down_params=[
UNetDownPathParams(filters=[16, 32, 64, 64, 128],
UNetDownPathParams(
filters=[16, 32, 64, 64, 128],
kernel_size=[(5, 3), 3, 3, 3, 3],
stride=[(4, 1), 1, 1, 1, 1]),
UNetDownPathParams(filters=[16, 32, 64, 64, 128, ],
dilation=[1, 2, 2, 2, 2])
stride=[(4, 1), 1, 1, 1, 1],
),
UNetDownPathParams(
filters=[
16,
32,
64,
64,
128,
],
dilation=[1, 2, 2, 2, 2],
),
],
bottom_params=[
UNetBottomPathParams(filters=[128, ]),
UNetBottomPathParams(filters=[128, ], dilation=2)
UNetBottomPathParams(
filters=[
128,
]
),
UNetBottomPathParams(
filters=[
128,
],
dilation=2,
),
],
up_params=UNetUpPathParams(filters=[128, 64, 64, 32],
upsampling_mode="upsampling"))
up_params=UNetUpPathParams(
filters=[128, 64, 64, 32], upsampling_mode="upsampling"
),
)
return model
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment