542 lines
21 KiB
Python
542 lines
21 KiB
Python
from flask import Blueprint, request, jsonify, send_file
|
|
from werkzeug.utils import secure_filename
|
|
import os
|
|
import json
|
|
import subprocess
|
|
from database.database import db
|
|
from models.TrainingProject import TrainingProject
|
|
from models.TrainingProjectDetails import TrainingProjectDetails
|
|
from models.training import Training
|
|
from models.LabelStudioProject import LabelStudioProject
|
|
from models.Images import Image
|
|
from models.Annotation import Annotation
|
|
|
|
api_bp = Blueprint('api', __name__)
|
|
|
|
# Global update status (similar to Node.js version)
|
|
update_status = {"running": False}
|
|
|
|
@api_bp.route('/seed', methods=['GET'])
|
|
def seed():
|
|
"""Trigger seeding from Label Studio"""
|
|
from services.seed_label_studio import seed_label_studio
|
|
result = seed_label_studio()
|
|
return jsonify(result)
|
|
|
|
@api_bp.route('/generate-yolox-json', methods=['POST'])
|
|
def generate_yolox_json():
|
|
"""Generate YOLOX JSON and exp.py for a project"""
|
|
try:
|
|
data = request.get_json()
|
|
project_id = data.get('project_id')
|
|
|
|
if not project_id:
|
|
return jsonify({'message': 'Missing project_id in request body'}), 400
|
|
|
|
# Find all TrainingProjectDetails for this project
|
|
details_rows = TrainingProjectDetails.query.filter_by(project_id=project_id).all()
|
|
|
|
if not details_rows:
|
|
return jsonify({'message': f'No TrainingProjectDetails found for project {project_id}'}), 404
|
|
|
|
# Get project name
|
|
training_project = TrainingProject.query.get(project_id)
|
|
project_name = training_project.title.replace(' ', '_') if training_project.title else f'project_{project_id}'
|
|
|
|
from services.generate_json_yolox import generate_training_json
|
|
from services.generate_yolox_exp import save_yolox_exp
|
|
|
|
# For each details row, generate coco.jsons and exp.py
|
|
for details in details_rows:
|
|
details_id = details.id
|
|
generate_training_json(details_id)
|
|
|
|
# Find all trainings for this details row
|
|
trainings = Training.query.filter_by(project_details_id=details_id).all()
|
|
if not trainings:
|
|
continue
|
|
|
|
# Create output directory
|
|
out_dir = os.path.join(os.path.dirname(__file__), '..', project_name, str(details_id))
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
# Save exp.py for each training
|
|
for training in trainings:
|
|
exp_file_path = os.path.join(out_dir, 'exp.py')
|
|
save_yolox_exp(training.id, exp_file_path)
|
|
|
|
return jsonify({'message': f'YOLOX JSON and exp.py generated for project {project_id}'})
|
|
|
|
except Exception as err:
|
|
print(f'Error generating YOLOX JSON: {err}')
|
|
return jsonify({'message': 'Failed to generate YOLOX JSON', 'error': str(err)}), 500
|
|
|
|
@api_bp.route('/start-yolox-training', methods=['POST'])
|
|
def start_yolox_training():
|
|
"""Start YOLOX training"""
|
|
try:
|
|
data = request.get_json()
|
|
project_id = data.get('project_id')
|
|
training_id = data.get('training_id')
|
|
|
|
# Get project name
|
|
training_project = TrainingProject.query.get(project_id)
|
|
project_name = training_project.title.replace(' ', '_') if training_project.title else f'project_{project_id}'
|
|
|
|
# Look up training row
|
|
training_row = Training.query.get(training_id)
|
|
if not training_row:
|
|
training_row = Training.query.filter_by(project_details_id=training_id).first()
|
|
|
|
if not training_row:
|
|
return jsonify({'error': f'Training row not found for id or project_details_id {training_id}'}), 404
|
|
|
|
project_details_id = training_row.project_details_id
|
|
|
|
# Path to exp.py
|
|
out_dir = os.path.join(os.path.dirname(__file__), '..', project_name, str(project_details_id))
|
|
exp_src = os.path.join(out_dir, 'exp.py')
|
|
|
|
if not os.path.exists(exp_src):
|
|
return jsonify({'error': f'exp.py not found at {exp_src}'}), 500
|
|
|
|
# YOLOX configuration
|
|
yolox_main_dir = '/home/kitraining/Yolox/YOLOX-main'
|
|
yolox_venv = '/home/kitraining/Yolox/yolox_venv/bin/activate'
|
|
|
|
# Determine model argument
|
|
model_arg = ''
|
|
cmd = ''
|
|
|
|
if (training_row.transfer_learning and
|
|
isinstance(training_row.transfer_learning, str) and
|
|
training_row.transfer_learning.lower() == 'coco'):
|
|
model_arg = f' -c /home/kitraining/Yolox/YOLOX-main/pretrained/{training_row.selected_model}'
|
|
cmd = f'bash -c \'source {yolox_venv} && python tools/train.py -f {exp_src} -d 1 -b 8 --fp16 -o {model_arg}.pth --cache\''
|
|
elif (training_row.selected_model and
|
|
training_row.selected_model.lower() == 'coco' and
|
|
(not training_row.transfer_learning or training_row.transfer_learning == False)):
|
|
model_arg = f' -c /pretrained/{training_row.selected_model}'
|
|
cmd = f'bash -c \'source {yolox_venv} && python tools/train.py -f {exp_src} -d 1 -b 8 --fp16 -o {model_arg}.pth --cache\''
|
|
else:
|
|
cmd = f'bash -c \'source {yolox_venv} && python tools/train.py -f {exp_src} -d 1 -b 8 --fp16 --cache\''
|
|
|
|
print(cmd)
|
|
|
|
# Start training in background
|
|
subprocess.Popen(cmd, shell=True, cwd=yolox_main_dir)
|
|
|
|
return jsonify({'message': 'Training started'})
|
|
|
|
except Exception as err:
|
|
return jsonify({'error': 'Failed to start training', 'details': str(err)}), 500
|
|
|
|
@api_bp.route('/training-log', methods=['GET'])
|
|
def training_log():
|
|
"""Get YOLOX training log"""
|
|
try:
|
|
project_id = request.args.get('project_id')
|
|
training_id = request.args.get('training_id')
|
|
|
|
training_project = TrainingProject.query.get(project_id)
|
|
project_name = training_project.title.replace(' ', '_') if training_project.title else f'project_{project_id}'
|
|
|
|
out_dir = os.path.join(os.path.dirname(__file__), '..', project_name, str(training_id))
|
|
log_path = os.path.join(out_dir, 'training.log')
|
|
|
|
if not os.path.exists(log_path):
|
|
return jsonify({'error': 'Log not found'}), 404
|
|
|
|
with open(log_path, 'r') as f:
|
|
log_data = f.read()
|
|
|
|
return jsonify({'log': log_data})
|
|
|
|
except Exception as err:
|
|
return jsonify({'error': 'Failed to fetch log', 'details': str(err)}), 500
|
|
|
|
@api_bp.route('/training-projects', methods=['POST'])
|
|
def create_training_project():
|
|
"""Create a new training project"""
|
|
try:
|
|
title = request.form.get('title')
|
|
description = request.form.get('description')
|
|
classes = json.loads(request.form.get('classes', '[]'))
|
|
|
|
project_image = None
|
|
project_image_type = None
|
|
|
|
if 'project_image' in request.files:
|
|
file = request.files['project_image']
|
|
project_image = file.read()
|
|
project_image_type = file.content_type
|
|
|
|
project = TrainingProject(
|
|
title=title,
|
|
description=description,
|
|
classes=classes,
|
|
project_image=project_image,
|
|
project_image_type=project_image_type
|
|
)
|
|
|
|
db.session.add(project)
|
|
db.session.commit()
|
|
|
|
return jsonify({'message': 'Project created!'})
|
|
|
|
except Exception as error:
|
|
print(f'Error creating project: {error}')
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to create project', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/training-projects', methods=['GET'])
|
|
def get_training_projects():
|
|
"""Get all training projects"""
|
|
try:
|
|
projects = TrainingProject.query.all()
|
|
serialized = [project.to_dict() for project in projects]
|
|
return jsonify(serialized)
|
|
|
|
except Exception as error:
|
|
return jsonify({'message': 'Failed to fetch projects', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/update-status', methods=['GET'])
|
|
def get_update_status():
|
|
"""Get update status"""
|
|
return jsonify(update_status)
|
|
|
|
@api_bp.route('/label-studio-projects', methods=['GET'])
|
|
def get_label_studio_projects():
|
|
"""Get all Label Studio projects with annotation counts"""
|
|
try:
|
|
from sqlalchemy import func
|
|
|
|
# Get all projects
|
|
label_studio_projects = LabelStudioProject.query.all()
|
|
|
|
# Get annotation counts in one query using SQL aggregation
|
|
annotation_counts_query = db.session.query(
|
|
Image.project_id,
|
|
Annotation.Label,
|
|
func.count(Annotation.annotation_id).label('count')
|
|
).join(
|
|
Annotation, Image.image_id == Annotation.image_id
|
|
).group_by(
|
|
Image.project_id, Annotation.Label
|
|
).all()
|
|
|
|
# Organize counts by project_id
|
|
counts_by_project = {}
|
|
for project_id, label, count in annotation_counts_query:
|
|
if project_id not in counts_by_project:
|
|
counts_by_project[project_id] = {}
|
|
counts_by_project[project_id][label] = count
|
|
|
|
# Build result
|
|
projects_with_counts = []
|
|
for project in label_studio_projects:
|
|
project_dict = project.to_dict()
|
|
project_dict['annotationCounts'] = counts_by_project.get(project.project_id, {})
|
|
projects_with_counts.append(project_dict)
|
|
|
|
return jsonify(projects_with_counts)
|
|
|
|
except Exception as error:
|
|
return jsonify({'message': 'Failed to fetch projects', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/training-project-details', methods=['POST'])
|
|
def create_training_project_details():
|
|
"""Create TrainingProjectDetails"""
|
|
try:
|
|
data = request.get_json()
|
|
project_id = data.get('project_id')
|
|
annotation_projects = data.get('annotation_projects')
|
|
class_map = data.get('class_map')
|
|
description = data.get('description')
|
|
|
|
if not project_id or annotation_projects is None:
|
|
return jsonify({'message': 'Missing required fields'}), 400
|
|
|
|
details = TrainingProjectDetails(
|
|
project_id=project_id,
|
|
annotation_projects=annotation_projects,
|
|
class_map=class_map,
|
|
description=description
|
|
)
|
|
|
|
db.session.add(details)
|
|
db.session.commit()
|
|
|
|
return jsonify({'message': 'TrainingProjectDetails created', 'details': details.to_dict()})
|
|
|
|
except Exception as error:
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to create TrainingProjectDetails', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/training-project-details', methods=['GET'])
|
|
def get_training_project_details():
|
|
"""Get all TrainingProjectDetails"""
|
|
try:
|
|
details = TrainingProjectDetails.query.all()
|
|
return jsonify([d.to_dict() for d in details])
|
|
|
|
except Exception as error:
|
|
return jsonify({'message': 'Failed to fetch TrainingProjectDetails', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/training-project-details', methods=['PUT'])
|
|
def update_training_project_details():
|
|
"""Update class_map and description in TrainingProjectDetails"""
|
|
try:
|
|
data = request.get_json()
|
|
project_id = data.get('project_id')
|
|
class_map = data.get('class_map')
|
|
description = data.get('description')
|
|
|
|
if not project_id or not class_map or not description:
|
|
return jsonify({'message': 'Missing required fields'}), 400
|
|
|
|
details = TrainingProjectDetails.query.filter_by(project_id=project_id).first()
|
|
|
|
if not details:
|
|
return jsonify({'message': 'TrainingProjectDetails not found'}), 404
|
|
|
|
details.class_map = class_map
|
|
details.description = description
|
|
db.session.commit()
|
|
|
|
return jsonify({'message': 'Class map and description updated', 'details': details.to_dict()})
|
|
|
|
except Exception as error:
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to update class map or description', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/yolox-settings', methods=['POST'])
|
|
def yolox_settings():
|
|
"""Receive YOLOX settings and save to DB"""
|
|
try:
|
|
settings = request.form.to_dict()
|
|
|
|
print('--- YOLOX settings received ---')
|
|
print('settings:', settings)
|
|
|
|
# Map select_model to selected_model if present
|
|
if 'select_model' in settings and 'selected_model' not in settings:
|
|
settings['selected_model'] = settings['select_model']
|
|
del settings['select_model']
|
|
|
|
# Lookup or create project_details_id
|
|
if not settings.get('project_id') or not settings['project_id'].isdigit():
|
|
raise ValueError('Missing or invalid project_id in request')
|
|
|
|
project_id = int(settings['project_id'])
|
|
details = TrainingProjectDetails.query.filter_by(project_id=project_id).first()
|
|
|
|
if not details:
|
|
details = TrainingProjectDetails(
|
|
project_id=project_id,
|
|
annotation_projects=[],
|
|
class_map=None,
|
|
description=None
|
|
)
|
|
db.session.add(details)
|
|
db.session.commit()
|
|
|
|
settings['project_details_id'] = details.id
|
|
|
|
# Map 'act' to 'activation'
|
|
if 'act' in settings:
|
|
settings['activation'] = settings['act']
|
|
del settings['act']
|
|
|
|
# Type conversions
|
|
numeric_fields = [
|
|
'max_epoch', 'depth', 'width', 'warmup_epochs', 'warmup_lr',
|
|
'no_aug_epochs', 'min_lr_ratio', 'weight_decay', 'momentum',
|
|
'print_interval', 'eval_interval', 'test_conf', 'nmsthre',
|
|
'multiscale_range', 'degrees', 'translate', 'shear',
|
|
'train', 'valid', 'test'
|
|
]
|
|
|
|
for field in numeric_fields:
|
|
if field in settings:
|
|
settings[field] = float(settings[field])
|
|
|
|
# Boolean conversions
|
|
boolean_fields = ['ema', 'enable_mixup', 'save_history_ckpt']
|
|
for field in boolean_fields:
|
|
if field in settings:
|
|
if isinstance(settings[field], str):
|
|
settings[field] = settings[field].lower() == 'true'
|
|
else:
|
|
settings[field] = bool(settings[field])
|
|
|
|
# Array conversions
|
|
array_fields = ['mosaic_scale', 'mixup_scale', 'scale']
|
|
for field in array_fields:
|
|
if field in settings and isinstance(settings[field], str):
|
|
settings[field] = [float(x.strip()) for x in settings[field].split(',') if x.strip()]
|
|
|
|
# Trim string fields
|
|
for key in settings:
|
|
if isinstance(settings[key], str):
|
|
settings[key] = settings[key].strip()
|
|
|
|
# Default for transfer_learning
|
|
if 'transfer_learning' not in settings:
|
|
settings['transfer_learning'] = False
|
|
|
|
# Convert empty seed to None
|
|
if 'seed' in settings and (settings['seed'] == '' or settings['seed'] is None):
|
|
settings['seed'] = None
|
|
|
|
# Validate required fields
|
|
required_fields = [
|
|
'project_details_id', 'exp_name', 'max_epoch', 'depth', 'width',
|
|
'activation', 'train', 'valid', 'test', 'selected_model', 'transfer_learning'
|
|
]
|
|
|
|
for field in required_fields:
|
|
if field not in settings or settings[field] in [None, '']:
|
|
raise ValueError(f'Missing required field: {field}')
|
|
|
|
print('Received YOLOX settings:', settings)
|
|
|
|
# Handle uploaded model file
|
|
if 'ckpt_upload' in request.files:
|
|
file = request.files['ckpt_upload']
|
|
upload_dir = os.path.join(os.path.dirname(__file__), '..', 'uploads')
|
|
os.makedirs(upload_dir, exist_ok=True)
|
|
filename = file.filename or f'uploaded_model_{project_id}.pth'
|
|
file_path = os.path.join(upload_dir, filename)
|
|
file.save(file_path)
|
|
settings['model_upload'] = file_path
|
|
|
|
# Save to DB
|
|
from services.push_yolox_exp import push_yolox_exp_to_db
|
|
training = push_yolox_exp_to_db(settings)
|
|
|
|
return jsonify({'message': 'YOLOX settings saved to DB', 'training': training.to_dict()})
|
|
|
|
except Exception as error:
|
|
print(f'Error in /api/yolox-settings: {error}')
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to save YOLOX settings', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/yolox-settings/upload', methods=['POST'])
|
|
def yolox_settings_upload():
|
|
"""Upload binary model file"""
|
|
try:
|
|
project_id = request.args.get('project_id')
|
|
if not project_id:
|
|
return jsonify({'message': 'Missing project_id in query'}), 400
|
|
|
|
# Save file to disk
|
|
upload_dir = os.path.join(os.path.dirname(__file__), '..', 'uploads')
|
|
os.makedirs(upload_dir, exist_ok=True)
|
|
|
|
filename = request.headers.get('x-upload-filename', f'uploaded_model_{project_id}.pth')
|
|
file_path = os.path.join(upload_dir, filename)
|
|
|
|
# Read binary data
|
|
with open(file_path, 'wb') as f:
|
|
f.write(request.data)
|
|
|
|
# Update latest training row
|
|
details = TrainingProjectDetails.query.filter_by(project_id=project_id).first()
|
|
if not details:
|
|
return jsonify({'message': 'No TrainingProjectDetails found for project_id'}), 404
|
|
|
|
training = Training.query.filter_by(project_details_id=details.id).order_by(Training.id.desc()).first()
|
|
if not training:
|
|
return jsonify({'message': 'No training found for project_id'}), 404
|
|
|
|
training.model_upload = file_path
|
|
db.session.commit()
|
|
|
|
return jsonify({
|
|
'message': 'Model file uploaded and saved to disk',
|
|
'filename': filename,
|
|
'trainingId': training.id
|
|
})
|
|
|
|
except Exception as error:
|
|
print(f'Error in /api/yolox-settings/upload: {error}')
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to upload model file', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/trainings', methods=['GET'])
|
|
def get_trainings():
|
|
"""Get all trainings (optionally filtered by project_id)"""
|
|
try:
|
|
project_id = request.args.get('project_id')
|
|
|
|
if project_id:
|
|
# Find all details rows for this project
|
|
details_rows = TrainingProjectDetails.query.filter_by(project_id=project_id).all()
|
|
if not details_rows:
|
|
return jsonify([])
|
|
|
|
# Get all trainings linked to any details row for this project
|
|
details_ids = [d.id for d in details_rows]
|
|
trainings = Training.query.filter(Training.project_details_id.in_(details_ids)).all()
|
|
return jsonify([t.to_dict() for t in trainings])
|
|
else:
|
|
# Return all trainings
|
|
trainings = Training.query.all()
|
|
return jsonify([t.to_dict() for t in trainings])
|
|
|
|
except Exception as error:
|
|
return jsonify({'message': 'Failed to fetch trainings', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/trainings/<int:id>', methods=['DELETE'])
|
|
def delete_training(id):
|
|
"""Delete a training by id"""
|
|
try:
|
|
training = Training.query.get(id)
|
|
if training:
|
|
db.session.delete(training)
|
|
db.session.commit()
|
|
return jsonify({'message': 'Training deleted'})
|
|
else:
|
|
return jsonify({'message': 'Training not found'}), 404
|
|
|
|
except Exception as error:
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to delete training', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/training-projects/<int:id>', methods=['DELETE'])
|
|
def delete_training_project(id):
|
|
"""Delete a training project and all related entries"""
|
|
try:
|
|
# Find details rows for this project
|
|
details_rows = TrainingProjectDetails.query.filter_by(project_id=id).all()
|
|
details_ids = [d.id for d in details_rows]
|
|
|
|
# Delete all trainings linked to these details
|
|
if details_ids:
|
|
Training.query.filter(Training.project_details_id.in_(details_ids)).delete(synchronize_session=False)
|
|
TrainingProjectDetails.query.filter_by(project_id=id).delete()
|
|
|
|
# Delete the project itself
|
|
project = TrainingProject.query.get(id)
|
|
if project:
|
|
db.session.delete(project)
|
|
db.session.commit()
|
|
return jsonify({'message': 'Training project and all related entries deleted'})
|
|
else:
|
|
return jsonify({'message': 'Training project not found'}), 404
|
|
|
|
except Exception as error:
|
|
db.session.rollback()
|
|
return jsonify({'message': 'Failed to delete training project', 'error': str(error)}), 500
|
|
|
|
@api_bp.route('/base-config/<model_name>', methods=['GET'])
|
|
def get_base_config(model_name):
|
|
"""Get base configuration for a specific YOLOX model"""
|
|
try:
|
|
from services.generate_yolox_exp import load_base_config
|
|
config = load_base_config(model_name)
|
|
return jsonify(config)
|
|
except Exception as error:
|
|
return jsonify({'message': f'Failed to load base config for {model_name}', 'error': str(error)}), 404
|