48 lines
6.3 KiB
JavaScript
48 lines
6.3 KiB
JavaScript
const Training = require('../models/training.js');
|
|
const fs = require('fs');
|
|
const path = require('path');
|
|
|
|
async function pushYoloxExpToDb(settings) {
|
|
// Normalize boolean and array fields for DB
|
|
const normalized = { ...settings };
|
|
// Map 'act' from frontend to 'activation' for DB
|
|
if (normalized.act !== undefined) {
|
|
normalized.activation = normalized.act;
|
|
delete normalized.act;
|
|
}
|
|
// Convert 'on'/'off' to boolean for save_history_ckpt
|
|
if (typeof normalized.save_history_ckpt === 'string') {
|
|
normalized.save_history_ckpt = normalized.save_history_ckpt === 'on' ? true : false;
|
|
}
|
|
// Convert comma-separated strings to arrays for input_size, test_size, mosaic_scale, mixup_scale
|
|
['input_size', 'test_size', 'mosaic_scale', 'mixup_scale'].forEach(key => {
|
|
if (typeof normalized[key] === 'string') {
|
|
const arr = normalized[key].split(',').map(v => parseFloat(v.trim()));
|
|
normalized[key] = arr.length === 1 ? arr[0] : arr;
|
|
}
|
|
});
|
|
// Find TrainingProjectDetails for this project
|
|
const TrainingProjectDetails = require('../models/TrainingProjectDetails.js');
|
|
const details = await TrainingProjectDetails.findOne({ where: { project_id: normalized.project_id } });
|
|
if (!details) throw new Error('TrainingProjectDetails not found for project_id ' + normalized.project_id);
|
|
normalized.project_details_id = details.id;
|
|
// Create DB row
|
|
const training = await Training.create(normalized);
|
|
return training;
|
|
}
|
|
|
|
async function generateYoloxExpFromDb(trainingId) {
|
|
// Fetch training row from DB
|
|
const training = await Training.findByPk(trainingId);
|
|
if (!training) throw new Error('Training not found');
|
|
// Template for exp.py
|
|
const expTemplate = `#!/usr/bin/env python3\n# Copyright (c) Megvii Inc. All rights reserved.\n\nimport os\nimport random\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\n\nfrom .base_exp import BaseExp\n\n__all__ = [\"Exp\", \"check_exp_value\"]\n\nclass Exp(BaseExp):\n def __init__(self):\n super().__init__()\n\n # ---------------- model config ---------------- #\n self.num_classes = ${training.num_classes || 80}\n self.depth = ${training.depth || 1.00}\n self.width = ${training.width || 1.00}\n self.act = \"${training.activation || training.act || 'silu'}\"\n\n # ---------------- dataloader config ---------------- #\n self.data_num_workers = ${training.data_num_workers || 4}\n self.input_size = (${Array.isArray(training.input_size) ? training.input_size.join(', ') : '640, 640'})\n self.multiscale_range = ${training.multiscale_range || 5}\n self.data_dir = ${training.data_dir ? `\"${training.data_dir}\"` : 'None'}\n self.train_ann = \"${training.train_ann || 'instances_train2017.json'}\"\n self.val_ann = \"${training.val_ann || 'instances_val2017.json'}\"\n self.test_ann = \"${training.test_ann || 'instances_test2017.json'}\"\n\n # --------------- transform config ----------------- #\n self.mosaic_prob = ${training.mosaic_prob !== undefined ? training.mosaic_prob : 1.0}\n self.mixup_prob = ${training.mixup_prob !== undefined ? training.mixup_prob : 1.0}\n self.hsv_prob = ${training.hsv_prob !== undefined ? training.hsv_prob : 1.0}\n self.flip_prob = ${training.flip_prob !== undefined ? training.flip_prob : 0.5}\n self.degrees = ${training.degrees !== undefined ? training.degrees : 10.0}\n self.translate = ${training.translate !== undefined ? training.translate : 0.1}\n self.mosaic_scale = (${Array.isArray(training.mosaic_scale) ? training.mosaic_scale.join(', ') : '0.1, 2'})\n self.enable_mixup = ${training.enable_mixup !== undefined ? training.enable_mixup : true}\n self.mixup_scale = (${Array.isArray(training.mixup_scale) ? training.mixup_scale.join(', ') : '0.5, 1.5'})\n self.shear = ${training.shear !== undefined ? training.shear : 2.0}\n\n # -------------- training config --------------------- #\n self.warmup_epochs = ${training.warmup_epochs !== undefined ? training.warmup_epochs : 5}\n self.max_epoch = ${training.max_epoch !== undefined ? training.max_epoch : 300}\n self.warmup_lr = ${training.warmup_lr !== undefined ? training.warmup_lr : 0}\n self.min_lr_ratio = ${training.min_lr_ratio !== undefined ? training.min_lr_ratio : 0.05}\n self.basic_lr_per_img = ${training.basic_lr_per_img !== undefined ? training.basic_lr_per_img : 0.01 / 64.0}\n self.scheduler = \"${training.scheduler || 'yoloxwarmcos'}\"\n self.no_aug_epochs = ${training.no_aug_epochs !== undefined ? training.no_aug_epochs : 15}\n self.ema = ${training.ema !== undefined ? training.ema : true}\n self.weight_decay = ${training.weight_decay !== undefined ? training.weight_decay : 5e-4}\n self.momentum = ${training.momentum !== undefined ? training.momentum : 0.9}\n self.print_interval = ${training.print_interval !== undefined ? training.print_interval : 10}\n self.eval_interval = ${training.eval_interval !== undefined ? training.eval_interval : 10}\n self.save_history_ckpt = ${training.save_history_ckpt !== undefined ? training.save_history_ckpt : true}\n self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(\".\")[0]\n\n # ----------------- testing config ------------------ #\n self.test_size = (${Array.isArray(training.test_size) ? training.test_size.join(', ') : '640, 640'})\n self.test_conf = ${training.test_conf !== undefined ? training.test_conf : 0.01}\n self.nmsthre = ${training.nmsthre !== undefined ? training.nmsthre : 0.65}\n\n # ... rest of the template ...\n\ndef check_exp_value(exp: Exp):\n h, w = exp.input_size\n assert h % 32 == 0 and w % 32 == 0, \"input size must be multiples of 32\"\n`;
|
|
// Save to file in output directory
|
|
const outDir = path.join(__dirname, '../../', training.project_id ? `project_${training.project_id}/${trainingId}` : 'exp_files');
|
|
if (!fs.existsSync(outDir)) fs.mkdirSync(outDir, { recursive: true });
|
|
const filePath = path.join(outDir, 'exp.py');
|
|
fs.writeFileSync(filePath, expTemplate);
|
|
return filePath;
|
|
}
|
|
|
|
module.exports = { pushYoloxExpToDb, generateYoloxExpFromDb }; |