forked from ArduCAM/pico-tflmicro
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_functions.cpp
More file actions
executable file
·174 lines (145 loc) · 6.11 KB
/
main_functions.cpp
File metadata and controls
executable file
·174 lines (145 loc) · 6.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "main_functions.h"
#include "LCD_st7735.h"
#include "pico/stdlib.h"
#include "accelerometer_handler.h"
#include "constants.h"
#include "gesture_predictor.h"
#include "magic_wand_model_data.h"
#include "output_handler.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
const uint LED_PIN = 25;
// Global variables, used to be compatible with Arduino style sketches.
namespace {
tflite::ErrorReporter *error_reporter = nullptr;
const tflite::Model *model = nullptr;
tflite::MicroInterpreter *interpreter = nullptr;
TfLiteTensor *model_input = nullptr;
int input_length;
// Create a memory area for input, output and intermediate arrays.
// The size depends on the model you are using and may need to be determined
// experimentally.
constexpr int kTensorArenaSize = 14 * 1024 + 1332;
uint8_t tensor_arena[kTensorArenaSize];
} // namespace
// The name of this function is very important for Arduino compatibility.
void setup() {
ST7735_Init();
ST7735_DrawImage(0, 0, 80, 160, arducam_logo);
// Set up logging.
// Google's style is to avoid global variables or static variables due to the
// uncertainty of the life cycle, but because it has a trivial destructor, it can.
static tflite::MicroErrorReporter micro_error_reporter; // NOLINT
error_reporter = µ_error_reporter;
// Map the model to the available data structure.
// This does not involve any copying or parsing, which is a very lightweight
// operation.
model = tflite::GetModel(g_magic_wand_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
// Only introduce the operation implementation we need.
// It depends on the complete list of all operations required for this graph.
// A simpler method is to use AllOpsResolver only,
// but this will result in a loss of code space for op implementations that are not
// needed in this figure.
static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT
micro_op_resolver.AddConv2D();
micro_op_resolver.AddDepthwiseConv2D();
micro_op_resolver.AddFullyConnected();
micro_op_resolver.AddMaxPool2D();
micro_op_resolver.AddSoftmax();
// Build an interpreter to run the model.
static tflite::MicroInterpreter static_interpreter(
model, micro_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
// Allocate memory from tensor_arena for the tensor of the model.
interpreter->AllocateTensors();
// Get a pointer to the input tensor of the model.
model_input = interpreter->input(0);
if ((model_input->dims->size != 4) || (model_input->dims->data[0] != 1)
|| (model_input->dims->data[1] != 128)
|| (model_input->dims->data[2] != kChannelNumber)
|| (model_input->type != kTfLiteFloat32)) {
TF_LITE_REPORT_ERROR(error_reporter, "Bad input tensor parameters in model");
return;
}
input_length = model_input->bytes / sizeof(float);
TfLiteStatus setup_status = SetupAccelerometer(error_reporter);
if (setup_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Set up failed\n");
}
ST7735_FillScreen(ST7735_GREEN);
ST7735_WriteString(5, 20, "Magic", Font_11x18, ST7735_BLACK, ST7735_GREEN);
ST7735_WriteString(30, 45, "Wand", Font_11x18, ST7735_BLACK, ST7735_GREEN);
gpio_init(LED_PIN);
gpio_set_dir(LED_PIN, GPIO_OUT);
}
void loop() {
#if EXECUTION_TIME
TF_LITE_MICRO_EXECUTION_TIME_BEGIN
TF_LITE_MICRO_EXECUTION_TIME_SNIPPET_START(error_reporter)
#endif
// Try to read new data from the accelerometer.
bool got_data = ReadAccelerometer(error_reporter, model_input->data.f, input_length);
// If there is no new data, please wait for the next time.
if (!got_data)
return;
#if EXECUTION_TIME
TF_LITE_MICRO_EXECUTION_TIME_SNIPPET_END(error_reporter, "ReadAccelerometer")
TF_LITE_MICRO_EXECUTION_TIME_SNIPPET_START(error_reporter)
#endif
gpio_put(LED_PIN, 1);
// Run inference and report any errors.
TfLiteStatus invoke_status = interpreter->Invoke();
#if EXECUTION_TIME
TF_LITE_MICRO_EXECUTION_TIME_SNIPPET_END(error_reporter, "Invoke")
#endif
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on index: %d\n", begin_index);
return;
}
// Analyze the results to get predictions
int gesture_index = PredictGesture(interpreter->output(0)->data.f);
// Produces output
HandleOutput(error_reporter, gesture_index);
#if 0
char s[64];
float *f = model_input->data.f;
float *p = interpreter->output(0)->data.f;
sprintf(s, "%+6.0f : %+6.0f : %+6.0f || W %3.2f : R %3.2f : S %3.2f", f[381], f[382],
f[383], p[0], p[1], p[2]);
TF_LITE_REPORT_ERROR(error_reporter, s);
// for (int i = 0; i < 3; i++) {
// printf("%d : ", i);
// int barNum = static_cast<int>(roundf(p[i] * 10));
// for (int k = 0; k < barNum; k++) {
// printf("\u2588"); // "█"
// }
// for (int k = barNum - 1; k < 10; k++) {
// printf(" ");
// }
// printf(" ");
// }
// printf("\n");
#endif
gpio_put(LED_PIN, 0);
}