cleanup add training bell
This commit is contained in:
68
backend/services/generate_yolox_exp.py
Normal file → Executable file
68
backend/services/generate_yolox_exp.py
Normal file → Executable file
@@ -220,6 +220,10 @@ def generate_yolox_inference_exp(training_id, options=None, use_base_config=Fals
|
||||
annotations_parent_dir = os.path.join(output_base_path, project_name, training_folder_name)
|
||||
annotations_parent_escaped = annotations_parent_dir.replace('\\', '\\\\')
|
||||
|
||||
# Set output directory for checkpoints - models subdirectory
|
||||
models_dir = os.path.join(annotations_parent_dir, 'models')
|
||||
models_dir_escaped = models_dir.replace('\\', '\\\\')
|
||||
|
||||
# Build exp content
|
||||
exp_content = f'''#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
@@ -235,6 +239,7 @@ class Exp(MyExp):
|
||||
super(Exp, self).__init__()
|
||||
self.data_dir = "{data_dir_escaped}" # Where images are located
|
||||
self.annotations_dir = "{annotations_parent_escaped}" # Where annotation JSONs are located
|
||||
self.output_dir = "{models_dir_escaped}" # Where checkpoints will be saved
|
||||
self.train_ann = "{train_ann}"
|
||||
self.val_ann = "{val_ann}"
|
||||
self.test_ann = "{test_ann}"
|
||||
@@ -252,21 +257,46 @@ class Exp(MyExp):
|
||||
if selected_model:
|
||||
exp_content += f" self.pretrained_ckpt = r'{yolox_base_dir}/pretrained/{selected_model}.pth'\n"
|
||||
|
||||
# Format arrays
|
||||
def format_value(val):
|
||||
# Format arrays and values for Python code generation
|
||||
# Integer-only parameters (sizes, epochs, intervals)
|
||||
integer_params = {
|
||||
'input_size', 'test_size', 'random_size', 'max_epoch', 'warmup_epochs',
|
||||
'no_aug_epochs', 'print_interval', 'eval_interval', 'multiscale_range',
|
||||
'data_num_workers', 'num_classes'
|
||||
}
|
||||
|
||||
def format_value(val, param_name=''):
|
||||
if isinstance(val, (list, tuple)):
|
||||
return '(' + ', '.join(map(str, val)) + ')'
|
||||
# Check if this parameter should have integer values
|
||||
if param_name in integer_params:
|
||||
# Convert all values to integers
|
||||
formatted_items = [str(int(float(item))) if isinstance(item, (int, float)) else str(item) for item in val]
|
||||
else:
|
||||
# Keep as floats or original type
|
||||
formatted_items = []
|
||||
for item in val:
|
||||
if isinstance(item, float):
|
||||
formatted_items.append(str(item))
|
||||
elif isinstance(item, int):
|
||||
formatted_items.append(str(item))
|
||||
else:
|
||||
formatted_items.append(str(item))
|
||||
return '(' + ', '.join(formatted_items) + ')'
|
||||
elif isinstance(val, bool):
|
||||
return str(val)
|
||||
elif isinstance(val, str):
|
||||
return f'"{val}"'
|
||||
elif isinstance(val, int):
|
||||
return str(val)
|
||||
elif isinstance(val, float):
|
||||
return str(val)
|
||||
else:
|
||||
return str(val)
|
||||
|
||||
# Add all config parameters to exp
|
||||
for key, value in config.items():
|
||||
if key not in ['exp_name']: # exp_name is handled separately
|
||||
exp_content += f" self.{key} = {format_value(value)}\n"
|
||||
exp_content += f" self.{key} = {format_value(value, key)}\n"
|
||||
|
||||
# Add get_dataset override using name parameter for image directory
|
||||
exp_content += '''
|
||||
@@ -289,7 +319,7 @@ class Exp(MyExp):
|
||||
|
||||
def get_eval_dataset(self, **kwargs):
|
||||
"""Override eval dataset using name parameter"""
|
||||
from yolox.data import COCODataset
|
||||
from yolox.data import COCODataset, ValTransform
|
||||
|
||||
testdev = kwargs.get("testdev", False)
|
||||
legacy = kwargs.get("legacy", False)
|
||||
@@ -299,8 +329,34 @@ class Exp(MyExp):
|
||||
json_file=self.val_ann if not testdev else self.test_ann,
|
||||
name="",
|
||||
img_size=self.test_size,
|
||||
preproc=None, # No preprocessing for evaluation
|
||||
preproc=ValTransform(legacy=legacy), # Use proper validation transform
|
||||
)
|
||||
|
||||
def get_eval_loader(self, batch_size, is_distributed, **kwargs):
|
||||
"""Standard YOLOX eval loader - matches official implementation"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
valdataset = self.get_eval_dataset(**kwargs)
|
||||
|
||||
if is_distributed:
|
||||
batch_size = batch_size // dist.get_world_size()
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
valdataset, shuffle=False
|
||||
)
|
||||
else:
|
||||
sampler = torch.utils.data.SequentialSampler(valdataset)
|
||||
|
||||
dataloader_kwargs = {
|
||||
"num_workers": self.data_num_workers,
|
||||
"pin_memory": True,
|
||||
"sampler": sampler,
|
||||
}
|
||||
dataloader_kwargs["batch_size"] = batch_size
|
||||
val_loader = DataLoader(valdataset, **dataloader_kwargs)
|
||||
|
||||
return val_loader
|
||||
'''
|
||||
|
||||
# Add exp_name at the end (uses dynamic path)
|
||||
|
||||
Reference in New Issue
Block a user