853 lines
35 KiB
Python
853 lines
35 KiB
Python
from flask import Blueprint, request, jsonify, send_file
|
|
from werkzeug.utils import secure_filename
|
|
import os
|
|
import json
|
|
import subprocess
|
|
from database.database import db
|
|
from models.TrainingProject import TrainingProject
|
|
from models.TrainingProjectDetails import TrainingProjectDetails
|
|
from models.training import Training
|
|
from models.LabelStudioProject import LabelStudioProject
|
|
from models.Images import Image
|
|
from models.Annotation import Annotation
|
|
|
|
api_bp = Blueprint('api', __name__)
|
|
|
|
# Global update status (similar to Node.js version)
|
|
update_status = {"running": False}
|
|
|
|
@api_bp.route('/seed', methods=['GET'])
|
|
def seed():
|
|
"""Trigger seeding from Label Studio"""
|
|
from services.seed_label_studio import seed_label_studio
|
|
result = seed_label_studio()
|
|
return jsonify(result)
|
|
|
|
@api_bp.route('/generate-yolox-json', methods=['POST'])
|
|
def generate_yolox_json():
|
|
"""Generate YOLOX JSON and exp.py for a project"""
|
|
try:
|
|
data = request.get_json()
|
|
project_id = data.get('project_id')
|
|
|
|
if not project_id:
|
|
return jsonify({'message': 'Missing project_id in request body'}), 400
|
|
|
|
# Find all TrainingProjectDetails for this project
|
|
details_rows = TrainingProjectDetails.query.filter_by(project_id=project_id).all()
|
|
|
|
if not details_rows:
|
|
return jsonify({'message': f'No TrainingProjectDetails found for project {project_id}'}), 404
|
|
|
|
# Get project name
|
|
training_project = TrainingProject.query.get(project_id)
|
|
project_name = training_project.title.replace(' ', '_') if training_project.title else f'project_{project_id}'
|
|
|
|
from services.generate_json_yolox import generate_training_json
|
|
from services.generate_yolox_exp import save_yolox_exp
|
|
|
|
# For each details row, generate coco.jsons and exp.py
|
|
for details in details_rows:
|
|
details_id = details.id
|
|
generate_training_json(details_id)
|
|
|
|
# Find all trainings for this details row
|
|
trainings = Training.query.filter_by(project_details_id=details_id).all()
|
|
if not trainings:
|
|
continue
|
|
|
|
# Create output directory
|
|
out_dir = os.path.join(os.path.dirname(__file__), '..', project_name, str(details_id))
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
# Save exp.py for each training
|
|
for training in trainings:
|
|
exp_file_path = os.path.join(out_dir, 'exp.py')
|
|
save_yolox_exp(training.id, exp_file_path)
|
|
|
|
return jsonify({'message': f'YOLOX JSON and exp.py generated for project {project_id}'})
|
|
|
|
except Exception as err:
|
|
print(f'Error generating YOLOX JSON: {err}')
|
|
return jsonify({'message': 'Failed to generate YOLOX JSON', 'error': str(err)}), 500
|
|
|
|
@api_bp.route('/start-yolox-training', methods=['POST'])
|
|
def start_yolox_training():
|
|
"""Generate JSONs, exp.py, and start YOLOX training"""
|
|
try:
|
|
data = request.get_json()
|
|
project_id = data.get('project_id')
|
|
training_id = data.get('training_id')
|
|
|
|
if not project_id or not training_id:
|
|
return jsonify({'message': 'Missing project_id or training_id'}), 400
|
|
|
|
# Get training record
|
|
training = Training.query.get(training_id)
|
|
if not training:
|
|
return jsonify({'message': f'Training {training_id} not found'}), 404
|
|
|
|
details_id = training.project_details_id
|
|
|
|
# Step 1: Generate COCO JSON files
|
|
from services.generate_json_yolox import generate_training_json
|
|
print(f'Generating COCO JSON for training {training_id}...')
|
|
generate_training_json(details_id)
|
|
|
|
# Step 2: Generate exp.py
|
|
from services.generate_yolox_exp import save_yolox_exp
|
|
from services.settings_service import get_setting
|
|
|
|
training_project = TrainingProject.query.get(project_id)
|
|
project_name = training_project.title.replace(' ', '_') if training_project and training_project.title else f'project_{project_id}'
|
|
|
|
# Use training name + id for folder to support multiple trainings per project
|
|
training_folder_name = f"{training.exp_name or training.training_name or 'training'}_{training_id}"
|
|
training_folder_name = training_folder_name.replace(' ', '_')
|
|
|
|
output_base_path = get_setting('yolox_output_path', './backend')
|
|
out_dir = os.path.join(output_base_path, project_name, training_folder_name)
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
exp_file_path = os.path.join(out_dir, 'exp.py')
|
|
print(f'Generating exp.py at {exp_file_path}...')
|
|
save_yolox_exp(training_id, exp_file_path)
|
|
|
|
# Step 3: Start training
|
|
print(f'Starting YOLOX training for training {training_id}...')
|
|
|
|
# Get YOLOX configuration from settings
|
|
yolox_main_dir = get_setting('yolox_path', '/home/kitraining/Yolox/YOLOX-main')
|
|
yolox_venv = get_setting('yolox_venv_path', '/home/kitraining/Yolox/yolox_venv/bin/activate')
|
|
|
|
# Detect platform and build appropriate command
|
|
import platform
|
|
is_windows = platform.system() == 'Windows'
|
|
|
|
# Determine model argument
|
|
model_arg = ''
|
|
|
|
if (training.transfer_learning and
|
|
isinstance(training.transfer_learning, str) and
|
|
training.transfer_learning.lower() == 'coco'):
|
|
model_arg = f'-c {yolox_main_dir}/pretrained/{training.selected_model}.pth'
|
|
elif (training.selected_model and
|
|
training.selected_model.lower() == 'coco' and
|
|
(not training.transfer_learning or training.transfer_learning == False)):
|
|
model_arg = f'-c {yolox_main_dir}/pretrained/{training.selected_model}.pth'
|
|
|
|
# Build base training arguments
|
|
train_args = f'-f {exp_file_path} -d 1 -b 8 --fp16 --cache'
|
|
if model_arg:
|
|
train_args += f' {model_arg} -o'
|
|
|
|
# Build platform-specific command
|
|
if is_windows:
|
|
# Windows: Use call to activate venv, then run python
|
|
# If venv path doesn't end with .bat, assume it needs Scripts\activate.bat
|
|
if not yolox_venv.endswith('.bat'):
|
|
venv_activate = os.path.join(yolox_venv, 'Scripts', 'activate.bat')
|
|
else:
|
|
venv_activate = yolox_venv
|
|
cmd = f'cmd /c ""{venv_activate}" && python tools\\train.py {train_args}"'
|
|
else:
|
|
# Linux: Use bash with source
|
|
cmd = f'bash -c "source {yolox_venv} && python tools/train.py {train_args}"'
|
|
|
|
print(f'Training command: {cmd}')
|
|
|
|
# Start training in background
|
|
subprocess.Popen(cmd, shell=True, cwd=yolox_main_dir)
|
|
|
|
return jsonify({
|
|
'message': f'JSONs and exp.py generated, training started for training {training_id}',
|
|
'exp_path': exp_file_path
|
|
})
|
|
|
|
except Exception as err:
|
|
print(f'Error starting YOLOX training: {err}')
|
|
import traceback
|
|
traceback.print_exc()
|
|
return jsonify({'message': 'Failed to start training', 'error': str(err)}), 500
|
|
|
|
@api_bp.route('/training-log', methods=['GET'])
|
|
def training_log():
|
|
"""Get YOLOX training log"""
|
|
try:
|
|
project_id = request.args.get('project_id')
|
|
training_id = request.args.get('training_id')
|
|
|
|
training_project = TrainingProject.query.get(project_id)
|
|
project_name = training_project.title.replace(' ', '_') if training_project.title else f'project_{project_id}'
|
|
|
|
out_dir = os.path.join(os.path.dirname(__file__), '..', project_name, str(training_id))
|
|
log_path = os.path.join(out_dir, 'training.log')
|
|
|
|
if not os.path.exists(log_path):
|
|
return jsonify({'error': 'Log not found'}), 404
|
|
|
|
with open(log_path, 'r') as f:
|
|
log_data = f.read()
|
|
|
|
return jsonify({'log': log_data})
|
|
|
|
except Exception as err:
|
|
return jsonify({'error': 'Failed to fetch log', 'details': str(err)}), 500
|
|
|
|
@api_bp.route('/training-projects', methods=['POST'])
|
|
def create_training_project():
|
|
"""Create a new training project"""
|
|
try:
|
|
from models.ProjectClass import ProjectClass
|
|
|
|
title = request.form.get('title')
|
|
description = request.form.get('description')
|
|
classes = json.loads(request.form.get('classes', '[]'))
|
|
|
|
project_image = None
|
|
project_image_type = None
|
|
|
|
if 'project_image' in request.files:
|
|
file = request.files['project_image']
|
|
project_image = file.read()
|
|
project_image_type = file.content_type
|
|
|
|
# Create project without classes field
|
|
project = TrainingProject(
|
|
title=title,
|
|
description=description,
|
|
project_image=project_image,
|
|
project_image_type=project_image_type
|
|
)
|
|
|
|
db.session.add(project)
|
|
db.session.flush() # Get project_id before commit
|
|
|
|
# Add classes to ProjectClass table
|
|
for index, class_name in enumerate(classes):
|
|
project_class = ProjectClass(
|
|
project_id=project.project_id,
|
|
class_name=class_name,
|
|
display_order=index
|
|
)
|
|
db.session.add(project_class)
|
|
|
|
db.session.commit()
|
|
|
|
return jsonify({'message': 'Project created!'})
|
|
|
|
except Exception as error:
|
|
print(f'Error creating project: {error}')
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to create project', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/training-projects', methods=['GET'])
|
|
def get_training_projects():
|
|
"""Get all training projects"""
|
|
try:
|
|
projects = TrainingProject.query.all()
|
|
serialized = [project.to_dict() for project in projects]
|
|
return jsonify(serialized)
|
|
|
|
except Exception as error:
|
|
return jsonify({'message': 'Failed to fetch projects', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/update-status', methods=['GET'])
|
|
def get_update_status():
|
|
"""Get update status"""
|
|
return jsonify(update_status)
|
|
|
|
@api_bp.route('/label-studio-projects', methods=['GET'])
|
|
def get_label_studio_projects():
|
|
"""Get all Label Studio projects with annotation counts"""
|
|
try:
|
|
from sqlalchemy import func
|
|
|
|
# Get all projects
|
|
label_studio_projects = LabelStudioProject.query.all()
|
|
|
|
# Get annotation counts in one query using SQL aggregation
|
|
annotation_counts_query = db.session.query(
|
|
Image.project_id,
|
|
Annotation.Label,
|
|
func.count(Annotation.annotation_id).label('count')
|
|
).join(
|
|
Annotation, Image.image_id == Annotation.image_id
|
|
).group_by(
|
|
Image.project_id, Annotation.Label
|
|
).all()
|
|
|
|
# Organize counts by project_id
|
|
counts_by_project = {}
|
|
for project_id, label, count in annotation_counts_query:
|
|
if project_id not in counts_by_project:
|
|
counts_by_project[project_id] = {}
|
|
counts_by_project[project_id][label] = count
|
|
|
|
# Build result
|
|
projects_with_counts = []
|
|
for project in label_studio_projects:
|
|
project_dict = project.to_dict()
|
|
project_dict['annotationCounts'] = counts_by_project.get(project.project_id, {})
|
|
projects_with_counts.append(project_dict)
|
|
|
|
return jsonify(projects_with_counts)
|
|
|
|
except Exception as error:
|
|
return jsonify({'message': 'Failed to fetch projects', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/training-project-details', methods=['POST'])
|
|
def create_training_project_details():
|
|
"""Create TrainingProjectDetails"""
|
|
try:
|
|
from models.AnnotationProjectMapping import AnnotationProjectMapping
|
|
from models.ClassMapping import ClassMapping
|
|
|
|
data = request.get_json()
|
|
project_id = data.get('project_id')
|
|
annotation_projects = data.get('annotation_projects') # Array of project IDs
|
|
class_map = data.get('class_map') # Dict: {source: target}
|
|
description = data.get('description')
|
|
|
|
if not project_id or annotation_projects is None:
|
|
return jsonify({'message': 'Missing required fields'}), 400
|
|
|
|
# Create TrainingProjectDetails without JSON fields
|
|
details = TrainingProjectDetails(
|
|
project_id=project_id,
|
|
description_text=description
|
|
)
|
|
|
|
db.session.add(details)
|
|
db.session.flush() # Get details.id
|
|
|
|
# Add annotation project mappings
|
|
for ls_project_id in annotation_projects:
|
|
mapping = AnnotationProjectMapping(
|
|
project_details_id=details.id,
|
|
label_studio_project_id=ls_project_id
|
|
)
|
|
db.session.add(mapping)
|
|
|
|
# Add class mappings if provided
|
|
if class_map:
|
|
for source_class, target_class in class_map.items():
|
|
# For initial creation, we don't have per-project mappings yet
|
|
# Will be replaced when user sets up mappings in UI
|
|
mapping = ClassMapping(
|
|
project_details_id=details.id,
|
|
label_studio_project_id=annotation_projects[0] if annotation_projects else 0,
|
|
source_class=source_class,
|
|
target_class=target_class
|
|
)
|
|
db.session.add(mapping)
|
|
|
|
db.session.commit()
|
|
db.session.commit()
|
|
|
|
return jsonify({'message': 'TrainingProjectDetails created', 'details': details.to_dict()})
|
|
|
|
except Exception as error:
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to create TrainingProjectDetails', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/training-project-details', methods=['GET'])
|
|
def get_training_project_details():
|
|
"""Get all TrainingProjectDetails"""
|
|
try:
|
|
details = TrainingProjectDetails.query.all()
|
|
result = []
|
|
for d in details:
|
|
try:
|
|
result.append(d.to_dict())
|
|
except Exception as e:
|
|
print(f'Error serializing detail {d.id}: {e}')
|
|
# Return basic info if full serialization fails
|
|
result.append({
|
|
'id': d.id,
|
|
'project_id': d.project_id,
|
|
'description': d.description_text,
|
|
'annotation_projects': [],
|
|
'class_map': {}
|
|
})
|
|
return jsonify(result)
|
|
|
|
except Exception as error:
|
|
print(f'Error fetching training project details: {error}')
|
|
return jsonify({'message': 'Failed to fetch TrainingProjectDetails', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/training-project-details', methods=['PUT'])
|
|
def update_training_project_details():
|
|
"""Update class_map and description in TrainingProjectDetails"""
|
|
try:
|
|
from models.ClassMapping import ClassMapping
|
|
|
|
data = request.get_json()
|
|
print(f'[DEBUG] Received PUT data: {data}')
|
|
|
|
project_id = data.get('project_id')
|
|
class_map_data = data.get('class_map') # Array: [[labelStudioProjectId, [[class, target], ...]], ...]
|
|
description_data = data.get('description') # Array: [[projectId, desc], ...]
|
|
|
|
print(f'[DEBUG] project_id: {project_id}')
|
|
print(f'[DEBUG] class_map_data: {class_map_data}')
|
|
print(f'[DEBUG] description_data: {description_data}')
|
|
|
|
if not project_id or class_map_data is None or description_data is None:
|
|
return jsonify({'message': 'Missing required fields'}), 400
|
|
|
|
details = TrainingProjectDetails.query.filter_by(project_id=project_id).first()
|
|
|
|
if not details:
|
|
return jsonify({'message': 'TrainingProjectDetails not found'}), 404
|
|
|
|
# Update description - combine all descriptions
|
|
combined_description = '\n\n'.join([desc[1] for desc in description_data if len(desc) > 1 and desc[1]])
|
|
details.description_text = combined_description
|
|
|
|
# Delete existing class mappings
|
|
ClassMapping.query.filter_by(project_details_id=details.id).delete()
|
|
|
|
# Add new class mappings - iterate through all label studio projects
|
|
# class_map_data format: [[labelStudioProjectId, [[class, target], ...]], ...]
|
|
for project_mapping in class_map_data:
|
|
if len(project_mapping) >= 2:
|
|
label_studio_project_id = project_mapping[0]
|
|
class_mappings = project_mapping[1] # [[class1, target1], [class2, target2], ...]
|
|
|
|
for class_pair in class_mappings:
|
|
if len(class_pair) >= 2:
|
|
source_class = class_pair[0]
|
|
target_class = class_pair[1]
|
|
|
|
# Create mapping with label_studio_project_id
|
|
mapping = ClassMapping(
|
|
project_details_id=details.id,
|
|
label_studio_project_id=label_studio_project_id,
|
|
source_class=source_class,
|
|
target_class=target_class
|
|
)
|
|
db.session.add(mapping)
|
|
|
|
db.session.commit()
|
|
|
|
return jsonify({'message': 'Class map and description updated', 'details': details.to_dict()})
|
|
|
|
except Exception as error:
|
|
db.session.rollback()
|
|
print(f'[ERROR] Failed to update training project details: {error}')
|
|
import traceback
|
|
traceback.print_exc()
|
|
return jsonify({'message': 'Failed to update class map or description', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/yolox-settings', methods=['POST'])
|
|
def yolox_settings():
|
|
"""Receive YOLOX settings and save to DB"""
|
|
try:
|
|
settings = request.form.to_dict()
|
|
|
|
print('--- YOLOX settings received ---')
|
|
print('settings:', settings)
|
|
|
|
# Map select_model to selected_model if present
|
|
if 'select_model' in settings and 'selected_model' not in settings:
|
|
settings['selected_model'] = settings['select_model']
|
|
del settings['select_model']
|
|
|
|
# Lookup or create project_details_id
|
|
if not settings.get('project_id') or not settings['project_id'].isdigit():
|
|
raise ValueError('Missing or invalid project_id in request')
|
|
|
|
project_id = int(settings['project_id'])
|
|
details = TrainingProjectDetails.query.filter_by(project_id=project_id).first()
|
|
|
|
if not details:
|
|
# Create TrainingProjectDetails without JSON fields
|
|
details = TrainingProjectDetails(
|
|
project_id=project_id,
|
|
description_text=None
|
|
)
|
|
db.session.add(details)
|
|
db.session.flush() # Get details.id
|
|
|
|
settings['project_details_id'] = details.id
|
|
|
|
# Map 'act' to 'activation'
|
|
if 'act' in settings:
|
|
settings['activation'] = settings['act']
|
|
del settings['act']
|
|
|
|
# Type conversions
|
|
numeric_fields = [
|
|
'max_epoch', 'depth', 'width', 'warmup_epochs', 'warmup_lr',
|
|
'no_aug_epochs', 'min_lr_ratio', 'weight_decay', 'momentum',
|
|
'print_interval', 'eval_interval', 'test_conf', 'nmsthre',
|
|
'multiscale_range', 'degrees', 'translate', 'shear',
|
|
'train', 'valid', 'test'
|
|
]
|
|
|
|
for field in numeric_fields:
|
|
if field in settings:
|
|
settings[field] = float(settings[field])
|
|
|
|
# Boolean conversions
|
|
boolean_fields = ['ema', 'enable_mixup', 'save_history_ckpt']
|
|
for field in boolean_fields:
|
|
if field in settings:
|
|
if isinstance(settings[field], str):
|
|
settings[field] = settings[field].lower() == 'true'
|
|
else:
|
|
settings[field] = bool(settings[field])
|
|
|
|
# Array conversions
|
|
array_fields = ['mosaic_scale', 'mixup_scale', 'scale']
|
|
for field in array_fields:
|
|
if field in settings and isinstance(settings[field], str):
|
|
settings[field] = [float(x.strip()) for x in settings[field].split(',') if x.strip()]
|
|
|
|
# Trim string fields
|
|
for key in settings:
|
|
if isinstance(settings[key], str):
|
|
settings[key] = settings[key].strip()
|
|
|
|
# Default for transfer_learning
|
|
if 'transfer_learning' not in settings:
|
|
settings['transfer_learning'] = False
|
|
|
|
# Convert empty seed to None
|
|
if 'seed' in settings and (settings['seed'] == '' or settings['seed'] is None):
|
|
settings['seed'] = None
|
|
|
|
# Validate required fields
|
|
required_fields = [
|
|
'project_details_id', 'exp_name', 'max_epoch', 'depth', 'width',
|
|
'activation', 'train', 'valid', 'test', 'selected_model', 'transfer_learning'
|
|
]
|
|
|
|
for field in required_fields:
|
|
if field not in settings or settings[field] in [None, '']:
|
|
raise ValueError(f'Missing required field: {field}')
|
|
|
|
print('Received YOLOX settings:', settings)
|
|
|
|
# Handle uploaded model file
|
|
if 'ckpt_upload' in request.files:
|
|
file = request.files['ckpt_upload']
|
|
upload_dir = os.path.join(os.path.dirname(__file__), '..', 'uploads')
|
|
os.makedirs(upload_dir, exist_ok=True)
|
|
filename = file.filename or f'uploaded_model_{project_id}.pth'
|
|
file_path = os.path.join(upload_dir, filename)
|
|
file.save(file_path)
|
|
settings['model_upload'] = file_path
|
|
|
|
# Save to DB
|
|
from services.push_yolox_exp import push_yolox_exp_to_db
|
|
training = push_yolox_exp_to_db(settings)
|
|
|
|
return jsonify({'message': 'YOLOX settings saved to DB', 'training': training.to_dict()})
|
|
|
|
except Exception as error:
|
|
print(f'Error in /api/yolox-settings: {error}')
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to save YOLOX settings', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/yolox-settings/upload', methods=['POST'])
|
|
def yolox_settings_upload():
|
|
"""Upload binary model file"""
|
|
try:
|
|
project_id = request.args.get('project_id')
|
|
if not project_id:
|
|
return jsonify({'message': 'Missing project_id in query'}), 400
|
|
|
|
# Save file to disk
|
|
upload_dir = os.path.join(os.path.dirname(__file__), '..', 'uploads')
|
|
os.makedirs(upload_dir, exist_ok=True)
|
|
|
|
filename = request.headers.get('x-upload-filename', f'uploaded_model_{project_id}.pth')
|
|
file_path = os.path.join(upload_dir, filename)
|
|
|
|
# Read binary data
|
|
with open(file_path, 'wb') as f:
|
|
f.write(request.data)
|
|
|
|
# Update latest training row
|
|
details = TrainingProjectDetails.query.filter_by(project_id=project_id).first()
|
|
if not details:
|
|
return jsonify({'message': 'No TrainingProjectDetails found for project_id'}), 404
|
|
|
|
training = Training.query.filter_by(project_details_id=details.id).order_by(Training.id.desc()).first()
|
|
if not training:
|
|
return jsonify({'message': 'No training found for project_id'}), 404
|
|
|
|
training.model_upload = file_path
|
|
db.session.commit()
|
|
|
|
return jsonify({
|
|
'message': 'Model file uploaded and saved to disk',
|
|
'filename': filename,
|
|
'trainingId': training.id
|
|
})
|
|
|
|
except Exception as error:
|
|
print(f'Error in /api/yolox-settings/upload: {error}')
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to upload model file', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/trainings', methods=['GET'])
|
|
def get_trainings():
|
|
"""Get all trainings (optionally filtered by project_id)"""
|
|
try:
|
|
project_id = request.args.get('project_id')
|
|
|
|
if project_id:
|
|
# Find all details rows for this project
|
|
details_rows = TrainingProjectDetails.query.filter_by(project_id=project_id).all()
|
|
if not details_rows:
|
|
return jsonify([])
|
|
|
|
# Get all trainings linked to any details row for this project
|
|
details_ids = [d.id for d in details_rows]
|
|
trainings = Training.query.filter(Training.project_details_id.in_(details_ids)).all()
|
|
return jsonify([t.to_dict() for t in trainings])
|
|
else:
|
|
# Return all trainings
|
|
trainings = Training.query.all()
|
|
return jsonify([t.to_dict() for t in trainings])
|
|
|
|
except Exception as error:
|
|
return jsonify({'message': 'Failed to fetch trainings', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/trainings/<int:id>', methods=['DELETE'])
|
|
def delete_training(id):
|
|
"""Delete a training by id"""
|
|
try:
|
|
training = Training.query.get(id)
|
|
if training:
|
|
db.session.delete(training)
|
|
db.session.commit()
|
|
return jsonify({'message': 'Training deleted'})
|
|
else:
|
|
return jsonify({'message': 'Training not found'}), 404
|
|
|
|
except Exception as error:
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to delete training', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/training-projects/<int:id>', methods=['DELETE'])
|
|
def delete_training_project(id):
|
|
"""Delete a training project and all related entries"""
|
|
try:
|
|
# Find details rows for this project
|
|
details_rows = TrainingProjectDetails.query.filter_by(project_id=id).all()
|
|
details_ids = [d.id for d in details_rows]
|
|
|
|
# Delete all trainings linked to these details
|
|
if details_ids:
|
|
Training.query.filter(Training.project_details_id.in_(details_ids)).delete(synchronize_session=False)
|
|
TrainingProjectDetails.query.filter_by(project_id=id).delete()
|
|
|
|
# Delete the project itself
|
|
project = TrainingProject.query.get(id)
|
|
if project:
|
|
db.session.delete(project)
|
|
db.session.commit()
|
|
return jsonify({'message': 'Training project and all related entries deleted'})
|
|
else:
|
|
return jsonify({'message': 'Training project not found'}), 404
|
|
|
|
except Exception as error:
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to delete training project', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/base-config/<model_name>', methods=['GET'])
|
|
def get_base_config(model_name):
|
|
"""Get base configuration for a specific YOLOX model"""
|
|
try:
|
|
from services.generate_yolox_exp import load_base_config
|
|
config = load_base_config(model_name)
|
|
return jsonify(config)
|
|
except Exception as error:
|
|
return jsonify({'message': f'Failed to load base config for {model_name}', 'error': str(error)}), 404
|
|
|
|
# Settings endpoints
|
|
@api_bp.route('/settings', methods=['GET'])
|
|
def get_settings():
|
|
"""Get all settings"""
|
|
from services.settings_service import get_all_settings_detailed
|
|
settings = get_all_settings_detailed()
|
|
return jsonify(settings)
|
|
|
|
@api_bp.route('/settings/<key>', methods=['GET'])
|
|
def get_setting(key):
|
|
"""Get a specific setting"""
|
|
from services.settings_service import get_setting as get_setting_value
|
|
from models.Settings import Settings
|
|
|
|
setting = Settings.query.filter_by(key=key).first()
|
|
if setting:
|
|
return jsonify(setting.to_dict())
|
|
else:
|
|
return jsonify({'message': f'Setting {key} not found'}), 404
|
|
|
|
@api_bp.route('/settings', methods=['POST'])
|
|
def update_settings():
|
|
"""Update multiple settings"""
|
|
try:
|
|
data = request.get_json()
|
|
from services.settings_service import set_setting
|
|
|
|
for key, value in data.items():
|
|
set_setting(key, value)
|
|
|
|
return jsonify({'message': 'Settings updated successfully'})
|
|
except Exception as error:
|
|
return jsonify({'message': 'Failed to update settings', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/settings/<key>', methods=['PUT'])
|
|
def update_setting(key):
|
|
"""Update a specific setting"""
|
|
try:
|
|
data = request.get_json()
|
|
value = data.get('value')
|
|
description = data.get('description')
|
|
|
|
from services.settings_service import set_setting
|
|
setting = set_setting(key, value, description)
|
|
|
|
return jsonify(setting.to_dict())
|
|
except Exception as error:
|
|
return jsonify({'message': f'Failed to update setting {key}', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/settings/test/labelstudio', methods=['POST'])
|
|
def test_labelstudio_connection():
|
|
"""Test Label Studio connection"""
|
|
try:
|
|
data = request.get_json()
|
|
api_url = data.get('api_url')
|
|
api_token = data.get('api_token')
|
|
|
|
if not api_url or not api_token:
|
|
return jsonify({'success': False, 'message': 'Missing api_url or api_token'}), 400
|
|
|
|
import requests
|
|
response = requests.get(
|
|
f'{api_url}/projects/',
|
|
headers={'Authorization': f'Token {api_token}'},
|
|
timeout=5
|
|
)
|
|
|
|
if response.ok:
|
|
projects = response.json()
|
|
return jsonify({
|
|
'success': True,
|
|
'message': f'Connection successful! Found {len(projects.get("results", projects))} projects.'
|
|
})
|
|
else:
|
|
return jsonify({
|
|
'success': False,
|
|
'message': f'Connection failed: {response.status_code} {response.reason}'
|
|
}), 400
|
|
|
|
except requests.exceptions.Timeout:
|
|
return jsonify({'success': False, 'message': 'Connection timeout'}), 400
|
|
except requests.exceptions.ConnectionError:
|
|
return jsonify({'success': False, 'message': 'Cannot connect to Label Studio'}), 400
|
|
except Exception as error:
|
|
return jsonify({'success': False, 'message': str(error)}), 500
|
|
|
|
@api_bp.route('/settings/test/yolox', methods=['POST'])
|
|
def test_yolox_path():
|
|
"""Test YOLOX path and venv path validity"""
|
|
try:
|
|
data = request.get_json()
|
|
yolox_path = data.get('yolox_path')
|
|
yolox_venv_path = data.get('yolox_venv_path')
|
|
|
|
if not yolox_path:
|
|
return jsonify({'success': False, 'message': 'Missing yolox_path'}), 400
|
|
|
|
# Check if YOLOX path exists
|
|
if not os.path.exists(yolox_path):
|
|
return jsonify({'success': False, 'message': 'YOLOX path does not exist'}), 400
|
|
|
|
# Check for key YOLOX files/directories
|
|
required_items = ['yolox', 'exps', 'tools']
|
|
found_items = []
|
|
missing_items = []
|
|
|
|
for item in required_items:
|
|
item_path = os.path.join(yolox_path, item)
|
|
if os.path.exists(item_path):
|
|
found_items.append(item)
|
|
else:
|
|
missing_items.append(item)
|
|
|
|
if len(found_items) < 2: # At least 2 out of 3 key items required
|
|
return jsonify({
|
|
'success': False,
|
|
'message': f'Invalid YOLOX path. Missing: {", ".join(missing_items)}',
|
|
'found': found_items,
|
|
'missing': missing_items
|
|
}), 400
|
|
|
|
# Check venv path if provided
|
|
venv_message = ''
|
|
if yolox_venv_path:
|
|
venv_valid = False
|
|
venv_details = []
|
|
|
|
# Normalize path
|
|
venv_path_normalized = os.path.normpath(yolox_venv_path)
|
|
|
|
# Check if it's an activation script (Linux/Mac: bin/activate, Windows: Scripts/activate.bat or Scripts/Activate.ps1)
|
|
if os.path.isfile(venv_path_normalized):
|
|
# Direct path to activation script
|
|
if 'activate' in os.path.basename(venv_path_normalized).lower():
|
|
venv_valid = True
|
|
venv_details.append(f'Activation script found: {os.path.basename(venv_path_normalized)}')
|
|
else:
|
|
return jsonify({
|
|
'success': False,
|
|
'message': 'Venv path points to a file but not an activation script'
|
|
}), 400
|
|
elif os.path.isdir(venv_path_normalized):
|
|
# Check if it's a venv directory
|
|
# Look for activation scripts in common locations
|
|
possible_activations = [
|
|
os.path.join(venv_path_normalized, 'bin', 'activate'), # Linux/Mac
|
|
os.path.join(venv_path_normalized, 'Scripts', 'activate.bat'), # Windows CMD
|
|
os.path.join(venv_path_normalized, 'Scripts', 'Activate.ps1'), # Windows PowerShell
|
|
os.path.join(venv_path_normalized, 'Scripts', 'activate'), # Windows Git Bash
|
|
]
|
|
|
|
found_activations = []
|
|
for act_path in possible_activations:
|
|
if os.path.exists(act_path):
|
|
found_activations.append(os.path.basename(act_path))
|
|
venv_valid = True
|
|
|
|
if venv_valid:
|
|
venv_details.append(f'Venv directory valid. Found: {", ".join(found_activations)}')
|
|
else:
|
|
return jsonify({
|
|
'success': False,
|
|
'message': 'Venv directory does not contain activation scripts'
|
|
}), 400
|
|
else:
|
|
return jsonify({
|
|
'success': False,
|
|
'message': 'Venv path does not exist'
|
|
}), 400
|
|
|
|
venv_message = ' ' + '. '.join(venv_details)
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'message': f'Valid YOLOX installation found. Found: {", ".join(found_items)}.{venv_message}',
|
|
'found': found_items,
|
|
'missing': missing_items
|
|
})
|
|
|
|
except Exception as error:
|
|
return jsonify({'success': False, 'message': str(error)}), 500
|