93 lines
3.6 KiB
Python
93 lines
3.6 KiB
Python
from models.training import Training
|
|
from models.TrainingProjectDetails import TrainingProjectDetails
|
|
from database.database import db
|
|
|
|
def push_yolox_exp_to_db(settings):
|
|
"""Save YOLOX settings to database"""
|
|
normalized = dict(settings)
|
|
|
|
# Map common frontend aliases to DB column names
|
|
alias_map = {
|
|
'act': 'activation',
|
|
'nmsthre': 'nms_thre',
|
|
'select_model': 'selected_model'
|
|
}
|
|
for a, b in alias_map.items():
|
|
if a in normalized and b not in normalized:
|
|
normalized[b] = normalized.pop(a)
|
|
|
|
# Convert 'on'/'off' or 'true'/'false' strings to boolean for known boolean fields
|
|
for bool_field in ['save_history_ckpt', 'ema', 'enable_mixup']:
|
|
if bool_field in normalized:
|
|
val = normalized[bool_field]
|
|
if isinstance(val, str):
|
|
normalized[bool_field] = val.lower() in ('1', 'true', 'on')
|
|
else:
|
|
normalized[bool_field] = bool(val)
|
|
|
|
# Convert comma-separated strings to arrays for JSON fields
|
|
for key in ['input_size', 'test_size', 'mosaic_scale', 'mixup_scale']:
|
|
if key in normalized and isinstance(normalized[key], str):
|
|
parts = [p.strip() for p in normalized[key].split(',') if p.strip()]
|
|
try:
|
|
arr = [float(p) for p in parts]
|
|
except Exception:
|
|
arr = parts
|
|
normalized[key] = arr[0] if len(arr) == 1 else arr
|
|
|
|
# Ensure we have a TrainingProjectDetails row for project_id
|
|
project_id = normalized.get('project_id')
|
|
if not project_id:
|
|
raise Exception('Missing project_id in settings')
|
|
details = TrainingProjectDetails.query.filter_by(project_id=project_id).first()
|
|
if not details:
|
|
raise Exception(f'TrainingProjectDetails not found for project_id {project_id}')
|
|
normalized['project_details_id'] = details.id
|
|
|
|
# Filter normalized to only columns that exist on the Training model
|
|
valid_cols = {c.name: c for c in Training.__table__.columns}
|
|
filtered = {}
|
|
for k, v in normalized.items():
|
|
if k in valid_cols:
|
|
col_type = valid_cols[k].type.__class__.__name__
|
|
# Try to coerce types for numeric/boolean columns
|
|
try:
|
|
if 'Integer' in col_type:
|
|
if v is None or v == '':
|
|
filtered[k] = None
|
|
else:
|
|
filtered[k] = int(float(v))
|
|
elif 'Float' in col_type:
|
|
if v is None or v == '':
|
|
filtered[k] = None
|
|
else:
|
|
filtered[k] = float(v)
|
|
elif 'Boolean' in col_type:
|
|
if isinstance(v, str):
|
|
filtered[k] = v.lower() in ('1', 'true', 'on')
|
|
else:
|
|
filtered[k] = bool(v)
|
|
elif 'JSON' in col_type:
|
|
filtered[k] = v
|
|
elif 'LargeBinary' in col_type:
|
|
# If a file path was passed, store its bytes; otherwise store raw bytes
|
|
if isinstance(v, str):
|
|
try:
|
|
filtered[k] = v.encode('utf-8')
|
|
except Exception:
|
|
filtered[k] = None
|
|
else:
|
|
filtered[k] = v
|
|
else:
|
|
filtered[k] = v
|
|
except Exception:
|
|
# If conversion fails, just assign raw value
|
|
filtered[k] = v
|
|
|
|
# Create DB row
|
|
training = Training(**filtered)
|
|
db.session.add(training)
|
|
db.session.commit()
|
|
|
|
return training
|