from flask import Blueprint, request, jsonify, send_file from werkzeug.utils import secure_filename import os import json import subprocess from database.database import db from models.TrainingProject import TrainingProject from models.TrainingProjectDetails import TrainingProjectDetails from models.training import Training from models.LabelStudioProject import LabelStudioProject from models.Images import Image from models.Annotation import Annotation api_bp = Blueprint('api', __name__) # Global update status (similar to Node.js version) update_status = {"running": False} @api_bp.route('/seed', methods=['GET']) def seed(): """Trigger seeding from Label Studio""" from services.seed_label_studio import seed_label_studio result = seed_label_studio() return jsonify(result) @api_bp.route('/generate-yolox-json', methods=['POST']) def generate_yolox_json(): """Generate YOLOX JSON and exp.py for a project""" try: data = request.get_json() project_id = data.get('project_id') if not project_id: return jsonify({'message': 'Missing project_id in request body'}), 400 # Find all TrainingProjectDetails for this project details_rows = TrainingProjectDetails.query.filter_by(project_id=project_id).all() if not details_rows: return jsonify({'message': f'No TrainingProjectDetails found for project {project_id}'}), 404 # Get project name training_project = TrainingProject.query.get(project_id) project_name = training_project.title.replace(' ', '_') if training_project.title else f'project_{project_id}' from services.generate_json_yolox import generate_training_json from services.generate_yolox_exp import save_yolox_exp # For each details row, generate coco.jsons and exp.py for details in details_rows: details_id = details.id generate_training_json(details_id) # Find all trainings for this details row trainings = Training.query.filter_by(project_details_id=details_id).all() if not trainings: continue # Create output directory out_dir = os.path.join(os.path.dirname(__file__), '..', project_name, str(details_id)) os.makedirs(out_dir, exist_ok=True) # Save exp.py for each training for training in trainings: exp_file_path = os.path.join(out_dir, 'exp.py') save_yolox_exp(training.id, exp_file_path) return jsonify({'message': f'YOLOX JSON and exp.py generated for project {project_id}'}) except Exception as err: print(f'Error generating YOLOX JSON: {err}') return jsonify({'message': 'Failed to generate YOLOX JSON', 'error': str(err)}), 500 @api_bp.route('/start-yolox-training', methods=['POST']) def start_yolox_training(): """Start YOLOX training""" try: data = request.get_json() project_id = data.get('project_id') training_id = data.get('training_id') # Get project name training_project = TrainingProject.query.get(project_id) project_name = training_project.title.replace(' ', '_') if training_project.title else f'project_{project_id}' # Look up training row training_row = Training.query.get(training_id) if not training_row: training_row = Training.query.filter_by(project_details_id=training_id).first() if not training_row: return jsonify({'error': f'Training row not found for id or project_details_id {training_id}'}), 404 project_details_id = training_row.project_details_id # Path to exp.py out_dir = os.path.join(os.path.dirname(__file__), '..', project_name, str(project_details_id)) exp_src = os.path.join(out_dir, 'exp.py') if not os.path.exists(exp_src): return jsonify({'error': f'exp.py not found at {exp_src}'}), 500 # YOLOX configuration yolox_main_dir = '/home/kitraining/Yolox/YOLOX-main' yolox_venv = '/home/kitraining/Yolox/yolox_venv/bin/activate' # Determine model argument model_arg = '' cmd = '' if (training_row.transfer_learning and isinstance(training_row.transfer_learning, str) and training_row.transfer_learning.lower() == 'coco'): model_arg = f' -c /home/kitraining/Yolox/YOLOX-main/pretrained/{training_row.selected_model}' cmd = f'bash -c \'source {yolox_venv} && python tools/train.py -f {exp_src} -d 1 -b 8 --fp16 -o {model_arg}.pth --cache\'' elif (training_row.selected_model and training_row.selected_model.lower() == 'coco' and (not training_row.transfer_learning or training_row.transfer_learning == False)): model_arg = f' -c /pretrained/{training_row.selected_model}' cmd = f'bash -c \'source {yolox_venv} && python tools/train.py -f {exp_src} -d 1 -b 8 --fp16 -o {model_arg}.pth --cache\'' else: cmd = f'bash -c \'source {yolox_venv} && python tools/train.py -f {exp_src} -d 1 -b 8 --fp16 --cache\'' print(cmd) # Start training in background subprocess.Popen(cmd, shell=True, cwd=yolox_main_dir) return jsonify({'message': 'Training started'}) except Exception as err: return jsonify({'error': 'Failed to start training', 'details': str(err)}), 500 @api_bp.route('/training-log', methods=['GET']) def training_log(): """Get YOLOX training log""" try: project_id = request.args.get('project_id') training_id = request.args.get('training_id') training_project = TrainingProject.query.get(project_id) project_name = training_project.title.replace(' ', '_') if training_project.title else f'project_{project_id}' out_dir = os.path.join(os.path.dirname(__file__), '..', project_name, str(training_id)) log_path = os.path.join(out_dir, 'training.log') if not os.path.exists(log_path): return jsonify({'error': 'Log not found'}), 404 with open(log_path, 'r') as f: log_data = f.read() return jsonify({'log': log_data}) except Exception as err: return jsonify({'error': 'Failed to fetch log', 'details': str(err)}), 500 @api_bp.route('/training-projects', methods=['POST']) def create_training_project(): """Create a new training project""" try: title = request.form.get('title') description = request.form.get('description') classes = json.loads(request.form.get('classes', '[]')) project_image = None project_image_type = None if 'project_image' in request.files: file = request.files['project_image'] project_image = file.read() project_image_type = file.content_type project = TrainingProject( title=title, description=description, classes=classes, project_image=project_image, project_image_type=project_image_type ) db.session.add(project) db.session.commit() return jsonify({'message': 'Project created!'}) except Exception as error: print(f'Error creating project: {error}') db.session.rollback() return jsonify({'message': 'Failed to create project', 'error': str(error)}), 500 @api_bp.route('/training-projects', methods=['GET']) def get_training_projects(): """Get all training projects""" try: projects = TrainingProject.query.all() serialized = [project.to_dict() for project in projects] return jsonify(serialized) except Exception as error: return jsonify({'message': 'Failed to fetch projects', 'error': str(error)}), 500 @api_bp.route('/update-status', methods=['GET']) def get_update_status(): """Get update status""" return jsonify(update_status) @api_bp.route('/label-studio-projects', methods=['GET']) def get_label_studio_projects(): """Get all Label Studio projects with annotation counts""" try: from sqlalchemy import func # Get all projects label_studio_projects = LabelStudioProject.query.all() # Get annotation counts in one query using SQL aggregation annotation_counts_query = db.session.query( Image.project_id, Annotation.Label, func.count(Annotation.annotation_id).label('count') ).join( Annotation, Image.image_id == Annotation.image_id ).group_by( Image.project_id, Annotation.Label ).all() # Organize counts by project_id counts_by_project = {} for project_id, label, count in annotation_counts_query: if project_id not in counts_by_project: counts_by_project[project_id] = {} counts_by_project[project_id][label] = count # Build result projects_with_counts = [] for project in label_studio_projects: project_dict = project.to_dict() project_dict['annotationCounts'] = counts_by_project.get(project.project_id, {}) projects_with_counts.append(project_dict) return jsonify(projects_with_counts) except Exception as error: return jsonify({'message': 'Failed to fetch projects', 'error': str(error)}), 500 @api_bp.route('/training-project-details', methods=['POST']) def create_training_project_details(): """Create TrainingProjectDetails""" try: data = request.get_json() project_id = data.get('project_id') annotation_projects = data.get('annotation_projects') class_map = data.get('class_map') description = data.get('description') if not project_id or annotation_projects is None: return jsonify({'message': 'Missing required fields'}), 400 details = TrainingProjectDetails( project_id=project_id, annotation_projects=annotation_projects, class_map=class_map, description=description ) db.session.add(details) db.session.commit() return jsonify({'message': 'TrainingProjectDetails created', 'details': details.to_dict()}) except Exception as error: db.session.rollback() return jsonify({'message': 'Failed to create TrainingProjectDetails', 'error': str(error)}), 500 @api_bp.route('/training-project-details', methods=['GET']) def get_training_project_details(): """Get all TrainingProjectDetails""" try: details = TrainingProjectDetails.query.all() return jsonify([d.to_dict() for d in details]) except Exception as error: return jsonify({'message': 'Failed to fetch TrainingProjectDetails', 'error': str(error)}), 500 @api_bp.route('/training-project-details', methods=['PUT']) def update_training_project_details(): """Update class_map and description in TrainingProjectDetails""" try: data = request.get_json() project_id = data.get('project_id') class_map = data.get('class_map') description = data.get('description') if not project_id or not class_map or not description: return jsonify({'message': 'Missing required fields'}), 400 details = TrainingProjectDetails.query.filter_by(project_id=project_id).first() if not details: return jsonify({'message': 'TrainingProjectDetails not found'}), 404 details.class_map = class_map details.description = description db.session.commit() return jsonify({'message': 'Class map and description updated', 'details': details.to_dict()}) except Exception as error: db.session.rollback() return jsonify({'message': 'Failed to update class map or description', 'error': str(error)}), 500 @api_bp.route('/yolox-settings', methods=['POST']) def yolox_settings(): """Receive YOLOX settings and save to DB""" try: settings = request.form.to_dict() print('--- YOLOX settings received ---') print('settings:', settings) # Map select_model to selected_model if present if 'select_model' in settings and 'selected_model' not in settings: settings['selected_model'] = settings['select_model'] del settings['select_model'] # Lookup or create project_details_id if not settings.get('project_id') or not settings['project_id'].isdigit(): raise ValueError('Missing or invalid project_id in request') project_id = int(settings['project_id']) details = TrainingProjectDetails.query.filter_by(project_id=project_id).first() if not details: details = TrainingProjectDetails( project_id=project_id, annotation_projects=[], class_map=None, description=None ) db.session.add(details) db.session.commit() settings['project_details_id'] = details.id # Map 'act' to 'activation' if 'act' in settings: settings['activation'] = settings['act'] del settings['act'] # Type conversions numeric_fields = [ 'max_epoch', 'depth', 'width', 'warmup_epochs', 'warmup_lr', 'no_aug_epochs', 'min_lr_ratio', 'weight_decay', 'momentum', 'print_interval', 'eval_interval', 'test_conf', 'nmsthre', 'multiscale_range', 'degrees', 'translate', 'shear', 'train', 'valid', 'test' ] for field in numeric_fields: if field in settings: settings[field] = float(settings[field]) # Boolean conversions boolean_fields = ['ema', 'enable_mixup', 'save_history_ckpt'] for field in boolean_fields: if field in settings: if isinstance(settings[field], str): settings[field] = settings[field].lower() == 'true' else: settings[field] = bool(settings[field]) # Array conversions array_fields = ['mosaic_scale', 'mixup_scale', 'scale'] for field in array_fields: if field in settings and isinstance(settings[field], str): settings[field] = [float(x.strip()) for x in settings[field].split(',') if x.strip()] # Trim string fields for key in settings: if isinstance(settings[key], str): settings[key] = settings[key].strip() # Default for transfer_learning if 'transfer_learning' not in settings: settings['transfer_learning'] = False # Convert empty seed to None if 'seed' in settings and (settings['seed'] == '' or settings['seed'] is None): settings['seed'] = None # Validate required fields required_fields = [ 'project_details_id', 'exp_name', 'max_epoch', 'depth', 'width', 'activation', 'train', 'valid', 'test', 'selected_model', 'transfer_learning' ] for field in required_fields: if field not in settings or settings[field] in [None, '']: raise ValueError(f'Missing required field: {field}') print('Received YOLOX settings:', settings) # Handle uploaded model file if 'ckpt_upload' in request.files: file = request.files['ckpt_upload'] upload_dir = os.path.join(os.path.dirname(__file__), '..', 'uploads') os.makedirs(upload_dir, exist_ok=True) filename = file.filename or f'uploaded_model_{project_id}.pth' file_path = os.path.join(upload_dir, filename) file.save(file_path) settings['model_upload'] = file_path # Save to DB from services.push_yolox_exp import push_yolox_exp_to_db training = push_yolox_exp_to_db(settings) return jsonify({'message': 'YOLOX settings saved to DB', 'training': training.to_dict()}) except Exception as error: print(f'Error in /api/yolox-settings: {error}') db.session.rollback() return jsonify({'message': 'Failed to save YOLOX settings', 'error': str(error)}), 500 @api_bp.route('/yolox-settings/upload', methods=['POST']) def yolox_settings_upload(): """Upload binary model file""" try: project_id = request.args.get('project_id') if not project_id: return jsonify({'message': 'Missing project_id in query'}), 400 # Save file to disk upload_dir = os.path.join(os.path.dirname(__file__), '..', 'uploads') os.makedirs(upload_dir, exist_ok=True) filename = request.headers.get('x-upload-filename', f'uploaded_model_{project_id}.pth') file_path = os.path.join(upload_dir, filename) # Read binary data with open(file_path, 'wb') as f: f.write(request.data) # Update latest training row details = TrainingProjectDetails.query.filter_by(project_id=project_id).first() if not details: return jsonify({'message': 'No TrainingProjectDetails found for project_id'}), 404 training = Training.query.filter_by(project_details_id=details.id).order_by(Training.id.desc()).first() if not training: return jsonify({'message': 'No training found for project_id'}), 404 training.model_upload = file_path db.session.commit() return jsonify({ 'message': 'Model file uploaded and saved to disk', 'filename': filename, 'trainingId': training.id }) except Exception as error: print(f'Error in /api/yolox-settings/upload: {error}') db.session.rollback() return jsonify({'message': 'Failed to upload model file', 'error': str(error)}), 500 @api_bp.route('/trainings', methods=['GET']) def get_trainings(): """Get all trainings (optionally filtered by project_id)""" try: project_id = request.args.get('project_id') if project_id: # Find all details rows for this project details_rows = TrainingProjectDetails.query.filter_by(project_id=project_id).all() if not details_rows: return jsonify([]) # Get all trainings linked to any details row for this project details_ids = [d.id for d in details_rows] trainings = Training.query.filter(Training.project_details_id.in_(details_ids)).all() return jsonify([t.to_dict() for t in trainings]) else: # Return all trainings trainings = Training.query.all() return jsonify([t.to_dict() for t in trainings]) except Exception as error: return jsonify({'message': 'Failed to fetch trainings', 'error': str(error)}), 500 @api_bp.route('/trainings/', methods=['DELETE']) def delete_training(id): """Delete a training by id""" try: training = Training.query.get(id) if training: db.session.delete(training) db.session.commit() return jsonify({'message': 'Training deleted'}) else: return jsonify({'message': 'Training not found'}), 404 except Exception as error: db.session.rollback() return jsonify({'message': 'Failed to delete training', 'error': str(error)}), 500 @api_bp.route('/training-projects/', methods=['DELETE']) def delete_training_project(id): """Delete a training project and all related entries""" try: # Find details rows for this project details_rows = TrainingProjectDetails.query.filter_by(project_id=id).all() details_ids = [d.id for d in details_rows] # Delete all trainings linked to these details if details_ids: Training.query.filter(Training.project_details_id.in_(details_ids)).delete(synchronize_session=False) TrainingProjectDetails.query.filter_by(project_id=id).delete() # Delete the project itself project = TrainingProject.query.get(id) if project: db.session.delete(project) db.session.commit() return jsonify({'message': 'Training project and all related entries deleted'}) else: return jsonify({'message': 'Training project not found'}), 404 except Exception as error: db.session.rollback() return jsonify({'message': 'Failed to delete training project', 'error': str(error)}), 500 @api_bp.route('/base-config/', methods=['GET']) def get_base_config(model_name): """Get base configuration for a specific YOLOX model""" try: from services.generate_yolox_exp import load_base_config config = load_base_config(model_name) return jsonify(config) except Exception as error: return jsonify({'message': f'Failed to load base config for {model_name}', 'error': str(error)}), 404