diff --git a/README.md b/README.md index cad1abd..d048f1b 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,63 @@ CUDA Rasterizer =============== -[CLICK ME FOR INSTRUCTION OF THIS PROJECT](./INSTRUCTION.md) +![demogif](renders/demo.gif) -**University of Pennsylvania, CIS 565: GPU Programming and Architecture, Project 4** +**A rasterization pipeline build in Nvidia's CUDA parallel API. Features perspective-correct, bilinear filtered texture mapping, backface culling with stream compaction, supersampled antialiasing.** +**University of Pennsylvania, CIS 565: GPU Programming and Architecture, Project 6** -* (TODO) YOUR NAME HERE -* Tested on: (TODO) Windows 22, i7-2222 @ 2.22GHz 22GB, GTX 222 222MB (Moore 2222 Lab) +* Daniel Daley-Mongtomery +* Tested on: MacBook Pro, OSX 10.12, i7 @ 2.3GHz, 16GB RAM, GT 750M 2048MB (Personal Machine) +### Implementation -### (TODO: Your README) + No 3D graphics pipeline can afford to physically model rays of light like a path tracer if it has to produce consistent imagery at interactive speeds. Modern real-time APIs like OpenGL, D3D, and Vulkan generally assume a process wherein vertex data is transformed, assembled into triangles, and projected onto a 2D screen space (rasterized) to be lit on a pixel-by-pixel basis. In this project I will implement such a pipeline in CUDA in order better grasp the subtleties and bottlenecks involved in interactive graphics. -*DO NOT* leave the README to the last minute! It is a crucial part of the -project, and we will not be able to grade you without a good README. +##### Compose Vertices and Assemble primitives + Perhaps the simplest step comes first: having recieved input geometry from a provided gltf reader, I transform each vertex in parallel, storing its screen-space position, depth, and normals, as well as texture and UV invformation. I then traverse the indices specified by the gltf model to assign each vertex to a triangle primitive. + + Before moving onto rasterization, an extremely expensive operation, I make sure to cull geometry facing away form the camera. Using thrust's stream compaction, I can group all front-facing triangles at the beginning of the primitive array, and so launch fewer rasterization kernels later on. Below is a debug image showing *only* those backfaces which my final rasterizer eliminates. + +![](renders/reverseCull.PNG) + +##### Rasterize Primitives + + The biggest challenge I encountered was rasterization. While it seems simple enough, I poured hours into different methods to speed it up to little avail. I parallelized over triangles; each kernel checked the XY bounding box of its triangle to see if any given fragment lay within its boundaries, as shown below: + +![](https://www.scratchapixel.com/images/upload/rasterization/raytracing-raster2.png?) + + When I found a fragment that did fall within the triangle, I checked a screen-sized depth buffer to see if any closer fragments had already been uncovered. If they hadn't, I atomically updated the depth buffer, then packed barycentrically-interpolated normal, UV, and texture data to a fragment buffer for later shading. In this way, no occluded fragments end up shaded, rather like a deferred pipeline. + +Because perspective projection shrinks distant objects, linearly interpolating in screen space yields incorrect UV values. I had to adjust the UV-interpolation [accordingly](http://web.cs.ucdavis.edu/~amenta/s12/perspectiveCorrect.pdf). + +| None | Some | +| ------------- |:-------------:| +| ![](renders/NoPerspectiveCorrection.PNG) | ![](renders/PerspectiveCorrection.PNG)| + + Being low-hanging fruit, I also threw in naive supersampled antialiasing. I uniformly subdivided each fragment and ran both my rasterization and shading step for each, accumulating framebuffer contributions for each. + +| Single Sample | 2x2 Uniform | +| ------------- |:-------------:| +| ![](renders/duckWithNoAntiAliasing.PNG) | ![duck](renders/DuckWithAntiAliasing.PNG)| + +##### Shade Fragments + +The last step was to shade the fragments that made it through. With lambert shading from the normal simple texture lookups based on UVs, I could create colorful triangles in realtime! I added bilinear texure filtering to smooth out lower-res images and more fairly treat fragments whose UVs didn't fall dead-center on a pixel: + +| Single Pixel Sample | Bilinearly Sampled | +| ------------- |:-------------:| +| ![](renders/NoFilteringWith128x128Texture.PNG) | ![](renders/BilinearFilteringWith128x128Texture.PNG)| +| ![](renders/NoFilteringWith512x512Texture.PNG) | ![](renders/BilinearFilteringWith512x512Texture.PNG)| + +##### Performance + +![](renders/RastPerf.png) + + + + The cost of my texture features was not as influential as I might have expected, largely because of the amount of time taken by my rasterization step. But given the weight of rasterization, I was surprised how little backface culling helped. I could demonstrate that I was launching about half as many _generateFragments()_ kernels, but saw a < 10% speedup. Clearly my biggest issue was the duration of these kernels. + +Some remaining hopes I have at improvement are better a [bettter scanline algorithm](http://forum.devmaster.net/t/advanced-rasterization/6145/24), and sorting triangles by their XY bounding-box area before rasterizing. While it might not be cheap enough to be worthwhile, it would put kernels with similar numbers of tests in the same warps and less time is spent waiting. ### Credits diff --git a/gltfs/CesiumMilkTruck/CesiumMilkTruck.png b/gltfs/CesiumMilkTruck/CesiumMilkTruck.png index ba7a47c..b9fb85b 100644 Binary files a/gltfs/CesiumMilkTruck/CesiumMilkTruck.png and b/gltfs/CesiumMilkTruck/CesiumMilkTruck.png differ diff --git a/gltfs/CesiumMilkTruck/CesiumMilkTruckHighRes.png b/gltfs/CesiumMilkTruck/CesiumMilkTruckHighRes.png new file mode 100644 index 0000000..ba7a47c Binary files /dev/null and b/gltfs/CesiumMilkTruck/CesiumMilkTruckHighRes.png differ diff --git a/renders/BilinearFilteringWith128x128Texture.PNG b/renders/BilinearFilteringWith128x128Texture.PNG new file mode 100644 index 0000000..c1b07ab Binary files /dev/null and b/renders/BilinearFilteringWith128x128Texture.PNG differ diff --git a/renders/BilinearFilteringWith512x512Texture.PNG b/renders/BilinearFilteringWith512x512Texture.PNG new file mode 100644 index 0000000..ae7cb32 Binary files /dev/null and b/renders/BilinearFilteringWith512x512Texture.PNG differ diff --git a/renders/DuckWithAntiAliasing.PNG b/renders/DuckWithAntiAliasing.PNG new file mode 100644 index 0000000..2f88c05 Binary files /dev/null and b/renders/DuckWithAntiAliasing.PNG differ diff --git a/renders/NoFilteringWith128x128Texture.PNG b/renders/NoFilteringWith128x128Texture.PNG new file mode 100644 index 0000000..7183107 Binary files /dev/null and b/renders/NoFilteringWith128x128Texture.PNG differ diff --git a/renders/NoFilteringWith512x512Texture.PNG b/renders/NoFilteringWith512x512Texture.PNG new file mode 100644 index 0000000..93d05ad Binary files /dev/null and b/renders/NoFilteringWith512x512Texture.PNG differ diff --git a/renders/NoPerspectiveCorrection.PNG b/renders/NoPerspectiveCorrection.PNG new file mode 100644 index 0000000..96507bd Binary files /dev/null and b/renders/NoPerspectiveCorrection.PNG differ diff --git a/renders/PerspectiveCorrection.PNG b/renders/PerspectiveCorrection.PNG new file mode 100644 index 0000000..5ddfbf1 Binary files /dev/null and b/renders/PerspectiveCorrection.PNG differ diff --git a/renders/RastPerf.png b/renders/RastPerf.png new file mode 100644 index 0000000..967be9d Binary files /dev/null and b/renders/RastPerf.png differ diff --git a/renders/demo.gif b/renders/demo.gif new file mode 100644 index 0000000..1d8ac6f Binary files /dev/null and b/renders/demo.gif differ diff --git a/renders/duckWithNoAntiAliasing.PNG b/renders/duckWithNoAntiAliasing.PNG new file mode 100644 index 0000000..061027d Binary files /dev/null and b/renders/duckWithNoAntiAliasing.PNG differ diff --git a/renders/percent.png b/renders/percent.png new file mode 100644 index 0000000..c3976bd Binary files /dev/null and b/renders/percent.png differ diff --git a/renders/reverseCull.PNG b/renders/reverseCull.PNG new file mode 100644 index 0000000..e3338cf Binary files /dev/null and b/renders/reverseCull.PNG differ diff --git a/src/main.cpp b/src/main.cpp index 7986959..22492e8 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -19,16 +19,20 @@ //------------------------------- int main(int argc, char **argv) { + + std::string input_filename; if (argc != 2) { - cout << "Usage: [gltf file]. Press Enter to exit" << endl; - getchar(); - return 0; - } + cout << "No filename included in args." << endl; + cout << "Please provide a path to a .gltf file, or drag on onto this window from the explorer." << endl; + getline(cin, input_filename); + } + else { + input_filename = argv[1]; + } tinygltf::Scene scene; tinygltf::TinyGLTFLoader loader; std::string err; - std::string input_filename(argv[1]); std::string ext = getFilePathExtension(input_filename); bool ret = false; @@ -64,6 +68,7 @@ int main(int argc, char **argv) { } void mainLoop() { + while (!glfwWindowShouldClose(window)) { glfwPollEvents(); runCuda(); @@ -71,12 +76,13 @@ void mainLoop() { time_t seconds2 = time (NULL); if (seconds2 - seconds >= 1) { - + cout << 1000.0 * (seconds2 - seconds) / fpstracker << std::endl; fps = fpstracker / (seconds2 - seconds); fpstracker = 0; seconds = seconds2; } + string title = "CIS565 Rasterizer | " + utilityCore::convertIntToString((int)fps) + " FPS"; glfwSetWindowTitle(window, title.c_str()); @@ -97,8 +103,8 @@ void mainLoop() { //---------RUNTIME STUFF--------- //------------------------------- float scale = 1.0f; -float x_trans = 0.0f, y_trans = 0.0f, z_trans = -10.0f; -float x_angle = 0.0f, y_angle = 0.0f; +float x_trans = -0.5f, y_trans = -2.0f, z_trans = -7.8f; +float x_angle = 0.0f, y_angle = 3.14159 * 2.0 * -0.105f; void runCuda() { // Map OpenGL buffer object for writing from CUDA on a single GPU // No data is moved (Win & Linux). When mapped to CUDA, OpenGL should not use this buffer @@ -109,6 +115,7 @@ void runCuda() { -scale, scale, 1.0, 1000.0); glm::mat4 V = glm::mat4(1.0f); + y_angle += (float)0.02; glm::mat4 M = glm::translate(glm::vec3(x_trans, y_trans, z_trans)) @@ -120,10 +127,13 @@ void runCuda() { glm::mat4 MVP = P * MV; cudaGLMapBufferObject((void **)&dptr, pbo); + + clock_t begin; rasterize(dptr, MVP, MV, MV_normal); cudaGLUnmapBufferObject(pbo); frame++; + iterations++; fpstracker++; } diff --git a/src/main.hpp b/src/main.hpp index 4816fa1..d8b0d82 100644 --- a/src/main.hpp +++ b/src/main.hpp @@ -40,6 +40,10 @@ const char *attributeLocations[] = { "Position", "Tex" }; GLuint pbo = (GLuint)NULL; GLuint displayImage; uchar4 *dptr; +int iterations = -100; +double cumulative = 0; + + GLFWwindow *window; diff --git a/src/rasterize.cu b/src/rasterize.cu index 1262a09..f140815 100644 --- a/src/rasterize.cu +++ b/src/rasterize.cu @@ -18,65 +18,66 @@ #include #include +#include +#include +#include +#include + +#define BACKFACE_CULLING 1 +#define PERSP_CORRECT 1 +#define SORT_BY_AREA 1 +#define BILINEAR 1 +#define SAMPLES 1 +#define SAMPLE_WEIGHT (1.0f / (SAMPLES * SAMPLES)) +#define SAMPLE_JITTER (1.0f / SAMPLES) +#pragma region Assumed namespace { typedef unsigned short VertexIndex; + typedef unsigned char TextureData; + typedef unsigned char BufferByte; typedef glm::vec3 VertexAttributePosition; typedef glm::vec3 VertexAttributeNormal; typedef glm::vec2 VertexAttributeTexcoord; - typedef unsigned char TextureData; - typedef unsigned char BufferByte; - - enum PrimitiveType{ + enum PrimitiveType { Point = 1, Line = 2, Triangle = 3 }; struct VertexOut { - glm::vec4 pos; - - // TODO: add new attributes to your VertexOut - // The attributes listed below might be useful, - // but always feel free to modify on your own - - glm::vec3 eyePos; // eye space position used for shading - glm::vec3 eyeNor; // eye space normal used for shading, cuz normal will go wrong after perspective transformation - // glm::vec3 col; - glm::vec2 texcoord0; - TextureData* dev_diffuseTex = NULL; - // int texWidth, texHeight; - // ... + + glm::vec4 vertexEyePos; //the position before perspective distort (Eye Space) + glm::vec4 vertexPerspPos; //the position in Final Device Coords + glm::vec3 vertexNormal; //Eye Space Normal + glm::vec2 vertexUV; //Built-in UV + TextureData* diffuseTexture = NULL; + int texWidth, texHeight; }; struct Primitive { PrimitiveType primitiveType = Triangle; // C++ 11 init VertexOut v[3]; + bool cull; }; struct Fragment { - glm::vec3 color; - - // TODO: add new attributes to your Fragment - // The attributes listed below might be useful, - // but always feel free to modify on your own - - // glm::vec3 eyePos; // eye space position used for shading - // glm::vec3 eyeNor; - // VertexAttributeTexcoord texcoord0; - // TextureData* dev_diffuseTex; - // ... + glm::vec3 eyeNormal; + glm::vec2 UV; + TextureData* diffuseTexture = NULL; + int texWidth, texHeight; }; struct PrimitiveDevBufPointers { - int primitiveMode; //from tinygltfloader macro - PrimitiveType primitiveType; + //from tinygltfloader macro + int primitiveMode; int numPrimitives; int numIndices; int numVertices; // Vertex In, const after loaded + PrimitiveType primitiveType; VertexIndex* dev_indices; VertexAttributePosition* dev_position; VertexAttributeNormal* dev_normal; @@ -86,83 +87,131 @@ namespace { TextureData* dev_diffuseTex; int diffuseTexWidth; int diffuseTexHeight; - // TextureData* dev_specularTex; - // TextureData* dev_normalTex; - // ... // Vertex Out, vertex used for rasterization, this is changing every frame VertexOut* dev_verticesOut; - - // TODO: add more attributes when needed }; } static std::map> mesh2PrimitivesMap; - static int width = 0; static int height = 0; static int totalNumPrimitives = 0; static Primitive *dev_primitives = NULL; +thrust::device_ptr dev_thrust_primitives; + static Fragment *dev_fragmentBuffer = NULL; static glm::vec3 *dev_framebuffer = NULL; +static int * dev_depth = NULL; -static int * dev_depth = NULL; // you might need this buffer when doing depth test - -/** - * Kernel that writes the image to the OpenGL PBO directly. - */ -__global__ +//write to PBO +__global__ void sendImageToPBO(uchar4 *pbo, int w, int h, glm::vec3 *image) { - int x = (blockIdx.x * blockDim.x) + threadIdx.x; - int y = (blockIdx.y * blockDim.y) + threadIdx.y; - int index = x + (y * w); - - if (x < w && y < h) { - glm::vec3 color; - color.x = glm::clamp(image[index].x, 0.0f, 1.0f) * 255.0; - color.y = glm::clamp(image[index].y, 0.0f, 1.0f) * 255.0; - color.z = glm::clamp(image[index].z, 0.0f, 1.0f) * 255.0; - // Each thread writes one pixel location in the texture (textel) - pbo[index].w = 0; - pbo[index].x = color.x; - pbo[index].y = color.y; - pbo[index].z = color.z; - } + int x = (blockIdx.x * blockDim.x) + threadIdx.x; + int y = (blockIdx.y * blockDim.y) + threadIdx.y; + int index = x + (y * w); + + if (x < w && y < h) { + glm::vec3 color; + color.x = glm::clamp(image[index].x, 0.0f, 1.0f) * 255.0; + color.y = glm::clamp(image[index].y, 0.0f, 1.0f) * 255.0; + color.z = glm::clamp(image[index].z, 0.0f, 1.0f) * 255.0; + // Each thread writes one pixel location in the texture (textel) + pbo[index].w = 0; + pbo[index].x = color.x; + pbo[index].y = color.y; + pbo[index].z = color.z; + } +} + +#define inverse255 0.00392156863f //1/255 for conversion from unit to 8bit color +__device__ glm::vec3 colorAtPoint(TextureData* t, int index) { + return inverse255 * glm::vec3( + (float)t[index * 3], + (float)t[index * 3 + 1], + (float)t[index * 3 + 2] + ); +} +#pragma endregion + +//baseline sample floors the uv and returns a single color +__device__ glm::vec3 sample(glm::vec2 uv, TextureData* tex, int texWidth, int texHeight) { + int uIndex = uv[0] * texWidth; + int vIndex = uv[1] * texHeight; + int Index1D = uIndex + texWidth * vIndex; + + return colorAtPoint(tex, Index1D); +} + +//biliear sample does four color reads and weights them according to the continuous uv value +__device__ glm::vec3 sampleBilinear(glm::vec2 uv, TextureData* tex, int texWidth, int texHeight) { + float uFraction = uv[0] * texWidth; + int uIndex = uFraction; + uFraction -= uIndex; + + float vFraction = uv[1] * texHeight; + int vIndex = vFraction; + vFraction -= vIndex; + + int Index1D = uIndex + texWidth * vIndex; + + return + (1.0f - uFraction)*(1.0f - vFraction) * colorAtPoint(tex, Index1D) + //current pixel + uFraction * (1.0f - vFraction) * colorAtPoint(tex, uIndex + 1 < texWidth ? Index1D + 1 : Index1D) + //next u pixel (if exists) + (1.0f - uFraction) * vFraction * colorAtPoint(tex, vIndex + 1 < texHeight ? Index1D + texWidth : Index1D) + //next y pixel (if exists) + uFraction * vFraction * colorAtPoint(tex, uIndex + 1 < texWidth && vIndex + 1 < texHeight ? Index1D + texWidth + 1 : Index1D); + //^next diagonal pixel (if both right and up exist) } -/** -* Writes fragment colors to the framebuffer -*/ __global__ void render(int w, int h, Fragment *fragmentBuffer, glm::vec3 *framebuffer) { - int x = (blockIdx.x * blockDim.x) + threadIdx.x; - int y = (blockIdx.y * blockDim.y) + threadIdx.y; - int index = x + (y * w); - - if (x < w && y < h) { - framebuffer[index] = fragmentBuffer[index].color; + int x = (blockIdx.x * blockDim.x) + threadIdx.x; + int y = (blockIdx.y * blockDim.y) + threadIdx.y; + int index = x + (y * w); - // TODO: add your fragment shader code here + if (x < w && y < h) { + Fragment f = fragmentBuffer[index]; + if (f.eyeNormal == glm::vec3(0)) { + framebuffer[index] = glm::vec3(0); + return; + } - } + glm::vec3 lightDir(1, 1, 1); + glm::vec3 color; + if (false || f.diffuseTexture == NULL) color = glm::vec3(1, 1, 1); + else +#if BILINEAR + color = sampleBilinear(f.UV, f.diffuseTexture, f.texWidth, f.texHeight); +#else + color = sample(f.UV, f.diffuseTexture, f.texWidth, f.texHeight); +#endif + + //compute lambert value by angle; keep it positive + float lightValue = glm::max(glm::dot(fragmentBuffer[index].eyeNormal, lightDir), 0.4f); + glm::vec3 diffuse = glm::clamp(color * lightValue, 0.0f, 1.0f); + framebuffer[index] += SAMPLE_WEIGHT * diffuse; + } } +#pragma region Assumed /** * Called once at the beginning of the program to allocate memory. */ void rasterizeInit(int w, int h) { - width = w; - height = h; + width = w; + height = h; + cudaFree(dev_fragmentBuffer); cudaMalloc(&dev_fragmentBuffer, width * height * sizeof(Fragment)); cudaMemset(dev_fragmentBuffer, 0, width * height * sizeof(Fragment)); - cudaFree(dev_framebuffer); - cudaMalloc(&dev_framebuffer, width * height * sizeof(glm::vec3)); - cudaMemset(dev_framebuffer, 0, width * height * sizeof(glm::vec3)); - + + cudaFree(dev_framebuffer); + cudaMalloc(&dev_framebuffer, width * height * sizeof(glm::vec3)); + cudaMemset(dev_framebuffer, 0, width * height * sizeof(glm::vec3)); + cudaFree(dev_depth); cudaMalloc(&dev_depth, width * height * sizeof(int)); @@ -187,9 +236,9 @@ void initDepth(int w, int h, int * depth) * kern function with support for stride to sometimes replace cudaMemcpy * One thread is responsible for copying one component */ -__global__ +__global__ void _deviceBufferCopy(int N, BufferByte* dev_dst, const BufferByte* dev_src, int n, int byteStride, int byteOffset, int componentTypeByteSize) { - + // Attribute (vec3 position) // component (3 * float) // byte (4 * byte) @@ -202,20 +251,20 @@ void _deviceBufferCopy(int N, BufferByte* dev_dst, const BufferByte* dev_src, in int offset = i - count * n; // which component of the attribute for (int j = 0; j < componentTypeByteSize; j++) { - - dev_dst[count * componentTypeByteSize * n - + offset * componentTypeByteSize + + dev_dst[count * componentTypeByteSize * n + + offset * componentTypeByteSize + j] - = + = - dev_src[byteOffset - + count * (byteStride == 0 ? componentTypeByteSize * n : byteStride) - + offset * componentTypeByteSize + dev_src[byteOffset + + count * (byteStride == 0 ? componentTypeByteSize * n : byteStride) + + offset * componentTypeByteSize + j]; } } - + } @@ -235,7 +284,7 @@ void _nodeMatrixTransform( } glm::mat4 getMatrixFromNodeMatrixVector(const tinygltf::Node & n) { - + glm::mat4 curMatrix(1.0); const std::vector &m = n.matrix; @@ -247,7 +296,8 @@ glm::mat4 getMatrixFromNodeMatrixVector(const tinygltf::Node & n) { curMatrix[i][j] = (float)m.at(4 * i + j); } } - } else { + } + else { // no matrix, use rotation, scale, translation if (n.translation.size() > 0) { @@ -275,12 +325,12 @@ glm::mat4 getMatrixFromNodeMatrixVector(const tinygltf::Node & n) { return curMatrix; } -void traverseNode ( +void traverseNode( std::map & n2m, const tinygltf::Scene & scene, const std::string & nodeString, const glm::mat4 & parentMatrix - ) +) { const tinygltf::Node & n = scene.nodes.at(nodeString); glm::mat4 M = parentMatrix * getMatrixFromNodeMatrixVector(n); @@ -537,7 +587,7 @@ void rasterizeSetBuffers(const tinygltf::Scene & scene) { size_t s = image.image.size() * sizeof(TextureData); cudaMalloc(&dev_diffuseTex, s); cudaMemcpy(dev_diffuseTex, &image.image.at(0), s, cudaMemcpyHostToDevice); - + diffuseTexWidth = image.width; diffuseTexHeight = image.height; @@ -554,7 +604,7 @@ void rasterizeSetBuffers(const tinygltf::Scene & scene) { // ---------Node hierarchy transform-------- cudaDeviceSynchronize(); - + dim3 numBlocksNodeTransform((numVertices + numThreadsPerBlock.x - 1) / numThreadsPerBlock.x); _nodeMatrixTransform << > > ( numVertices, @@ -569,11 +619,11 @@ void rasterizeSetBuffers(const tinygltf::Scene & scene) { // push dev pointers to map primitiveVector.push_back(PrimitiveDevBufPointers{ primitive.mode, - primitiveType, numPrimitives, numIndices, numVertices, + primitiveType, dev_indices, dev_position, dev_normal, @@ -595,21 +645,22 @@ void rasterizeSetBuffers(const tinygltf::Scene & scene) { } // for each node } - + // 3. Malloc for dev_primitives { cudaMalloc(&dev_primitives, totalNumPrimitives * sizeof(Primitive)); + dev_thrust_primitives = thrust::device_ptr(dev_primitives); } - + // Finally, cudaFree raw dev_bufferViews { std::map::const_iterator it(bufferViewDevPointers.begin()); std::map::const_iterator itEnd(bufferViewDevPointers.end()); - - //bufferViewDevPointers + + //bufferViewDevPointers for (; it != itEnd; it++) { cudaFree(it->second); @@ -621,13 +672,31 @@ void rasterizeSetBuffers(const tinygltf::Scene & scene) { } +static int curPrimitiveBeginId = 0; + +#pragma endregion + +__global__ +void _primitiveAssembly(int numIndices, int curPrimitiveBeginId, Primitive* dev_primitives, PrimitiveDevBufPointers primitive) { + // index id + int iid = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (iid < numIndices) { + int pid; // id for cur primitives vector + pid = iid / (int)primitive.primitiveType + curPrimitiveBeginId; + VertexOut v = primitive.dev_verticesOut[primitive.dev_indices[iid]]; + dev_primitives[pid].v[iid % (int)primitive.primitiveType] = v; + dev_primitives[pid].cull = v.vertexNormal[2] < 0; + } -__global__ +} + +__global__ void _vertexTransformAndAssembly( - int numVertices, - PrimitiveDevBufPointers primitive, - glm::mat4 MVP, glm::mat4 MV, glm::mat3 MV_normal, + int numVertices, + PrimitiveDevBufPointers primitive, + glm::mat4 MVP, glm::mat4 MV, glm::mat3 MV_normal, int width, int height) { // vertex id @@ -638,108 +707,174 @@ void _vertexTransformAndAssembly( // Multiply the MVP matrix for each vertex position, this will transform everything into clipping space // Then divide the pos by its w element to transform into NDC space // Finally transform x and y to viewport space - - // TODO: Apply vertex assembly here - // Assemble all attribute arraies into the primitive array - + glm::vec4 world = glm::vec4(primitive.dev_position[vid], 1.0f); + glm::vec4 clip = MVP * world; + clip /= clip.w; + + //NDC -> Screen + clip.x = (1.0f - clip.x) * width / 2.0f; + clip.y = (1.0f - clip.y) * height / 2.0f; +#if PERSP_CORRECT + primitive.dev_verticesOut[vid].vertexUV = (1 / clip.z) * primitive.dev_texcoord0[vid]; +#else + primitive.dev_verticesOut[vid].vertexUV = primitive.dev_texcoord0[vid]; +#endif + //transfer all other necessary attributes + primitive.dev_verticesOut[vid].vertexPerspPos = clip; + primitive.dev_verticesOut[vid].vertexEyePos = MV * world; + primitive.dev_verticesOut[vid].diffuseTexture = primitive.dev_diffuseTex; + primitive.dev_verticesOut[vid].vertexNormal = glm::normalize(MV_normal * primitive.dev_normal[vid]); + primitive.dev_verticesOut[vid].texHeight = primitive.diffuseTexHeight; + primitive.dev_verticesOut[vid].texWidth = primitive.diffuseTexWidth; } } +__device__ static glm::vec3 interpolateBarycentric(glm::vec3 p0, glm::vec3 p1, glm::vec3 p2, glm::vec3 coords) { + return coords.x * p0 + + coords.y * p1 + + coords.z * p2; +} +__device__ static glm::vec2 interpolateBarycentric(glm::vec2 p0, glm::vec2 p1, glm::vec2 p2, glm::vec3 coords) { + return coords.x * p0 + + coords.y * p1 + + coords.z * p2; +} -static int curPrimitiveBeginId = 0; - -__global__ -void _primitiveAssembly(int numIndices, int curPrimitiveBeginId, Primitive* dev_primitives, PrimitiveDevBufPointers primitive) { - - // index id - int iid = (blockIdx.x * blockDim.x) + threadIdx.x; +__device__ static glm::vec2 interpolatePerspective(Primitive p, glm::vec3 coords, float pointDepth) { + return pointDepth * ( + coords.x * p.v[0].vertexUV + + coords.y * p.v[1].vertexUV + + coords.z * p.v[2].vertexUV); +} - if (iid < numIndices) { +//VertexOuts in Primitive Buffer -> Fragments in Fragment Buffer +__global__ void generateFragments(int numPrimitives, Primitive* primitiveBuffer, int width, int height, Fragment* fragmentBuffer, int* depthBuffer, glm::vec2 sampleOffset) { + int primIdx = (blockIdx.x * blockDim.x) + threadIdx.x; + if (primIdx >= numPrimitives) return; - // TODO: uncomment the following code for a start - // This is primitive assembly for triangles + //get the intersected triangle + Primitive p = primitiveBuffer[primIdx]; + VertexOut p0 = p.v[0]; + VertexOut p1 = p.v[1]; + VertexOut p2 = p.v[2]; + glm::vec3 triangle[3] = { glm::vec3(p0.vertexPerspPos), glm::vec3(p1.vertexPerspPos), glm::vec3(p2.vertexPerspPos) }; - //int pid; // id for cur primitives vector - //if (primitive.primitiveMode == TINYGLTF_MODE_TRIANGLES) { - // pid = iid / (int)primitive.primitiveType; - // dev_primitives[pid + curPrimitiveBeginId].v[iid % (int)primitive.primitiveType] - // = primitive.dev_verticesOut[primitive.dev_indices[iid]]; - //} + //get upper and lower triangle bounds and restrict them to frustum + AABB triBounds = getAABBForTriangle(triangle, width, height); + //simple loop + for (int x = triBounds.min.x; x < triBounds.max.x; x++) { + for (int y = triBounds.min.y; y < triBounds.max.y; y++) { - // TODO: other primitive types (point, line) + int fragIndex = y*width + x; + + glm::vec3 barycentricCoord = calculateBarycentricCoordinate(triangle, glm::vec2(x, y)+sampleOffset); + //see if the given pixel is within the triangle's bounds from the current view + if (isBarycentricCoordInBounds(barycentricCoord)) { + int fragIndex = y*width + x; + if (fragIndex > width * height) return; + float depth = 1.0f / getZAtCoordinate(barycentricCoord, triangle); + atomicMin(&depthBuffer[fragIndex], (int)(depth * INT_MAX)); + glm::vec3 eyeNormal = interpolateBarycentric(p0.vertexNormal, p1.vertexNormal, p2.vertexNormal, barycentricCoord); + if (depth * INT_MAX == depthBuffer[fragIndex]) { + Fragment f; +#if PERSP_CORRECT + f.UV = interpolatePerspective(p, barycentricCoord, depth); +#else + f.UV = interpolateBarycentric(p0.vertexUV, p1.vertexUV, p2.vertexUV, barycentricCoord); +#endif + f.texWidth = p0.texWidth; + f.texHeight = p0.texHeight; + f.diffuseTexture = p0.diffuseTexture; + f.eyeNormal = eyeNormal; + fragmentBuffer[fragIndex] = f; + } + } + } } - } +//helper for remove-if +struct toCull { + __host__ __device__ bool operator()(Primitive p) { + return p.cull; + } +}; - -/** - * Perform rasterization. - */ +//put it all together void rasterize(uchar4 *pbo, const glm::mat4 & MVP, const glm::mat4 & MV, const glm::mat3 MV_normal) { - int sideLength2d = 8; - dim3 blockSize2d(sideLength2d, sideLength2d); - dim3 blockCount2d((width - 1) / blockSize2d.x + 1, + int sideLength2d = 8; + dim3 blockSize2d(sideLength2d, sideLength2d); + dim3 blockCount2d((width - 1) / blockSize2d.x + 1, (height - 1) / blockSize2d.y + 1); - // Execute your rasterization pipeline here - // (See README for rasterization pipeline outline.) - // Vertex Process & primitive assembly - { - curPrimitiveBeginId = 0; - dim3 numThreadsPerBlock(128); - - auto it = mesh2PrimitivesMap.begin(); - auto itEnd = mesh2PrimitivesMap.end(); - - for (; it != itEnd; ++it) { - auto p = (it->second).begin(); // each primitive - auto pEnd = (it->second).end(); - for (; p != pEnd; ++p) { - dim3 numBlocksForVertices((p->numVertices + numThreadsPerBlock.x - 1) / numThreadsPerBlock.x); - dim3 numBlocksForIndices((p->numIndices + numThreadsPerBlock.x - 1) / numThreadsPerBlock.x); - - _vertexTransformAndAssembly << < numBlocksForVertices, numThreadsPerBlock >> >(p->numVertices, *p, MVP, MV, MV_normal, width, height); - checkCUDAError("Vertex Processing"); - cudaDeviceSynchronize(); - _primitiveAssembly << < numBlocksForIndices, numThreadsPerBlock >> > - (p->numIndices, - curPrimitiveBeginId, - dev_primitives, + curPrimitiveBeginId = 0; + dim3 numThreadsPerBlock(128); + + auto it = mesh2PrimitivesMap.begin(); + auto itEnd = mesh2PrimitivesMap.end(); + + for (; it != itEnd; ++it) { + auto p = (it->second).begin(); // each primitive + auto pEnd = (it->second).end(); + for (; p != pEnd; ++p) { + dim3 numBlocksForVertices((p->numVertices + numThreadsPerBlock.x - 1) / numThreadsPerBlock.x); + dim3 numBlocksForIndices((p->numIndices + numThreadsPerBlock.x - 1) / numThreadsPerBlock.x); + + _vertexTransformAndAssembly << < numBlocksForVertices, numThreadsPerBlock >> > (p->numVertices, *p, MVP, MV, MV_normal, width, height); + checkCUDAError("Vertex Processing"); + cudaDeviceSynchronize(); + _primitiveAssembly << < numBlocksForIndices, numThreadsPerBlock >> > + (p->numIndices, + curPrimitiveBeginId, + dev_primitives, *p); - checkCUDAError("Primitive Assembly"); + checkCUDAError("Primitive Assembly"); - curPrimitiveBeginId += p->numPrimitives; - } + curPrimitiveBeginId += p->numPrimitives; } - - checkCUDAError("Vertex Processing and Primitive Assembly"); } - - cudaMemset(dev_fragmentBuffer, 0, width * height * sizeof(Fragment)); - initDepth << > >(width, height, dev_depth); - - // TODO: rasterize + checkCUDAError("Vertex Processing and Primitive Assembly"); + +#if BACKFACE_CULLING + //remove if for culled primitives + Primitive* new_primitive_end = thrust::remove_if(thrust::device, dev_primitives, dev_primitives + totalNumPrimitives, toCull());//-- 2: cull those paths that don't need any more shading + int frontPrims = new_primitive_end - dev_primitives; +#else + int frontPrims = totalNumPrimitives; +#endif + + dim3 numBlocksForPrims((frontPrims + numThreadsPerBlock.x - 1) / numThreadsPerBlock.x); + + cudaMemset(dev_framebuffer, 0, width*height * sizeof(glm::vec3)); + + for (float i = 0; i < SAMPLES; i++) { + for (float j = 0; j < SAMPLES; j++) { + glm::vec2 sampleOffset = glm::vec2(i * SAMPLE_JITTER, j * SAMPLE_JITTER); + cudaMemset(dev_fragmentBuffer, 0, width * height * sizeof(Fragment)); + initDepth << > > (width, height, dev_depth); + // TODO: rasterize + generateFragments << > > (frontPrims, dev_primitives, width, height, dev_fragmentBuffer, dev_depth, sampleOffset); + checkCUDAError("rasterization problem"); - // Copy depthbuffer colors into framebuffer - render << > >(width, height, dev_fragmentBuffer, dev_framebuffer); - checkCUDAError("fragment shader"); - // Copy framebuffer into OpenGL buffer for OpenGL previewing - sendImageToPBO<<>>(pbo, width, height, dev_framebuffer); - checkCUDAError("copy render result to pbo"); + // Copy depthbuffer colors into framebuffer + render << > > (width, height, dev_fragmentBuffer, dev_framebuffer); + checkCUDAError("fragment shader"); + } + } + // Copy framebuffer into OpenGL buffer for OpenGL previewing + sendImageToPBO << > > (pbo, width, height, dev_framebuffer); + checkCUDAError("copy render result to pbo"); } -/** - * Called once at the end of the program to free CUDA memory. - */ +//clean up void rasterizeFree() { - // deconstruct primitives attribute/indices device buffer + // deconstruct primitives attribute/indices device buffer auto it(mesh2PrimitivesMap.begin()); auto itEnd(mesh2PrimitivesMap.end()); @@ -753,24 +888,24 @@ void rasterizeFree() { cudaFree(p->dev_verticesOut); - + //TODO: release other attributes and materials } } //////////// - cudaFree(dev_primitives); - dev_primitives = NULL; + cudaFree(dev_primitives); + dev_primitives = NULL; cudaFree(dev_fragmentBuffer); dev_fragmentBuffer = NULL; - cudaFree(dev_framebuffer); - dev_framebuffer = NULL; + cudaFree(dev_framebuffer); + dev_framebuffer = NULL; cudaFree(dev_depth); dev_depth = NULL; - checkCUDAError("rasterize Free"); + checkCUDAError("rasterize Free"); } diff --git a/src/rasterizeTools.h b/src/rasterizeTools.h index 46c701e..bd113ae 100644 --- a/src/rasterizeTools.h +++ b/src/rasterizeTools.h @@ -13,8 +13,8 @@ #include struct AABB { - glm::vec3 min; - glm::vec3 max; + glm::vec2 min; + glm::vec2 max; }; /** @@ -30,16 +30,14 @@ glm::vec3 multiplyMV(glm::mat4 m, glm::vec4 v) { * Finds the axis aligned bounding box for a given triangle. */ __host__ __device__ static -AABB getAABBForTriangle(const glm::vec3 tri[3]) { +AABB getAABBForTriangle(const glm::vec3 tri[3], int width, int height) { AABB aabb; - aabb.min = glm::vec3( - min(min(tri[0].x, tri[1].x), tri[2].x), - min(min(tri[0].y, tri[1].y), tri[2].y), - min(min(tri[0].z, tri[1].z), tri[2].z)); - aabb.max = glm::vec3( - max(max(tri[0].x, tri[1].x), tri[2].x), - max(max(tri[0].y, tri[1].y), tri[2].y), - max(max(tri[0].z, tri[1].z), tri[2].z)); + aabb.min = glm::vec2( + max(min(min(min(tri[0].x, tri[1].x), tri[2].x), (float)width), 0.0f), + max(min(min(min(tri[0].y, tri[1].y), tri[2].y), (float)height), 0.0f)); + aabb.max = glm::vec2( + max(min(max(max(tri[0].x, tri[1].x), tri[2].x), (float)width), 0.0f), + max(min(max(max(tri[0].y, tri[1].y), tri[2].y), (float)height), 0.0f)); return aabb; } @@ -95,7 +93,7 @@ bool isBarycentricCoordInBounds(const glm::vec3 barycentricCoord) { */ __host__ __device__ static float getZAtCoordinate(const glm::vec3 barycentricCoord, const glm::vec3 tri[3]) { - return -(barycentricCoord.x * tri[0].z - + barycentricCoord.y * tri[1].z - + barycentricCoord.z * tri[2].z); + return (barycentricCoord.x / tri[0].z + + barycentricCoord.y / tri[1].z + + barycentricCoord.z / tri[2].z); }