180 lines
7.3 KiB
Python
180 lines
7.3 KiB
Python
import json
|
|
import os
|
|
import math
|
|
from models.TrainingProject import TrainingProject
|
|
from models.TrainingProjectDetails import TrainingProjectDetails
|
|
from models.Images import Image
|
|
from models.Annotation import Annotation
|
|
|
|
def generate_training_json(training_id):
|
|
"""Generate COCO JSON for training, validation, and test sets"""
|
|
# training_id is now project_details_id
|
|
training_project_details = TrainingProjectDetails.query.get(training_id)
|
|
|
|
if not training_project_details:
|
|
raise Exception(f'No TrainingProjectDetails found for project_details_id {training_id}')
|
|
|
|
details_obj = training_project_details.to_dict()
|
|
|
|
# Get parent project for name
|
|
training_project = TrainingProject.query.get(details_obj['project_id'])
|
|
|
|
# Get split percentages (default values if not set)
|
|
train_percent = details_obj.get('train_percent', 85)
|
|
valid_percent = details_obj.get('valid_percent', 10)
|
|
test_percent = details_obj.get('test_percent', 5)
|
|
|
|
coco_images = []
|
|
coco_annotations = []
|
|
coco_categories = []
|
|
category_map = {}
|
|
category_id = 0
|
|
image_id = 0
|
|
annotation_id = 0
|
|
|
|
for cls in details_obj['class_map']:
|
|
asg_map = []
|
|
list_asg = cls[1]
|
|
|
|
for asg in list_asg:
|
|
asg_map.append({'original': asg[0], 'mapped': asg[1]})
|
|
# Build category list and mapping
|
|
if asg[1] and asg[1] not in category_map:
|
|
category_map[asg[1]] = category_id
|
|
coco_categories.append({'id': category_id, 'name': asg[1], 'supercategory': ''})
|
|
category_id += 1
|
|
|
|
# Get images for this project
|
|
images = Image.query.filter_by(project_id=cls[0]).all()
|
|
|
|
for image in images:
|
|
image_id += 1
|
|
file_name = image.image_path
|
|
|
|
# Clean up file path
|
|
if '%20' in file_name:
|
|
file_name = file_name.replace('%20', ' ')
|
|
if file_name and file_name.startswith('/data/local-files/?d='):
|
|
file_name = file_name.replace('/data/local-files/?d=', '')
|
|
file_name = file_name.replace('/home/kitraining/home/kitraining/', '')
|
|
if file_name and file_name.startswith('home/kitraining/To_Annotate/'):
|
|
file_name = file_name.replace('home/kitraining/To_Annotate/', '')
|
|
|
|
# Get annotations for this image
|
|
annotations = Annotation.query.filter_by(image_id=image.image_id).all()
|
|
|
|
coco_images.append({
|
|
'id': image_id,
|
|
'file_name': file_name,
|
|
'width': image.width or 0,
|
|
'height': image.height or 0
|
|
})
|
|
|
|
for annotation in annotations:
|
|
# Translate class name using asg_map
|
|
mapped_class = annotation.Label
|
|
for map_entry in asg_map:
|
|
if annotation.Label == map_entry['original']:
|
|
mapped_class = map_entry['mapped']
|
|
break
|
|
|
|
# Only add annotation if mapped_class is valid
|
|
if mapped_class and mapped_class in category_map:
|
|
annotation_id += 1
|
|
area = 0
|
|
if annotation.width and annotation.height:
|
|
area = annotation.width * annotation.height
|
|
|
|
coco_annotations.append({
|
|
'id': annotation_id,
|
|
'image_id': image_id,
|
|
'category_id': category_map[mapped_class],
|
|
'bbox': [annotation.x, annotation.y, annotation.width, annotation.height],
|
|
'area': area,
|
|
'iscrowd': 0
|
|
})
|
|
|
|
# Shuffle images for random split using seed
|
|
def seeded_random(seed):
|
|
x = math.sin(seed) * 10000
|
|
return x - math.floor(x)
|
|
|
|
def shuffle(array, seed):
|
|
for i in range(len(array) - 1, 0, -1):
|
|
j = int(seeded_random(seed + i) * (i + 1))
|
|
array[i], array[j] = array[j], array[i]
|
|
|
|
# Use seed from details_obj if present, else default to 42
|
|
split_seed = details_obj.get('seed', 42)
|
|
if split_seed is not None:
|
|
split_seed = int(split_seed)
|
|
else:
|
|
split_seed = 42
|
|
|
|
shuffle(coco_images, split_seed)
|
|
|
|
# Split images
|
|
total_images = len(coco_images)
|
|
train_count = int(total_images * train_percent / 100)
|
|
valid_count = int(total_images * valid_percent / 100)
|
|
test_count = total_images - train_count - valid_count
|
|
|
|
train_images = coco_images[0:train_count]
|
|
valid_images = coco_images[train_count:train_count + valid_count]
|
|
test_images = coco_images[train_count + valid_count:]
|
|
|
|
# Helper to get image ids for each split
|
|
train_image_ids = {img['id'] for img in train_images}
|
|
valid_image_ids = {img['id'] for img in valid_images}
|
|
test_image_ids = {img['id'] for img in test_images}
|
|
|
|
# Split annotations
|
|
train_annotations = [ann for ann in coco_annotations if ann['image_id'] in train_image_ids]
|
|
valid_annotations = [ann for ann in coco_annotations if ann['image_id'] in valid_image_ids]
|
|
test_annotations = [ann for ann in coco_annotations if ann['image_id'] in test_image_ids]
|
|
|
|
# Build final COCO JSONs
|
|
def build_coco_json(images, annotations, categories):
|
|
return {
|
|
'images': images,
|
|
'annotations': annotations,
|
|
'categories': categories
|
|
}
|
|
|
|
train_json = build_coco_json(train_images, train_annotations, coco_categories)
|
|
valid_json = build_coco_json(valid_images, valid_annotations, coco_categories)
|
|
test_json = build_coco_json(test_images, test_annotations, coco_categories)
|
|
|
|
# Create output directory
|
|
project_name = training_project.title.replace(' ', '_') if training_project and training_project.title else f'project_{details_obj["project_id"]}'
|
|
annotations_dir = '/home/kitraining/To_Annotate/annotations'
|
|
os.makedirs(annotations_dir, exist_ok=True)
|
|
|
|
# Write to files
|
|
train_path = f'{annotations_dir}/coco_project_{training_id}_train.json'
|
|
valid_path = f'{annotations_dir}/coco_project_{training_id}_valid.json'
|
|
test_path = f'{annotations_dir}/coco_project_{training_id}_test.json'
|
|
|
|
with open(train_path, 'w') as f:
|
|
json.dump(train_json, f, indent=2)
|
|
with open(valid_path, 'w') as f:
|
|
json.dump(valid_json, f, indent=2)
|
|
with open(test_path, 'w') as f:
|
|
json.dump(test_json, f, indent=2)
|
|
|
|
print(f'COCO JSON splits written to {annotations_dir} for trainingId {training_id}')
|
|
|
|
# Also generate inference exp.py
|
|
from services.generate_yolox_exp import generate_yolox_inference_exp
|
|
project_folder = os.path.join(os.path.dirname(__file__), '..', project_name, str(training_id))
|
|
os.makedirs(project_folder, exist_ok=True)
|
|
|
|
inference_exp_path = os.path.join(project_folder, 'exp_infer.py')
|
|
try:
|
|
exp_content = generate_yolox_inference_exp(training_id)
|
|
with open(inference_exp_path, 'w') as f:
|
|
f.write(exp_content)
|
|
print(f'Inference exp.py written to {inference_exp_path}')
|
|
except Exception as err:
|
|
print(f'Failed to generate inference exp.py: {err}')
|