diff --git a/MyLibs/experiment.cpp b/MyLibs/experiment.cpp index e3177ad..dc9cc76 100644 --- a/MyLibs/experiment.cpp +++ b/MyLibs/experiment.cpp @@ -259,3 +259,71 @@ void ExperimentData::runBlueTest() cout << "Experiment ended. " << endl; } + +bool fishAngleAnalysis_test(String fishVideoAddress, bool isGrey) { + VideoCapture capture(fishVideoAddress); + Point testTail; + Mat curImg; + + namedWindow("output", CV_WINDOW_NORMAL); + //namedWindow("org", CV_WINDOW_NORMAL); + capture >> curImg; + //findFishHeadAndCenter(curImg); + //getchar(); + for (int i = 0; i < 1400; i++) { + capture >> curImg; + } + + + int n = 0; + int checkPoint=0; + int boutStart = 0; + int predict; + for (int i = 0; i < 10000; i++) { + + Mat grey; + Point fishTail = Point(-1, -1); + double fishAngle[10000]; + capture >> curImg; + + if (!isGrey) { + cvtColor(curImg, grey, CV_BGR2GRAY); + if (!fishAngleAnalysis(grey, fishHead, fishCenter, &fishTail, &fishAngle[i], threshold_val)) { + cout << "AngleAnalysis error!" << endl; + return false; + } + } + else { + if (!fishAngleAnalysis(curImg, fishHead, fishCenter, &fishTail, &fishAngle[i], threshold_val)) { + cout << "AngleAnalysis error!" << endl; + return false; + } + } + + + if (n == checkPoint) { + if (boutStart > 0) { + predict=predict_left(&fishAngle[boutStart]); + cout << "predict:" << predict << endl; + boutStart = 0; + } + + if (fishAngle[i] > 0.2|| fishAngle[i] < -0.2) { + boutStart = i - 4; + checkPoint = checkPoint + 40; + cout << "Find bout!" << endl; + } + } + else + n++; + + cout << "fishAngle:" << fishAngle[i] << endl; + circle(curImg, fishTail, 1, Scalar(255, 0, 0), -1); + circle(curImg, tailPt_a, 1, Scalar(255, 0, 0), -1); + circle(curImg, tailPt_b, 1, Scalar(255, 0, 0), -1); + circle(curImg, topEnd, 1, Scalar(255, 0, 0), -1); + imshow("output", curImg); + waitKey(); + } + return true; +} diff --git a/MyLibs/experiment.h b/MyLibs/experiment.h index 2232cf0..3e02656 100644 --- a/MyLibs/experiment.h +++ b/MyLibs/experiment.h @@ -71,6 +71,6 @@ class Experiment FileWriter fileWriterObj; }; - +bool fishAngleAnalysis_test(String fishVideoAddress, bool isGrey); #endif _GUARD_EXPERIMENT_H diff --git a/MyLibs/fishAnalysis.cpp b/MyLibs/fishAnalysis.cpp index f5a5c32..5adcf55 100644 --- a/MyLibs/fishAnalysis.cpp +++ b/MyLibs/fishAnalysis.cpp @@ -228,7 +228,7 @@ void Arena::getImgFromVideo(cv::VideoCapture cap) cap >> opencvImg;//TODO: test this usage } -void Arena::alignImg(int deg2rotate) +void Arena::alignImg(int deg2rotate) { Mat rotatedImg; rotateImg(opencvImg, rotatedImg, deg2rotate); //TODO: write the implementation @@ -296,7 +296,7 @@ bool Arena::findSingleFish() int contourSize = contour.size(); if (contourSize < contourSizeThre) continue; // skip this turn - + if (maxContourSize < contour.size()) { maxContourSize = contour.size(); @@ -494,7 +494,7 @@ bool Fish::checkIfGiveShock(int sElapsed) { if (head.x == -1) // invalid frame return false; if (idxCase) // patternIdx == 1, since 2 is already excluded - { + { if (head.y < yDiv) // in non-CS area shockOn = false; else { @@ -538,7 +538,7 @@ bool Fish::checkIfReversePattern(int sElapsed) { idxCase = !idxCase; return true; - } + } else return false; } @@ -643,4 +643,82 @@ void rot90CW(Mat src, Mat dst) transpose(temp, dst); } +/*This function return a radian to describe the fishtailing motion */ +bool fishAngleAnalysis(Mat fishImg, Point fishHead, Point fishCenter, Point * fishTail_return, double* fishAngle,int threshold_val) { + //Find the contour of fish + Mat binaryzation; + double max_val = 255, maxFishArea = 10000, minFishArea = 1000; + vector> allContours, fishContours; + threshold(fishImg, binaryzation, threshold_val, max_val, CV_THRESH_BINARY); + findContours(binaryzation, allContours, CV_RETR_LIST, CHAIN_APPROX_NONE); + for (int i = 0; i < allContours.size(); i++) { + if (contourArea(allContours[i]) < maxFishArea && contourArea(allContours[i]) > minFishArea) + fishContours.push_back(allContours[i]); + } + if (fishContours.size() != 1) { + cout << "Can't find contour of fish!Area of all contours:"; + for (int i = 0; i < allContours.size(); i++) { + cout << contourArea(allContours[i]) << ','; + } + cout << endl; + return false; + } + + //Find the tail of fish + double Pt2center = norm(fishContours[0][0] - fishCenter); + topEnd = fishContours[0][0]; + + for (int i = 1; i < fishContours[0].size(); i++) + { + double curPt2center = norm(fishContours[0][i] - fishCenter); + if (Pt2center < curPt2center) { + topEnd = fishContours[0][i]; + Pt2center = curPt2center; + //circle(fishImg, topEnd, 1, Scalar(255), -1); + + } + + + } + Point tailAxis = topEnd - fishCenter; + tailPt_a = fishCenter + tailAxis * 9 / 10 + Point(tailAxis.y, -tailAxis.x)/4; + tailPt_b = fishCenter + tailAxis * 9 / 10 + Point(-tailAxis.y, tailAxis.x)/4; + vector fishTail = findPtsLineIntersectContour(fishContours[0], tailPt_a, tailPt_b); + + //Calculate the angle + Point fishHeadVector, fishTailVector; + fishHeadVector = fishCenter - fishHead; + fishTailVector = (fishContours[0][fishTail[0]]+ fishContours[0][fishTail[1]])/2 - fishCenter; + double sinfi; + sinfi = -(fishHeadVector.x * fishTailVector.y - fishTailVector.x * fishHeadVector.y) / (norm(fishHeadVector) * norm(fishTailVector)); + *fishAngle = asin(sinfi); + *fishTail_return = (fishContours[0][fishTail[0]] + fishContours[0][fishTail[1]]) / 2; + + return true; +} + +int predict_left(double* boutStart) { + Py_Initialize(); + if (Py_IsInitialized() == 0) { + cout << "Py_Initialize failed." << endl; + } + PyObject* pModule = PyImport_ImportModule("predict"); + if (pModule == NULL) + cout << "Py_ImportModule failed." << endl; + PyObject * pFunc = PyObject_GetAttrString(pModule, "predict_left"); + PyObject * PyList = PyList_New(40); + PyObject * ArgList = PyTuple_New(1); + for (int Index_i = 0; Index_i < PyList_Size(PyList); Index_i++) { + PyList_SetItem(PyList, Index_i, PyFloat_FromDouble(boutStart[Index_i])); + } + PyTuple_SetItem(ArgList, 0, PyList); + PyObject* pReturn = NULL; + pReturn = PyObject_CallObject(pFunc, ArgList); + int result; + PyArg_Parse(pReturn, "i", &result); + cout << "predict:" << result << endl; + Py_Finalize(); + return result; + +} diff --git a/MyLibs/fishAnalysis.h b/MyLibs/fishAnalysis.h index ae8e546..2ea6fd4 100644 --- a/MyLibs/fishAnalysis.h +++ b/MyLibs/fishAnalysis.h @@ -33,6 +33,8 @@ // Include user-defined libraries #include "errorHandling.h" +#include + /* Define related methods and properties for a single fish */ // Write every frame updated info at here? No! Create another class in fileWriter class class Fish { @@ -215,7 +217,7 @@ class FishAnalysis { { 223, 223, 588, 588 }, { 223, 223, 588, 588 } }; // TODO: make this variable private - + aIdx = 0; } @@ -227,7 +229,7 @@ class FishAnalysis { void initialize(std::vector numFishInArenas); /* Process image from camera */ - void preprocessImg(); + void preprocessImg(); /* Get image from camera */ void getImgFromCamera(int width, int height, uint8_t* buffer); @@ -281,7 +283,7 @@ class FishAnalysis { int numArenas; int aIdx; // index of arena to process private: - std::vector> yDivs; + std::vector> yDivs; }; @@ -305,7 +307,9 @@ double getPt2LineDistance(cv::Point2f P, cv::Point2f A, cv::Point2f B); /* Find 2 intersection points of a line (AB) and contour */ std::vector findPtsLineIntersectContour(std::vector& contour, cv::Point2f A, cv::Point2f B); - - +/*This function return a radian to describe the fishtailing motion */ +bool fishAngleAnalysis(Mat fishImg, Point fishHead, Point fishCenter, Point * fishTail_return, double* fishAngle,int threshold_val); +/*This function return the direction of motion(is or not left),which predicted from 40 radians*/ +int predict_left(double* boutStart); #endif // !_GUARD_FISHANALYSIS_H diff --git a/MyLibs/predict.py b/MyLibs/predict.py new file mode 100644 index 0000000..2c471d4 --- /dev/null +++ b/MyLibs/predict.py @@ -0,0 +1,14 @@ +import h2o +import numpy as np +import pandas as pd + +def predict_left(List): + PyList=List + #print(PyList) + df=pd.DataFrame(np.zeros((1,40)),columns=range(1,41)) + df.iloc[0]=PyList + test_predict=h2o.mojo_predict_pandas(df, 'E:/autoML/left_true_0513_accuracy_09516/GBM_grid_1_AutoML_20190513_170806_model_7.zip') + print(test_predict) + return int(test_predict.iat[0,0]) + +#def predict_right(List): diff --git a/MyLibs/userInterface.cpp b/MyLibs/userInterface.cpp index 96a3254..99ef697 100644 --- a/MyLibs/userInterface.cpp +++ b/MyLibs/userInterface.cpp @@ -57,13 +57,13 @@ void UserInterface::enquireDevice2use(std::istream& is) { devices2use[1] = 1; enquirePattern2use(is); - } + } else if (!s.compare("3")) { devices2use[2] = 1; enquireCameras2use(is); } - + else { cout << "Invalid input! Please enter again." << endl; @@ -87,7 +87,7 @@ void UserInterface::enquireCameras2use(std::istream& is) for (int i = 0; i < numCameras; i++) cameras2open[i] = stoi(tempStrVec[i]); - + for (int i = 0; i < cameras2open.size(); i++) { if (cameras2open[i]) @@ -227,7 +227,7 @@ void UserInterface::generateBasenames() string UserInterface::generateBasename(int idxFile) { string baseName = - startTimeStr + "_" + "Arena" + startTimeStr + "_" + "Arena" + to_string(arenaIDs[idxFile]) + "_" + strainName + "_" + to_string(fishAge)+ "dpf_" @@ -251,7 +251,7 @@ void showWelcomeMsg() Copyright 2018 Wenbin Yang */ cout << "Welcome to BLITZ (Behavioral Learning In The Zebrafish)." << endl - << "This program is under GNU 3.0 License." << endl + << "This program is under GNU 3.0 License." << endl << "Most updated code and other resources can be found at " << endl << "https://github.com/Wenlab/BLITZ" << endl << "Please cite (Wenbin Yang et al., 2019) if you use any portion of this program." << endl @@ -283,7 +283,7 @@ vector getStrVecFromCMD(std::istream& is) string inputStr; getline(is, inputStr); cout << endl; // separated with an empty line - + istringstream ss; ss.clear(); @@ -315,3 +315,59 @@ string getCurDateTime() timeStr = buffer; return timeStr; } + +static void on_trackbar_setThreshold(int, void*) { + int max_val = 255; + Mat binaryzation = Mat::zeros(cur_img.size(), CV_8UC1); + threshold(cur_img, binaryzation, threshold_val, max_val, CV_THRESH_BINARY); + + imshow("setThreshold", binaryzation); + +} + +bool setThreshold() { + namedWindow("setThreshold", CV_WINDOW_NORMAL); + createTrackbar("Threshold", "setThreshold", &threshold_val, 255, on_trackbar_setThreshold); + on_trackbar_setThreshold(threshold_val, 0); + cout<<"Press 'q' to exit."< #include // to get the current date and time +// Include OpenCV libraries +#include +#include +#include +using namespace cv; + // TODO: make a GUI, alternative1: GMU readline command line interface /* Talk to users via command line interface */ class UserInterface @@ -38,7 +44,7 @@ class UserInterface public: UserInterface() { - + } /* Ask the user about the experiment infos */ @@ -49,43 +55,43 @@ class UserInterface /* Ask which cameras to use. */ void enquireCameras2use(std::istream& is); - + /* Ask which visual pattern to use */ void enquirePattern2use(std::istream& is); - + /* Ask how many fish under a camera */ int enquireNumFishForACam(std::istream& is, int idxCamera); - + /* Generate fish IDs for all fish */ std::vector generateFishIDs(int numFish); - + /* Ask for what strain of Fish is using, assume all fish are the same strain */ void enquireFishStrain(std::istream& is); - + /* Ask for the age for all fish, assume all fish are at the same age */ void enquireFishAge(std::istream& is); - + /* Ask for what experiment task for fish */ void enquireExpTask(std::istream& is); - + /* Generate basenames for all output files */ void generateBasenames(); - + /* Generate basenames for the output files */ std::string generateBasename(int idxFile); - + // Properties // TODO: get the number of cameras and pass it to fileWriterObj @@ -102,6 +108,7 @@ class UserInterface std::vector baseNames; int numOpenCameras; + }; @@ -119,11 +126,18 @@ std::vector getStrVecFromCMD(std::istream& is); /* Get current date and time string from chrono system clock, depends on */ std::string getCurDateTime(); - - - - - +//some global variables used in trackbar and mouse_findHeadAndCenter +cv::Mat cur_img; +Point fishEye1,fishEye2,fishCenter, fishHead; +int threshold_val = 16; +//callback function used in setThreshold() +static void on_trackbar_setThreshold(int, void*); +//create a trackbar from a Mat to set threshold value +bool setThreshold(); +//callback function used in findHeadAndCenter() +void on_mouse_findHeadAndCenter(int event, int x, int y, int flags, void* ustc); +//use click to select head and center +bool findHeadAndCenter();