Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
atlasui
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Container Registry
Model registry
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
INM-1
BDA
software
analysis
atlas
atlasui
Commits
d2b49ecd
Commit
d2b49ecd
authored
11 months ago
by
Schiffer, Christian
Browse files
Options
Downloads
Patches
Plain Diff
[DATALAD] Recorded changes
parent
47751fcf
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
atlas_server/src/static/training/config_cytonet.py
+217
-0
217 additions, 0 deletions
atlas_server/src/static/training/config_cytonet.py
with
217 additions
and
0 deletions
atlas_server/src/static/training/config_cytonet.py
0 → 100644
+
217
−
0
View file @
d2b49ecd
def
make_training_parameters
(
**
kwargs
):
import
os
import
math
from
atlas.configuration
import
config
,
constants
from
atlas.configuration
import
data_source
as
ds
from
atlas.data.image_provider
import
CHANNEL_HISTOLOGY
from
atlas.experiments
import
get_slices_in_range
from
atlas.torch.optimizers
import
WeightDecaySelector
brain
=
__BRAIN__
labels
=
[
"
bg
"
,
]
no_aux_labels
=
__NO_AUX_LABELS__
if
not
no_aux_labels
:
labels
=
labels
+
[
"
cor
"
,
"
wm
"
,
]
labels
=
labels
+
__LABELS__
sections
=
__SECTIONS__
freeze_encoder
=
True
weight_file
=
"
/p/fastdata/bigbrains/personal/schiffer1/experiments/2024_cytonet/experiments/atlasui_models/models/cytonet_unet_200k_sigma10.pt
"
train_slices
=
[(
brain
,
s
)
for
s
in
sections
]
# Slightly increase the number of sections to create predictions for
increase
=
0.05
from_section
=
min
(
sections
)
to_section
=
max
(
sections
)
section_range
=
to_section
-
from_section
increase
=
int
(
math
.
ceil
(
section_range
*
increase
))
//
2
min_section
=
max
(
from_section
-
increase
,
1
)
max_section
=
to_section
+
increase
validation_sections
=
get_slices_in_range
(
start
=
min_section
,
stop
=
max_section
,
brain
=
brain
,
pattern
=
constants
.
original_path
,
step
=
1
)
congruent_slices
=
{
key
:
train_slices
for
key
in
set
(
validation_sections
)}
# Settings were configured for a batch size of 12.
# We adjust the learning rate if the actually used batch size is higher
base_batch_size
=
12
base_learning_rate
=
0.01
# Also decrease the number of iterations and respective learning rate steps
base_iterations
=
3000
base_steps
=
[
1000
,
1400
,
1800
,
2200
,
2600
]
# Actual batch size and learning rate
batch_size
=
42
lr_multiplier
=
batch_size
/
base_batch_size
# Multiply -> Increase the learning rate if batch size gets larger
learning_rate
=
base_learning_rate
*
lr_multiplier
# Divide -> Reduce number of iterations if batch size gets larger
iterations
=
int
(
math
.
ceil
(
base_iterations
/
lr_multiplier
))
steps
=
[
int
(
math
.
ceil
(
it
/
lr_multiplier
))
for
it
in
base_steps
]
# Apply weight decay
param_selectors
=
[
WeightDecaySelector
(),
WeightDecaySelector
(
inverse
=
True
)]
weight_decay
=
[
0.0001
,
0.0
]
learning_rate_params
=
config
.
LearningRateParams
(
learning_rate
=
learning_rate
,
lr_policy
=
"
piecewise
"
,
piecewise_steps
=
base_steps
,
piecewise_values
=
[
learning_rate
*
(
0.5
**
i
)
for
i
in
range
(
1
,
6
)])
learning_rate_params
=
config
.
LearningRateParamsPerParamGroup
(
learning_rate_params
,
learning_rate_params
)
# Initialize weights of encoder
# noinspection PyUnusedLocal
def
_weight_init_func
(
model
,
trainer
,
*
args
,
**
kwargs
):
import
torch
import
types
weights
=
torch
.
load
(
weight_file
,
map_location
=
kwargs
.
get
(
"
device
"
,
trainer
.
device
))[
"
model
"
]
# Filter out non-encoder weights
missing
=
model
.
load_state_dict
(
weights
,
strict
=
False
)
assert
not
missing
.
unexpected_keys
,
missing
.
unexpected_keys
missing_str
=
"
,
"
.
join
(
missing
.
missing_keys
)
trainer
.
log
.
info
(
f
"
Missing keys when loading weights, these weights will __NOT__ be initialized:
{
missing_str
}
"
)
if
freeze_encoder
:
# Lock the encoder
for
model_part
in
(
"
down_2
"
,
"
down_16
"
,
"
bottom_2
"
,
"
bottom_16
"
):
model_part
=
getattr
(
model
,
model_part
)
model_part
.
eval
()
train_fun
=
model_part
.
train
model_part
.
train
=
types
.
MethodType
(
lambda
self
,
mode
=
False
:
train_fun
(
False
),
model_part
)
trainer
.
log
.
info
(
"
Locked encoder
"
)
# Disable gradients of the encoder
for
p
in
model_part
.
parameters
():
p
.
requires_grad
=
False
trainer
.
log
.
info
(
"
Disabled gradient computation for encoder
"
)
train_params
=
config
.
TrainParams
(
backend
=
"
torch
"
,
mode
=
"
fast_segmentation
"
,
iterations
=
iterations
,
test_interval
=
None
,
snapshot
=
100
,
data_format
=
"
NCHW
"
,
class_weights
=
"
$CLASS_WEIGHTS
"
,
save_best_weights
=
False
,
input_output_mapping
=
ds
.
InputOutputMapping
(
# Multi scale input
input
=
(
"
gray2
"
,
"
gray16
"
),
# Labels for supervised training
output
=
"
labels
"
),
num_classes
=
len
(
labels
),
distributed_learning_rate_scaling
=
True
,
optimizer_params
=
config
.
OptimizerParams
(
# Note: Learning rate will be scaled by number of GPUs (most often 4)
learning_rate_params
=
learning_rate_params
,
param_group_selectors
=
param_selectors
,
weight_decay
=
weight_decay
,
),
network_params
=
config
.
NetworkParams
(
make_network
=
make_network
,
weight_init_func
=
_weight_init_func
),
iterator_params
=
{
config
.
TRAIN_SPLIT
:
config
.
IteratorParams
(
batch_size
=
batch_size
,
data_sources
=
[
ds
.
ImageDataSource
(
name
=
"
gray2
"
,
size
=
2025
,
spacing
=
2
,
channels
=
CHANNEL_HISTOLOGY
,
data_format
=
"
NCHW
"
),
ds
.
ImageDataSource
(
name
=
"
gray16
"
,
size
=
628
,
spacing
=
16
,
channels
=
CHANNEL_HISTOLOGY
,
data_format
=
"
NCHW
"
),
ds
.
LabelDataSource
(
name
=
"
labels
"
,
size
=
68
,
spacing
=
16
,
labels
=
labels
)],
image_provider_params
=
config
.
ImageProviderParams
(
opencv_num_threads
=
4
,
numba_num_threads
=
4
,
transformation_parameters
=
config
.
TransformationParams
(
normalize_orientation
=
True
,
pre_registration_base_slice
=
sections
[
0
],
brightness_augmentation
=
True
,
brightness_augmentation_range
=
(
-
0.2
,
+
0.2
),
contrast_augmentation
=
True
,
contrast_augmentation_range
=
(
0.9
,
1.1
),
gamma_contrast_augmentation
=
True
,
gamma_contrast_augmentation_range
=
(
0.8
,
1.214
),
random_rotation
=
True
,
random_rotation_range
=
(
-
math
.
pi
/
4
,
+
math
.
pi
/
4
)
),
),
sampler_params
=
config
.
SamplerParams
(
sampler_mode
=
"
weighted
"
,
weighted_sampler_probability_weights
=
{
label
:
1
/
len
(
labels
)
for
label
in
labels
},
weighted_sampler_labels
=
labels
,
area_provider_params
=
config
.
AreaProviderParams
(
area_mode
=
"
distance
"
,
distance_area_provider_radius
=
2500
,
distance_area_provider_radius_spacing
=
2
,
distance_area_provider_radius_per_section
=
True
,
congruent_enlargement_factor
=
1.05
,
labels
=
labels
,
congruent_slices
=
congruent_slices
,
slices
=
train_slices
,
use_pre_registration_for_congruent_sections
=
True
,
congruent_mode
=
"
match
"
,
),
),
),
config
.
VAL_SPLIT
:
config
.
IteratorParams
(
batch_size
=
5
,
data_sources
=
[
ds
.
ImageDataSource
(
name
=
"
gray2
"
,
size
=
8169
,
spacing
=
2
,
channels
=
CHANNEL_HISTOLOGY
,
data_format
=
"
NCHW
"
),
ds
.
ImageDataSource
(
name
=
"
gray16
"
,
size
=
1396
,
spacing
=
16
,
channels
=
CHANNEL_HISTOLOGY
,
data_format
=
"
NCHW
"
),
ds
.
LabelDataSource
(
name
=
"
labels
"
,
size
=
836
,
spacing
=
16
,
labels
=
labels
,
create
=
False
)],
image_provider_params
=
config
.
ImageProviderParams
(
transformation_parameters
=
config
.
TransformationParams
(
normalize_orientation
=
True
,
),
),
sampler_params
=
config
.
SamplerParams
(
sampler_mode
=
"
grid
"
,
grid_sampler_output_format
=
config
.
OutputFormat
(
size
=
836
,
spacing
=
16
),
image_provider_params
=
config
.
ImageProviderParams
(
transformation_parameters
=
config
.
TransformationParams
(
normalize_orientation
=
True
,
),
),
area_provider_params
=
config
.
AreaProviderParams
(
area_mode
=
"
distance
"
,
distance_area_provider_radius
=
2500
,
distance_area_provider_radius_spacing
=
2
,
distance_area_provider_radius_per_section
=
True
,
congruent_enlargement_factor
=
1.05
,
labels
=
labels
,
slices
=
validation_sections
,
congruent_slices
=
congruent_slices
,
use_pre_registration_for_congruent_sections
=
True
,
congruent_mode
=
"
match
"
,
)
)
),
}
)
return
train_params
def
make_network
(
**
kwargs
):
from
atlas.torch.models.segmentation.unet_multiscale
import
UNet
from
atlas.torch.models.segmentation.unet_parts
import
UNetDownPathParams
,
UNetBottomPathParams
,
UNetUpPathParams
input_shape
=
kwargs
[
"
input_shape
"
]
input_spacing
=
kwargs
[
"
input_spacing
"
]
num_classes
=
kwargs
[
"
num_classes
"
]
input_channels
=
[
s
[
1
]
for
s
in
input_shape
]
model
=
UNet
(
input_channels
=
input_channels
,
input_spacings
=
input_spacing
,
num_classes
=
num_classes
,
down_params
=
[
UNetDownPathParams
(
filters
=
[
16
,
32
,
64
,
64
,
128
],
kernel_size
=
[(
5
,
3
),
3
,
3
,
3
,
3
],
stride
=
[(
4
,
1
),
1
,
1
,
1
,
1
]),
UNetDownPathParams
(
filters
=
[
16
,
32
,
64
,
64
,
128
,
],
dilation
=
[
1
,
2
,
2
,
2
,
2
])
],
bottom_params
=
[
UNetBottomPathParams
(
filters
=
[
128
,
]),
UNetBottomPathParams
(
filters
=
[
128
,
],
dilation
=
2
)
],
up_params
=
UNetUpPathParams
(
filters
=
[
128
,
64
,
64
,
32
],
upsampling_mode
=
"
upsampling
"
))
return
model
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment