Files
Abschluss-Projekt/backend/routes/api.py
2025-11-28 12:50:27 +01:00

532 lines
21 KiB
Python

from flask import Blueprint, request, jsonify, send_file
from werkzeug.utils import secure_filename
import os
import json
import subprocess
from database.database import db
from models.TrainingProject import TrainingProject
from models.TrainingProjectDetails import TrainingProjectDetails
from models.training import Training
from models.LabelStudioProject import LabelStudioProject
from models.Images import Image
from models.Annotation import Annotation
api_bp = Blueprint('api', __name__)
# Global update status (similar to Node.js version)
update_status = {"running": False}
@api_bp.route('/seed', methods=['GET'])
def seed():
"""Trigger seeding from Label Studio"""
from services.seed_label_studio import seed_label_studio
result = seed_label_studio()
return jsonify(result)
@api_bp.route('/generate-yolox-json', methods=['POST'])
def generate_yolox_json():
"""Generate YOLOX JSON and exp.py for a project"""
try:
data = request.get_json()
project_id = data.get('project_id')
if not project_id:
return jsonify({'message': 'Missing project_id in request body'}), 400
# Find all TrainingProjectDetails for this project
details_rows = TrainingProjectDetails.query.filter_by(project_id=project_id).all()
if not details_rows:
return jsonify({'message': f'No TrainingProjectDetails found for project {project_id}'}), 404
# Get project name
training_project = TrainingProject.query.get(project_id)
project_name = training_project.title.replace(' ', '_') if training_project.title else f'project_{project_id}'
from services.generate_json_yolox import generate_training_json
from services.generate_yolox_exp import save_yolox_exp
# For each details row, generate coco.jsons and exp.py
for details in details_rows:
details_id = details.id
generate_training_json(details_id)
# Find all trainings for this details row
trainings = Training.query.filter_by(project_details_id=details_id).all()
if not trainings:
continue
# Create output directory
out_dir = os.path.join(os.path.dirname(__file__), '..', project_name, str(details_id))
os.makedirs(out_dir, exist_ok=True)
# Save exp.py for each training
for training in trainings:
exp_file_path = os.path.join(out_dir, 'exp.py')
save_yolox_exp(training.id, exp_file_path)
return jsonify({'message': f'YOLOX JSON and exp.py generated for project {project_id}'})
except Exception as err:
print(f'Error generating YOLOX JSON: {err}')
return jsonify({'message': 'Failed to generate YOLOX JSON', 'error': str(err)}), 500
@api_bp.route('/start-yolox-training', methods=['POST'])
def start_yolox_training():
"""Start YOLOX training"""
try:
data = request.get_json()
project_id = data.get('project_id')
training_id = data.get('training_id')
# Get project name
training_project = TrainingProject.query.get(project_id)
project_name = training_project.title.replace(' ', '_') if training_project.title else f'project_{project_id}'
# Look up training row
training_row = Training.query.get(training_id)
if not training_row:
training_row = Training.query.filter_by(project_details_id=training_id).first()
if not training_row:
return jsonify({'error': f'Training row not found for id or project_details_id {training_id}'}), 404
project_details_id = training_row.project_details_id
# Path to exp.py
out_dir = os.path.join(os.path.dirname(__file__), '..', project_name, str(project_details_id))
exp_src = os.path.join(out_dir, 'exp.py')
if not os.path.exists(exp_src):
return jsonify({'error': f'exp.py not found at {exp_src}'}), 500
# YOLOX configuration
yolox_main_dir = '/home/kitraining/Yolox/YOLOX-main'
yolox_venv = '/home/kitraining/Yolox/yolox_venv/bin/activate'
# Determine model argument
model_arg = ''
cmd = ''
if (training_row.transfer_learning and
isinstance(training_row.transfer_learning, str) and
training_row.transfer_learning.lower() == 'coco'):
model_arg = f' -c /home/kitraining/Yolox/YOLOX-main/pretrained/{training_row.selected_model}'
cmd = f'bash -c \'source {yolox_venv} && python tools/train.py -f {exp_src} -d 1 -b 8 --fp16 -o {model_arg}.pth --cache\''
elif (training_row.selected_model and
training_row.selected_model.lower() == 'coco' and
(not training_row.transfer_learning or training_row.transfer_learning == False)):
model_arg = f' -c /pretrained/{training_row.selected_model}'
cmd = f'bash -c \'source {yolox_venv} && python tools/train.py -f {exp_src} -d 1 -b 8 --fp16 -o {model_arg}.pth --cache\''
else:
cmd = f'bash -c \'source {yolox_venv} && python tools/train.py -f {exp_src} -d 1 -b 8 --fp16 --cache\''
print(cmd)
# Start training in background
subprocess.Popen(cmd, shell=True, cwd=yolox_main_dir)
return jsonify({'message': 'Training started'})
except Exception as err:
return jsonify({'error': 'Failed to start training', 'details': str(err)}), 500
@api_bp.route('/training-log', methods=['GET'])
def training_log():
"""Get YOLOX training log"""
try:
project_id = request.args.get('project_id')
training_id = request.args.get('training_id')
training_project = TrainingProject.query.get(project_id)
project_name = training_project.title.replace(' ', '_') if training_project.title else f'project_{project_id}'
out_dir = os.path.join(os.path.dirname(__file__), '..', project_name, str(training_id))
log_path = os.path.join(out_dir, 'training.log')
if not os.path.exists(log_path):
return jsonify({'error': 'Log not found'}), 404
with open(log_path, 'r') as f:
log_data = f.read()
return jsonify({'log': log_data})
except Exception as err:
return jsonify({'error': 'Failed to fetch log', 'details': str(err)}), 500
@api_bp.route('/training-projects', methods=['POST'])
def create_training_project():
"""Create a new training project"""
try:
title = request.form.get('title')
description = request.form.get('description')
classes = json.loads(request.form.get('classes', '[]'))
project_image = None
project_image_type = None
if 'project_image' in request.files:
file = request.files['project_image']
project_image = file.read()
project_image_type = file.content_type
project = TrainingProject(
title=title,
description=description,
classes=classes,
project_image=project_image,
project_image_type=project_image_type
)
db.session.add(project)
db.session.commit()
return jsonify({'message': 'Project created!'})
except Exception as error:
print(f'Error creating project: {error}')
db.session.rollback()
return jsonify({'message': 'Failed to create project', 'error': str(error)}), 500
@api_bp.route('/training-projects', methods=['GET'])
def get_training_projects():
"""Get all training projects"""
try:
projects = TrainingProject.query.all()
serialized = [project.to_dict() for project in projects]
return jsonify(serialized)
except Exception as error:
return jsonify({'message': 'Failed to fetch projects', 'error': str(error)}), 500
@api_bp.route('/update-status', methods=['GET'])
def get_update_status():
"""Get update status"""
return jsonify(update_status)
@api_bp.route('/label-studio-projects', methods=['GET'])
def get_label_studio_projects():
"""Get all Label Studio projects with annotation counts"""
try:
from sqlalchemy import func
# Get all projects
label_studio_projects = LabelStudioProject.query.all()
# Get annotation counts in one query using SQL aggregation
annotation_counts_query = db.session.query(
Image.project_id,
Annotation.Label,
func.count(Annotation.annotation_id).label('count')
).join(
Annotation, Image.image_id == Annotation.image_id
).group_by(
Image.project_id, Annotation.Label
).all()
# Organize counts by project_id
counts_by_project = {}
for project_id, label, count in annotation_counts_query:
if project_id not in counts_by_project:
counts_by_project[project_id] = {}
counts_by_project[project_id][label] = count
# Build result
projects_with_counts = []
for project in label_studio_projects:
project_dict = project.to_dict()
project_dict['annotationCounts'] = counts_by_project.get(project.project_id, {})
projects_with_counts.append(project_dict)
return jsonify(projects_with_counts)
except Exception as error:
return jsonify({'message': 'Failed to fetch projects', 'error': str(error)}), 500
@api_bp.route('/training-project-details', methods=['POST'])
def create_training_project_details():
"""Create TrainingProjectDetails"""
try:
data = request.get_json()
project_id = data.get('project_id')
annotation_projects = data.get('annotation_projects')
class_map = data.get('class_map')
description = data.get('description')
if not project_id or annotation_projects is None:
return jsonify({'message': 'Missing required fields'}), 400
details = TrainingProjectDetails(
project_id=project_id,
annotation_projects=annotation_projects,
class_map=class_map,
description=description
)
db.session.add(details)
db.session.commit()
return jsonify({'message': 'TrainingProjectDetails created', 'details': details.to_dict()})
except Exception as error:
db.session.rollback()
return jsonify({'message': 'Failed to create TrainingProjectDetails', 'error': str(error)}), 500
@api_bp.route('/training-project-details', methods=['GET'])
def get_training_project_details():
"""Get all TrainingProjectDetails"""
try:
details = TrainingProjectDetails.query.all()
return jsonify([d.to_dict() for d in details])
except Exception as error:
return jsonify({'message': 'Failed to fetch TrainingProjectDetails', 'error': str(error)}), 500
@api_bp.route('/training-project-details', methods=['PUT'])
def update_training_project_details():
"""Update class_map and description in TrainingProjectDetails"""
try:
data = request.get_json()
project_id = data.get('project_id')
class_map = data.get('class_map')
description = data.get('description')
if not project_id or not class_map or not description:
return jsonify({'message': 'Missing required fields'}), 400
details = TrainingProjectDetails.query.filter_by(project_id=project_id).first()
if not details:
return jsonify({'message': 'TrainingProjectDetails not found'}), 404
details.class_map = class_map
details.description = description
db.session.commit()
return jsonify({'message': 'Class map and description updated', 'details': details.to_dict()})
except Exception as error:
db.session.rollback()
return jsonify({'message': 'Failed to update class map or description', 'error': str(error)}), 500
@api_bp.route('/yolox-settings', methods=['POST'])
def yolox_settings():
"""Receive YOLOX settings and save to DB"""
try:
settings = request.form.to_dict()
print('--- YOLOX settings received ---')
print('settings:', settings)
# Map select_model to selected_model if present
if 'select_model' in settings and 'selected_model' not in settings:
settings['selected_model'] = settings['select_model']
del settings['select_model']
# Lookup or create project_details_id
if not settings.get('project_id') or not settings['project_id'].isdigit():
raise ValueError('Missing or invalid project_id in request')
project_id = int(settings['project_id'])
details = TrainingProjectDetails.query.filter_by(project_id=project_id).first()
if not details:
details = TrainingProjectDetails(
project_id=project_id,
annotation_projects=[],
class_map=None,
description=None
)
db.session.add(details)
db.session.commit()
settings['project_details_id'] = details.id
# Map 'act' to 'activation'
if 'act' in settings:
settings['activation'] = settings['act']
del settings['act']
# Type conversions
numeric_fields = [
'max_epoch', 'depth', 'width', 'warmup_epochs', 'warmup_lr',
'no_aug_epochs', 'min_lr_ratio', 'weight_decay', 'momentum',
'print_interval', 'eval_interval', 'test_conf', 'nmsthre',
'multiscale_range', 'degrees', 'translate', 'shear',
'train', 'valid', 'test'
]
for field in numeric_fields:
if field in settings:
settings[field] = float(settings[field])
# Boolean conversions
boolean_fields = ['ema', 'enable_mixup', 'save_history_ckpt']
for field in boolean_fields:
if field in settings:
if isinstance(settings[field], str):
settings[field] = settings[field].lower() == 'true'
else:
settings[field] = bool(settings[field])
# Array conversions
array_fields = ['mosaic_scale', 'mixup_scale', 'scale']
for field in array_fields:
if field in settings and isinstance(settings[field], str):
settings[field] = [float(x.strip()) for x in settings[field].split(',') if x.strip()]
# Trim string fields
for key in settings:
if isinstance(settings[key], str):
settings[key] = settings[key].strip()
# Default for transfer_learning
if 'transfer_learning' not in settings:
settings['transfer_learning'] = False
# Convert empty seed to None
if 'seed' in settings and (settings['seed'] == '' or settings['seed'] is None):
settings['seed'] = None
# Validate required fields
required_fields = [
'project_details_id', 'exp_name', 'max_epoch', 'depth', 'width',
'activation', 'train', 'valid', 'test', 'selected_model', 'transfer_learning'
]
for field in required_fields:
if field not in settings or settings[field] in [None, '']:
raise ValueError(f'Missing required field: {field}')
print('Received YOLOX settings:', settings)
# Handle uploaded model file
if 'ckpt_upload' in request.files:
file = request.files['ckpt_upload']
upload_dir = os.path.join(os.path.dirname(__file__), '..', 'uploads')
os.makedirs(upload_dir, exist_ok=True)
filename = file.filename or f'uploaded_model_{project_id}.pth'
file_path = os.path.join(upload_dir, filename)
file.save(file_path)
settings['model_upload'] = file_path
# Save to DB
from services.push_yolox_exp import push_yolox_exp_to_db
training = push_yolox_exp_to_db(settings)
return jsonify({'message': 'YOLOX settings saved to DB', 'training': training.to_dict()})
except Exception as error:
print(f'Error in /api/yolox-settings: {error}')
db.session.rollback()
return jsonify({'message': 'Failed to save YOLOX settings', 'error': str(error)}), 500
@api_bp.route('/yolox-settings/upload', methods=['POST'])
def yolox_settings_upload():
"""Upload binary model file"""
try:
project_id = request.args.get('project_id')
if not project_id:
return jsonify({'message': 'Missing project_id in query'}), 400
# Save file to disk
upload_dir = os.path.join(os.path.dirname(__file__), '..', 'uploads')
os.makedirs(upload_dir, exist_ok=True)
filename = request.headers.get('x-upload-filename', f'uploaded_model_{project_id}.pth')
file_path = os.path.join(upload_dir, filename)
# Read binary data
with open(file_path, 'wb') as f:
f.write(request.data)
# Update latest training row
details = TrainingProjectDetails.query.filter_by(project_id=project_id).first()
if not details:
return jsonify({'message': 'No TrainingProjectDetails found for project_id'}), 404
training = Training.query.filter_by(project_details_id=details.id).order_by(Training.id.desc()).first()
if not training:
return jsonify({'message': 'No training found for project_id'}), 404
training.model_upload = file_path
db.session.commit()
return jsonify({
'message': 'Model file uploaded and saved to disk',
'filename': filename,
'trainingId': training.id
})
except Exception as error:
print(f'Error in /api/yolox-settings/upload: {error}')
db.session.rollback()
return jsonify({'message': 'Failed to upload model file', 'error': str(error)}), 500
@api_bp.route('/trainings', methods=['GET'])
def get_trainings():
"""Get all trainings (optionally filtered by project_id)"""
try:
project_id = request.args.get('project_id')
if project_id:
# Find all details rows for this project
details_rows = TrainingProjectDetails.query.filter_by(project_id=project_id).all()
if not details_rows:
return jsonify([])
# Get all trainings linked to any details row for this project
details_ids = [d.id for d in details_rows]
trainings = Training.query.filter(Training.project_details_id.in_(details_ids)).all()
return jsonify([t.to_dict() for t in trainings])
else:
# Return all trainings
trainings = Training.query.all()
return jsonify([t.to_dict() for t in trainings])
except Exception as error:
return jsonify({'message': 'Failed to fetch trainings', 'error': str(error)}), 500
@api_bp.route('/trainings/<int:id>', methods=['DELETE'])
def delete_training(id):
"""Delete a training by id"""
try:
training = Training.query.get(id)
if training:
db.session.delete(training)
db.session.commit()
return jsonify({'message': 'Training deleted'})
else:
return jsonify({'message': 'Training not found'}), 404
except Exception as error:
db.session.rollback()
return jsonify({'message': 'Failed to delete training', 'error': str(error)}), 500
@api_bp.route('/training-projects/<int:id>', methods=['DELETE'])
def delete_training_project(id):
"""Delete a training project and all related entries"""
try:
# Find details rows for this project
details_rows = TrainingProjectDetails.query.filter_by(project_id=id).all()
details_ids = [d.id for d in details_rows]
# Delete all trainings linked to these details
if details_ids:
Training.query.filter(Training.project_details_id.in_(details_ids)).delete(synchronize_session=False)
TrainingProjectDetails.query.filter_by(project_id=id).delete()
# Delete the project itself
project = TrainingProject.query.get(id)
if project:
db.session.delete(project)
db.session.commit()
return jsonify({'message': 'Training project and all related entries deleted'})
else:
return jsonify({'message': 'Training project not found'}), 404
except Exception as error:
db.session.rollback()
return jsonify({'message': 'Failed to delete training project', 'error': str(error)}), 500