diff --git a/CHANGELOG.md b/CHANGELOG.md index e9bd29e1..14cd1a72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ ## Latest Changes +### v0.5.4 (2025-02-01) +Improvements to JAX frontend. + +**Added**: +- Jacobian Vector Products (JVP) + for both `TensorProduct` and `TensorProductConv` via custom primitives, in addition to VJP. +- Arbitrary higher-order derivatives in JAX. +- JAX JIT support; in particular, support for + Phonon Fine Tuning in [Nequix](https://github.com/atomicarchitects/nequix). + +**Fixed**: +- Zero'd all output buffers in the backwards and double-backwards implementations of convolution +before calling kernels. + +### v0.5.1-0.5.3 (2025-02-01) +Minor bugfixes related to packaging and JAX. + ### v0.5.0 (2025-12-25) JAX support is now available in OpenEquivariance for BOTH NVIDIA and diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index 9c3884c9..c9765d0d 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -320,7 +320,7 @@ def __init__(self, num_blocks, num_threads, warp_size, smem): self.num_blocks = num_blocks self.num_threads = num_threads self.warp_size = warp_size - self.smem = smem + self.smem = int(smem) class ComputationSchedule: diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 5fd8f81d..50f35bd4 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -7,7 +7,6 @@ import json import tempfile -import hashlib from enum import IntEnum @@ -200,13 +199,3 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): time_millis[i] = kernel_time return time_millis - - -def hash_attributes(attrs): - m = hashlib.sha256() - - for key in sorted(attrs.keys()): - m.update(attrs[key].__repr__().encode("utf-8")) - - hash = int(m.hexdigest()[:16], 16) >> 1 - attrs["hash"] = hash diff --git a/openequivariance/openequivariance/extension/json11/json11.cpp b/openequivariance/openequivariance/extension/json11/json11.cpp new file mode 100644 index 00000000..f3140252 --- /dev/null +++ b/openequivariance/openequivariance/extension/json11/json11.cpp @@ -0,0 +1,790 @@ +/* Copyright (c) 2013 Dropbox, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "json11.hpp" +#include +#include +#include +#include +#include + +namespace json11 { + +static const int max_depth = 200; + +using std::string; +using std::vector; +using std::map; +using std::make_shared; +using std::initializer_list; +using std::move; + +/* Helper for representing null - just a do-nothing struct, plus comparison + * operators so the helpers in JsonValue work. We can't use nullptr_t because + * it may not be orderable. + */ +struct NullStruct { + bool operator==(NullStruct) const { return true; } + bool operator<(NullStruct) const { return false; } +}; + +/* * * * * * * * * * * * * * * * * * * * + * Serialization + */ + +static void dump(NullStruct, string &out) { + out += "null"; +} + +static void dump(double value, string &out) { + if (std::isfinite(value)) { + char buf[32]; + snprintf(buf, sizeof buf, "%.17g", value); + out += buf; + } else { + out += "null"; + } +} + +static void dump(int value, string &out) { + char buf[32]; + snprintf(buf, sizeof buf, "%d", value); + out += buf; +} + +static void dump(bool value, string &out) { + out += value ? "true" : "false"; +} + +static void dump(const string &value, string &out) { + out += '"'; + for (size_t i = 0; i < value.length(); i++) { + const char ch = value[i]; + if (ch == '\\') { + out += "\\\\"; + } else if (ch == '"') { + out += "\\\""; + } else if (ch == '\b') { + out += "\\b"; + } else if (ch == '\f') { + out += "\\f"; + } else if (ch == '\n') { + out += "\\n"; + } else if (ch == '\r') { + out += "\\r"; + } else if (ch == '\t') { + out += "\\t"; + } else if (static_cast(ch) <= 0x1f) { + char buf[8]; + snprintf(buf, sizeof buf, "\\u%04x", ch); + out += buf; + } else if (static_cast(ch) == 0xe2 && static_cast(value[i+1]) == 0x80 + && static_cast(value[i+2]) == 0xa8) { + out += "\\u2028"; + i += 2; + } else if (static_cast(ch) == 0xe2 && static_cast(value[i+1]) == 0x80 + && static_cast(value[i+2]) == 0xa9) { + out += "\\u2029"; + i += 2; + } else { + out += ch; + } + } + out += '"'; +} + +static void dump(const Json::array &values, string &out) { + bool first = true; + out += "["; + for (const auto &value : values) { + if (!first) + out += ", "; + value.dump(out); + first = false; + } + out += "]"; +} + +static void dump(const Json::object &values, string &out) { + bool first = true; + out += "{"; + for (const auto &kv : values) { + if (!first) + out += ", "; + dump(kv.first, out); + out += ": "; + kv.second.dump(out); + first = false; + } + out += "}"; +} + +void Json::dump(string &out) const { + m_ptr->dump(out); +} + +/* * * * * * * * * * * * * * * * * * * * + * Value wrappers + */ + +template +class Value : public JsonValue { +protected: + + // Constructors + explicit Value(const T &value) : m_value(value) {} + explicit Value(T &&value) : m_value(move(value)) {} + + // Get type tag + Json::Type type() const override { + return tag; + } + + // Comparisons + bool equals(const JsonValue * other) const override { + return m_value == static_cast *>(other)->m_value; + } + bool less(const JsonValue * other) const override { + return m_value < static_cast *>(other)->m_value; + } + + const T m_value; + void dump(string &out) const override { json11::dump(m_value, out); } +}; + +class JsonDouble final : public Value { + double number_value() const override { return m_value; } + int int_value() const override { return static_cast(m_value); } + bool equals(const JsonValue * other) const override { return m_value == other->number_value(); } + bool less(const JsonValue * other) const override { return m_value < other->number_value(); } +public: + explicit JsonDouble(double value) : Value(value) {} +}; + +class JsonInt final : public Value { + double number_value() const override { return m_value; } + int int_value() const override { return m_value; } + bool equals(const JsonValue * other) const override { return m_value == other->number_value(); } + bool less(const JsonValue * other) const override { return m_value < other->number_value(); } +public: + explicit JsonInt(int value) : Value(value) {} +}; + +class JsonBoolean final : public Value { + bool bool_value() const override { return m_value; } +public: + explicit JsonBoolean(bool value) : Value(value) {} +}; + +class JsonString final : public Value { + const string &string_value() const override { return m_value; } +public: + explicit JsonString(const string &value) : Value(value) {} + explicit JsonString(string &&value) : Value(move(value)) {} +}; + +class JsonArray final : public Value { + const Json::array &array_items() const override { return m_value; } + const Json & operator[](size_t i) const override; +public: + explicit JsonArray(const Json::array &value) : Value(value) {} + explicit JsonArray(Json::array &&value) : Value(move(value)) {} +}; + +class JsonObject final : public Value { + const Json::object &object_items() const override { return m_value; } + const Json & operator[](const string &key) const override; +public: + explicit JsonObject(const Json::object &value) : Value(value) {} + explicit JsonObject(Json::object &&value) : Value(move(value)) {} +}; + +class JsonNull final : public Value { +public: + JsonNull() : Value({}) {} +}; + +/* * * * * * * * * * * * * * * * * * * * + * Static globals - static-init-safe + */ +struct Statics { + const std::shared_ptr null = make_shared(); + const std::shared_ptr t = make_shared(true); + const std::shared_ptr f = make_shared(false); + const string empty_string; + const vector empty_vector; + const map empty_map; + Statics() {} +}; + +static const Statics & statics() { + static const Statics s {}; + return s; +} + +static const Json & static_null() { + // This has to be separate, not in Statics, because Json() accesses statics().null. + static const Json json_null; + return json_null; +} + +/* * * * * * * * * * * * * * * * * * * * + * Constructors + */ + +Json::Json() noexcept : m_ptr(statics().null) {} +Json::Json(std::nullptr_t) noexcept : m_ptr(statics().null) {} +Json::Json(double value) : m_ptr(make_shared(value)) {} +Json::Json(int value) : m_ptr(make_shared(value)) {} +Json::Json(bool value) : m_ptr(value ? statics().t : statics().f) {} +Json::Json(const string &value) : m_ptr(make_shared(value)) {} +Json::Json(string &&value) : m_ptr(make_shared(move(value))) {} +Json::Json(const char * value) : m_ptr(make_shared(value)) {} +Json::Json(const Json::array &values) : m_ptr(make_shared(values)) {} +Json::Json(Json::array &&values) : m_ptr(make_shared(move(values))) {} +Json::Json(const Json::object &values) : m_ptr(make_shared(values)) {} +Json::Json(Json::object &&values) : m_ptr(make_shared(move(values))) {} + +/* * * * * * * * * * * * * * * * * * * * + * Accessors + */ + +Json::Type Json::type() const { return m_ptr->type(); } +double Json::number_value() const { return m_ptr->number_value(); } +int Json::int_value() const { return m_ptr->int_value(); } +bool Json::bool_value() const { return m_ptr->bool_value(); } +const string & Json::string_value() const { return m_ptr->string_value(); } +const vector & Json::array_items() const { return m_ptr->array_items(); } +const map & Json::object_items() const { return m_ptr->object_items(); } +const Json & Json::operator[] (size_t i) const { return (*m_ptr)[i]; } +const Json & Json::operator[] (const string &key) const { return (*m_ptr)[key]; } + +double JsonValue::number_value() const { return 0; } +int JsonValue::int_value() const { return 0; } +bool JsonValue::bool_value() const { return false; } +const string & JsonValue::string_value() const { return statics().empty_string; } +const vector & JsonValue::array_items() const { return statics().empty_vector; } +const map & JsonValue::object_items() const { return statics().empty_map; } +const Json & JsonValue::operator[] (size_t) const { return static_null(); } +const Json & JsonValue::operator[] (const string &) const { return static_null(); } + +const Json & JsonObject::operator[] (const string &key) const { + auto iter = m_value.find(key); + return (iter == m_value.end()) ? static_null() : iter->second; +} +const Json & JsonArray::operator[] (size_t i) const { + if (i >= m_value.size()) return static_null(); + else return m_value[i]; +} + +/* * * * * * * * * * * * * * * * * * * * + * Comparison + */ + +bool Json::operator== (const Json &other) const { + if (m_ptr == other.m_ptr) + return true; + if (m_ptr->type() != other.m_ptr->type()) + return false; + + return m_ptr->equals(other.m_ptr.get()); +} + +bool Json::operator< (const Json &other) const { + if (m_ptr == other.m_ptr) + return false; + if (m_ptr->type() != other.m_ptr->type()) + return m_ptr->type() < other.m_ptr->type(); + + return m_ptr->less(other.m_ptr.get()); +} + +/* * * * * * * * * * * * * * * * * * * * + * Parsing + */ + +/* esc(c) + * + * Format char c suitable for printing in an error message. + */ +static inline string esc(char c) { + char buf[12]; + if (static_cast(c) >= 0x20 && static_cast(c) <= 0x7f) { + snprintf(buf, sizeof buf, "'%c' (%d)", c, c); + } else { + snprintf(buf, sizeof buf, "(%d)", c); + } + return string(buf); +} + +static inline bool in_range(long x, long lower, long upper) { + return (x >= lower && x <= upper); +} + +namespace { +/* JsonParser + * + * Object that tracks all state of an in-progress parse. + */ +struct JsonParser final { + + /* State + */ + const string &str; + size_t i; + string &err; + bool failed; + const JsonParse strategy; + + /* fail(msg, err_ret = Json()) + * + * Mark this parse as failed. + */ + Json fail(string &&msg) { + return fail(move(msg), Json()); + } + + template + T fail(string &&msg, const T err_ret) { + if (!failed) + err = std::move(msg); + failed = true; + return err_ret; + } + + /* consume_whitespace() + * + * Advance until the current character is non-whitespace. + */ + void consume_whitespace() { + while (str[i] == ' ' || str[i] == '\r' || str[i] == '\n' || str[i] == '\t') + i++; + } + + /* consume_comment() + * + * Advance comments (c-style inline and multiline). + */ + bool consume_comment() { + bool comment_found = false; + if (str[i] == '/') { + i++; + if (i == str.size()) + return fail("unexpected end of input after start of comment", false); + if (str[i] == '/') { // inline comment + i++; + // advance until next line, or end of input + while (i < str.size() && str[i] != '\n') { + i++; + } + comment_found = true; + } + else if (str[i] == '*') { // multiline comment + i++; + if (i > str.size()-2) + return fail("unexpected end of input inside multi-line comment", false); + // advance until closing tokens + while (!(str[i] == '*' && str[i+1] == '/')) { + i++; + if (i > str.size()-2) + return fail( + "unexpected end of input inside multi-line comment", false); + } + i += 2; + comment_found = true; + } + else + return fail("malformed comment", false); + } + return comment_found; + } + + /* consume_garbage() + * + * Advance until the current character is non-whitespace and non-comment. + */ + void consume_garbage() { + consume_whitespace(); + if(strategy == JsonParse::COMMENTS) { + bool comment_found = false; + do { + comment_found = consume_comment(); + if (failed) return; + consume_whitespace(); + } + while(comment_found); + } + } + + /* get_next_token() + * + * Return the next non-whitespace character. If the end of the input is reached, + * flag an error and return 0. + */ + char get_next_token() { + consume_garbage(); + if (failed) return static_cast(0); + if (i == str.size()) + return fail("unexpected end of input", static_cast(0)); + + return str[i++]; + } + + /* encode_utf8(pt, out) + * + * Encode pt as UTF-8 and add it to out. + */ + void encode_utf8(long pt, string & out) { + if (pt < 0) + return; + + if (pt < 0x80) { + out += static_cast(pt); + } else if (pt < 0x800) { + out += static_cast((pt >> 6) | 0xC0); + out += static_cast((pt & 0x3F) | 0x80); + } else if (pt < 0x10000) { + out += static_cast((pt >> 12) | 0xE0); + out += static_cast(((pt >> 6) & 0x3F) | 0x80); + out += static_cast((pt & 0x3F) | 0x80); + } else { + out += static_cast((pt >> 18) | 0xF0); + out += static_cast(((pt >> 12) & 0x3F) | 0x80); + out += static_cast(((pt >> 6) & 0x3F) | 0x80); + out += static_cast((pt & 0x3F) | 0x80); + } + } + + /* parse_string() + * + * Parse a string, starting at the current position. + */ + string parse_string() { + string out; + long last_escaped_codepoint = -1; + while (true) { + if (i == str.size()) + return fail("unexpected end of input in string", ""); + + char ch = str[i++]; + + if (ch == '"') { + encode_utf8(last_escaped_codepoint, out); + return out; + } + + if (in_range(ch, 0, 0x1f)) + return fail("unescaped " + esc(ch) + " in string", ""); + + // The usual case: non-escaped characters + if (ch != '\\') { + encode_utf8(last_escaped_codepoint, out); + last_escaped_codepoint = -1; + out += ch; + continue; + } + + // Handle escapes + if (i == str.size()) + return fail("unexpected end of input in string", ""); + + ch = str[i++]; + + if (ch == 'u') { + // Extract 4-byte escape sequence + string esc = str.substr(i, 4); + // Explicitly check length of the substring. The following loop + // relies on std::string returning the terminating NUL when + // accessing str[length]. Checking here reduces brittleness. + if (esc.length() < 4) { + return fail("bad \\u escape: " + esc, ""); + } + for (size_t j = 0; j < 4; j++) { + if (!in_range(esc[j], 'a', 'f') && !in_range(esc[j], 'A', 'F') + && !in_range(esc[j], '0', '9')) + return fail("bad \\u escape: " + esc, ""); + } + + long codepoint = strtol(esc.data(), nullptr, 16); + + // JSON specifies that characters outside the BMP shall be encoded as a pair + // of 4-hex-digit \u escapes encoding their surrogate pair components. Check + // whether we're in the middle of such a beast: the previous codepoint was an + // escaped lead (high) surrogate, and this is a trail (low) surrogate. + if (in_range(last_escaped_codepoint, 0xD800, 0xDBFF) + && in_range(codepoint, 0xDC00, 0xDFFF)) { + // Reassemble the two surrogate pairs into one astral-plane character, per + // the UTF-16 algorithm. + encode_utf8((((last_escaped_codepoint - 0xD800) << 10) + | (codepoint - 0xDC00)) + 0x10000, out); + last_escaped_codepoint = -1; + } else { + encode_utf8(last_escaped_codepoint, out); + last_escaped_codepoint = codepoint; + } + + i += 4; + continue; + } + + encode_utf8(last_escaped_codepoint, out); + last_escaped_codepoint = -1; + + if (ch == 'b') { + out += '\b'; + } else if (ch == 'f') { + out += '\f'; + } else if (ch == 'n') { + out += '\n'; + } else if (ch == 'r') { + out += '\r'; + } else if (ch == 't') { + out += '\t'; + } else if (ch == '"' || ch == '\\' || ch == '/') { + out += ch; + } else { + return fail("invalid escape character " + esc(ch), ""); + } + } + } + + /* parse_number() + * + * Parse a double. + */ + Json parse_number() { + size_t start_pos = i; + + if (str[i] == '-') + i++; + + // Integer part + if (str[i] == '0') { + i++; + if (in_range(str[i], '0', '9')) + return fail("leading 0s not permitted in numbers"); + } else if (in_range(str[i], '1', '9')) { + i++; + while (in_range(str[i], '0', '9')) + i++; + } else { + return fail("invalid " + esc(str[i]) + " in number"); + } + + if (str[i] != '.' && str[i] != 'e' && str[i] != 'E' + && (i - start_pos) <= static_cast(std::numeric_limits::digits10)) { + return std::atoi(str.c_str() + start_pos); + } + + // Decimal part + if (str[i] == '.') { + i++; + if (!in_range(str[i], '0', '9')) + return fail("at least one digit required in fractional part"); + + while (in_range(str[i], '0', '9')) + i++; + } + + // Exponent part + if (str[i] == 'e' || str[i] == 'E') { + i++; + + if (str[i] == '+' || str[i] == '-') + i++; + + if (!in_range(str[i], '0', '9')) + return fail("at least one digit required in exponent"); + + while (in_range(str[i], '0', '9')) + i++; + } + + return std::strtod(str.c_str() + start_pos, nullptr); + } + + /* expect(str, res) + * + * Expect that 'str' starts at the character that was just read. If it does, advance + * the input and return res. If not, flag an error. + */ + Json expect(const string &expected, Json res) { + assert(i != 0); + i--; + if (str.compare(i, expected.length(), expected) == 0) { + i += expected.length(); + return res; + } else { + return fail("parse error: expected " + expected + ", got " + str.substr(i, expected.length())); + } + } + + /* parse_json() + * + * Parse a JSON object. + */ + Json parse_json(int depth) { + if (depth > max_depth) { + return fail("exceeded maximum nesting depth"); + } + + char ch = get_next_token(); + if (failed) + return Json(); + + if (ch == '-' || (ch >= '0' && ch <= '9')) { + i--; + return parse_number(); + } + + if (ch == 't') + return expect("true", true); + + if (ch == 'f') + return expect("false", false); + + if (ch == 'n') + return expect("null", Json()); + + if (ch == '"') + return parse_string(); + + if (ch == '{') { + map data; + ch = get_next_token(); + if (ch == '}') + return data; + + while (1) { + if (ch != '"') + return fail("expected '\"' in object, got " + esc(ch)); + + string key = parse_string(); + if (failed) + return Json(); + + ch = get_next_token(); + if (ch != ':') + return fail("expected ':' in object, got " + esc(ch)); + + data[std::move(key)] = parse_json(depth + 1); + if (failed) + return Json(); + + ch = get_next_token(); + if (ch == '}') + break; + if (ch != ',') + return fail("expected ',' in object, got " + esc(ch)); + + ch = get_next_token(); + } + return data; + } + + if (ch == '[') { + vector data; + ch = get_next_token(); + if (ch == ']') + return data; + + while (1) { + i--; + data.push_back(parse_json(depth + 1)); + if (failed) + return Json(); + + ch = get_next_token(); + if (ch == ']') + break; + if (ch != ',') + return fail("expected ',' in list, got " + esc(ch)); + + ch = get_next_token(); + (void)ch; + } + return data; + } + + return fail("expected value, got " + esc(ch)); + } +}; +}//namespace { + +Json Json::parse(const string &in, string &err, JsonParse strategy) { + JsonParser parser { in, 0, err, false, strategy }; + Json result = parser.parse_json(0); + + // Check for any trailing garbage + parser.consume_garbage(); + if (parser.failed) + return Json(); + if (parser.i != in.size()) + return parser.fail("unexpected trailing " + esc(in[parser.i])); + + return result; +} + +// Documented in json11.hpp +vector Json::parse_multi(const string &in, + std::string::size_type &parser_stop_pos, + string &err, + JsonParse strategy) { + JsonParser parser { in, 0, err, false, strategy }; + parser_stop_pos = 0; + vector json_vec; + while (parser.i != in.size() && !parser.failed) { + json_vec.push_back(parser.parse_json(0)); + if (parser.failed) + break; + + // Check for another object + parser.consume_garbage(); + if (parser.failed) + break; + parser_stop_pos = parser.i; + } + return json_vec; +} + +/* * * * * * * * * * * * * * * * * * * * + * Shape-checking + */ + +bool Json::has_shape(const shape & types, string & err) const { + if (!is_object()) { + err = "expected JSON object, got " + dump(); + return false; + } + + const auto& obj_items = object_items(); + for (auto & item : types) { + const auto it = obj_items.find(item.first); + if (it == obj_items.cend() || it->second.type() != item.second) { + err = "bad type for " + item.first + " in " + dump(); + return false; + } + } + + return true; +} + +} // namespace json11 \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/json11/json11.hpp b/openequivariance/openequivariance/extension/json11/json11.hpp new file mode 100644 index 00000000..965388a7 --- /dev/null +++ b/openequivariance/openequivariance/extension/json11/json11.hpp @@ -0,0 +1,232 @@ +/* json11 + * + * json11 is a tiny JSON library for C++11, providing JSON parsing and serialization. + * + * The core object provided by the library is json11::Json. A Json object represents any JSON + * value: null, bool, number (int or double), string (std::string), array (std::vector), or + * object (std::map). + * + * Json objects act like values: they can be assigned, copied, moved, compared for equality or + * order, etc. There are also helper methods Json::dump, to serialize a Json to a string, and + * Json::parse (static) to parse a std::string as a Json object. + * + * Internally, the various types of Json object are represented by the JsonValue class + * hierarchy. + * + * A note on numbers - JSON specifies the syntax of number formatting but not its semantics, + * so some JSON implementations distinguish between integers and floating-point numbers, while + * some don't. In json11, we choose the latter. Because some JSON implementations (namely + * Javascript itself) treat all numbers as the same type, distinguishing the two leads + * to JSON that will be *silently* changed by a round-trip through those implementations. + * Dangerous! To avoid that risk, json11 stores all numbers as double internally, but also + * provides integer helpers. + * + * Fortunately, double-precision IEEE754 ('double') can precisely store any integer in the + * range +/-2^53, which includes every 'int' on most systems. (Timestamps often use int64 + * or long long to avoid the Y2038K problem; a double storing microseconds since some epoch + * will be exact for +/- 275 years.) + */ + +/* Copyright (c) 2013 Dropbox, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#ifdef _MSC_VER + #if _MSC_VER <= 1800 // VS 2013 + #ifndef noexcept + #define noexcept throw() + #endif + + #ifndef snprintf + #define snprintf _snprintf_s + #endif + #endif +#endif + +namespace json11 { + +enum JsonParse { + STANDARD, COMMENTS +}; + +class JsonValue; + +class Json final { +public: + // Types + enum Type { + NUL, NUMBER, BOOL, STRING, ARRAY, OBJECT + }; + + // Array and object typedefs + typedef std::vector array; + typedef std::map object; + + // Constructors for the various types of JSON value. + Json() noexcept; // NUL + Json(std::nullptr_t) noexcept; // NUL + Json(double value); // NUMBER + Json(int value); // NUMBER + Json(bool value); // BOOL + Json(const std::string &value); // STRING + Json(std::string &&value); // STRING + Json(const char * value); // STRING + Json(const array &values); // ARRAY + Json(array &&values); // ARRAY + Json(const object &values); // OBJECT + Json(object &&values); // OBJECT + + // Implicit constructor: anything with a to_json() function. + template + Json(const T & t) : Json(t.to_json()) {} + + // Implicit constructor: map-like objects (std::map, std::unordered_map, etc) + template ().begin()->first)>::value + && std::is_constructible().begin()->second)>::value, + int>::type = 0> + Json(const M & m) : Json(object(m.begin(), m.end())) {} + + // Implicit constructor: vector-like objects (std::list, std::vector, std::set, etc) + template ().begin())>::value, + int>::type = 0> + Json(const V & v) : Json(array(v.begin(), v.end())) {} + + // This prevents Json(some_pointer) from accidentally producing a bool. Use + // Json(bool(some_pointer)) if that behavior is desired. + Json(void *) = delete; + + // Accessors + Type type() const; + + bool is_null() const { return type() == NUL; } + bool is_number() const { return type() == NUMBER; } + bool is_bool() const { return type() == BOOL; } + bool is_string() const { return type() == STRING; } + bool is_array() const { return type() == ARRAY; } + bool is_object() const { return type() == OBJECT; } + + // Return the enclosed value if this is a number, 0 otherwise. Note that json11 does not + // distinguish between integer and non-integer numbers - number_value() and int_value() + // can both be applied to a NUMBER-typed object. + double number_value() const; + int int_value() const; + + // Return the enclosed value if this is a boolean, false otherwise. + bool bool_value() const; + // Return the enclosed string if this is a string, "" otherwise. + const std::string &string_value() const; + // Return the enclosed std::vector if this is an array, or an empty vector otherwise. + const array &array_items() const; + // Return the enclosed std::map if this is an object, or an empty map otherwise. + const object &object_items() const; + + // Return a reference to arr[i] if this is an array, Json() otherwise. + const Json & operator[](size_t i) const; + // Return a reference to obj[key] if this is an object, Json() otherwise. + const Json & operator[](const std::string &key) const; + + // Serialize. + void dump(std::string &out) const; + std::string dump() const { + std::string out; + dump(out); + return out; + } + + // Parse. If parse fails, return Json() and assign an error message to err. + static Json parse(const std::string & in, + std::string & err, + JsonParse strategy = JsonParse::STANDARD); + static Json parse(const char * in, + std::string & err, + JsonParse strategy = JsonParse::STANDARD) { + if (in) { + return parse(std::string(in), err, strategy); + } else { + err = "null input"; + return nullptr; + } + } + // Parse multiple objects, concatenated or separated by whitespace + static std::vector parse_multi( + const std::string & in, + std::string::size_type & parser_stop_pos, + std::string & err, + JsonParse strategy = JsonParse::STANDARD); + + static inline std::vector parse_multi( + const std::string & in, + std::string & err, + JsonParse strategy = JsonParse::STANDARD) { + std::string::size_type parser_stop_pos; + return parse_multi(in, parser_stop_pos, err, strategy); + } + + bool operator== (const Json &rhs) const; + bool operator< (const Json &rhs) const; + bool operator!= (const Json &rhs) const { return !(*this == rhs); } + bool operator<= (const Json &rhs) const { return !(rhs < *this); } + bool operator> (const Json &rhs) const { return (rhs < *this); } + bool operator>= (const Json &rhs) const { return !(*this < rhs); } + + /* has_shape(types, err) + * + * Return true if this is a JSON object and, for each item in types, has a field of + * the given type. If not, return false and set err to a descriptive message. + */ + typedef std::initializer_list> shape; + bool has_shape(const shape & types, std::string & err) const; + +private: + std::shared_ptr m_ptr; +}; + +// Internal class hierarchy - JsonValue objects are not exposed to users of this API. +class JsonValue { +protected: + friend class Json; + friend class JsonInt; + friend class JsonDouble; + virtual Json::Type type() const = 0; + virtual bool equals(const JsonValue * other) const = 0; + virtual bool less(const JsonValue * other) const = 0; + virtual void dump(std::string &out) const = 0; + virtual double number_value() const; + virtual int int_value() const; + virtual bool bool_value() const; + virtual const std::string &string_value() const; + virtual const Json::array &array_items() const; + virtual const Json &operator[](size_t i) const; + virtual const Json::object &object_items() const; + virtual const Json &operator[](const std::string &key) const; + virtual ~JsonValue() {} +}; + +} // namespace json11 \ No newline at end of file diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index 05e4b097..bf7a1445 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -1,116 +1,11 @@ import jax -import jax.numpy as jnp import numpy as np -from functools import partial from openequivariance.jax import extlib from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollTP import LoopUnrollTP -from openequivariance.core.utils import hash_attributes from openequivariance.jax.utils import reorder_jax - - -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) -def forward(X, Y, W, L3_dim, irrep_dtype, attrs): - forward_call = jax.ffi.ffi_call( - "tp_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) - ) - return forward_call(X, Y, W, **attrs) - - -def forward_fwd(X, Y, W, L3_dim, irrep_dtype, attrs): - return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W) - - -def forward_bwd(L3_dim, irrep_dtype, attrs, inputs, dZ): - X, Y, W = inputs - return backward(X, Y, W, dZ, irrep_dtype, attrs) - - -forward.defvjp(forward_fwd, forward_bwd) - - -@partial(jax.custom_vjp, nondiff_argnums=(4, 5)) -def backward(X, Y, W, dZ, irrep_dtype, attrs): - backward_call = jax.ffi.ffi_call( - "tp_backward", - ( - jax.ShapeDtypeStruct(X.shape, irrep_dtype), - jax.ShapeDtypeStruct(Y.shape, irrep_dtype), - jax.ShapeDtypeStruct(W.shape, irrep_dtype), - ), - ) - return backward_call(X, Y, W, dZ, **attrs) - - -def backward_fwd(X, Y, W, dZ, irrep_dtype, attrs): - return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ) - - -def backward_bwd(irrep_dtype, attrs, inputs, derivs): - X, Y, W, dZ = inputs - ddX, ddY, ddW = derivs - return double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs) - - -backward.defvjp(backward_fwd, backward_bwd) - - -@partial(jax.custom_vjp, nondiff_argnums=(7, 8)) -def double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs): - double_backward_call = jax.ffi.ffi_call( - "tp_double_backward", - ( - jax.ShapeDtypeStruct(X.shape, irrep_dtype), - jax.ShapeDtypeStruct(Y.shape, irrep_dtype), - jax.ShapeDtypeStruct(W.shape, irrep_dtype), - jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), - ), - ) - return double_backward_call(X, Y, W, dZ, ddX, ddY, ddW, **attrs) - - -def double_backward_fwd(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs): - out = double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs) - return out, (X, Y, W, dZ, ddX, ddY, ddW) - - -def zeros_like(x): - return jnp.zeros_like(x) - - -def triple_backward(irrep_dtype, attrs, residuals, tangent_outputs): - X, Y, W, dZ, ddX, ddY, ddW = residuals - t_dX, t_dY, t_dW, t_ddZ = tangent_outputs - - op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W)) - g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, irrep_dtype, attrs) - - op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW)) - g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, irrep_dtype, attrs) - - op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW) - g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, irrep_dtype, attrs) - - op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW) - g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, irrep_dtype, attrs) - - g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, irrep_dtype, attrs) - g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, irrep_dtype, attrs) - g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, irrep_dtype, attrs) - - grad_X = g2_X + g4_X + g6_X + g7_X - grad_Y = g2_Y + g3_Y + g5_Y + g7_Y - grad_W = g1_W + g3_W + g4_W + g5_W + g6_W - grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ - - grad_ddX = g1_ddX + g3_ddX + g5_ddX - grad_ddY = g1_ddY + g4_ddY + g6_ddY - grad_ddW = g2_ddW + g7_ddW - - return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW - - -double_backward.defvjp(double_backward_fwd, triple_backward) +from openequivariance.jax.jvp.tp_prim import tp_fwd_p +import json class TensorProduct(LoopUnrollTP): @@ -124,14 +19,18 @@ def __init__(self, problem: TPProblem): dp = extlib.DeviceProp(0) super().__init__(problem, dp, extlib.postprocess_kernel, torch_op=False) - self.attrs = { - "kernel": self.jit_kernel, - "forward_config": vars(self.forward_schedule.launch_config), - "backward_config": vars(self.backward_schedule.launch_config), - "double_backward_config": vars(self.double_backward_schedule.launch_config), - "kernel_prop": self.kernelProp, - } - hash_attributes(self.attrs) + self.kernel = json.dumps( + { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars( + self.double_backward_schedule.launch_config + ), + "kernel_prop": self.kernelProp, + } + ) + self.hash = self.kernel.__hash__() self.weight_numel = problem.weight_numel self.L3_dim = self.config.irreps_out.dim @@ -139,7 +38,9 @@ def __init__(self, problem: TPProblem): def forward( self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray ) -> jax.numpy.ndarray: - return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs) + return tp_fwd_p.bind( + X, Y, W, L3_dim=self.L3_dim, kernel=self.kernel, hash=self.hash + ) def __call__( self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray @@ -160,7 +61,7 @@ def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None: weights = self.reorder_weights_from_e3nn( weights, has_batch_dim=not self.config.shared_weights ) - result = self.forward( + result = jax.jit(self.forward)( jax.numpy.asarray(L1_in), jax.numpy.asarray(L2_in), jax.numpy.asarray(weights), @@ -173,12 +74,14 @@ def backward_cpu( weights = self.reorder_weights_from_e3nn( weights, has_batch_dim=not self.config.shared_weights ) - backward_fn = jax.vjp( - lambda X, Y, W: self.forward(X, Y, W), - jax.numpy.asarray(L1_in), - jax.numpy.asarray(L2_in), - jax.numpy.asarray(weights), - )[1] + backward_fn = jax.jit( + jax.vjp( + lambda X, Y, W: self.forward(X, Y, W), + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + )[1] + ) L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn( jax.numpy.asarray(L3_grad) ) @@ -200,14 +103,20 @@ def double_backward_cpu( in2_dgrad_jax = jax.numpy.asarray(in2_dgrad) weights_dgrad_jax = jax.numpy.asarray(weights_dgrad) - in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp( - lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[ - 1 - ](o), - in1_jax, - in2_jax, - weights_jax, - out_grad_jax, - )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) + dbwd_func = jax.jit( + jax.vjp( + lambda x, y, w, o: jax.vjp( + lambda a, b, c: self.forward(a, b, c), x, y, w + )[1](o), + in1_jax, + in2_jax, + weights_jax, + out_grad_jax, + )[1] + ) + + in1_grad, in2_grad, weights_grad, out_dgrad = dbwd_func( + (in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax) + ) return in1_grad, in2_grad, weights_grad, out_dgrad diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 7439cd4e..52102294 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -1,184 +1,44 @@ import jax +import json import jax.numpy as jnp import numpy as np -from functools import partial from typing import Optional from openequivariance.jax import extlib from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollConv import LoopUnrollConv -from openequivariance.core.utils import hash_attributes from openequivariance.jax.utils import reorder_jax from openequivariance.benchmark.logging_utils import getLogger - -logger = getLogger() - - -def zeros_like(x): - return jnp.zeros_like(x) - - -@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) -def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): - forward_call = jax.ffi.ffi_call( - "conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) - ) - return forward_call(X, Y, W, rows, cols, workspace, sender_perm, **attrs) - - -def forward_fwd( - X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs -): - out = forward( - X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs - ) - return out, (X, Y, W, rows, cols) - - -def forward_bwd(workspace, sender_perm, L3_dim, irrep_dtype, attrs, res, dZ): - X, Y, W, rows, cols = res - dX, dY, dW = backward( - X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs - ) - return dX, dY, dW, None, None - - -forward.defvjp(forward_fwd, forward_bwd) - - -@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9)) -def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): - backward_call = jax.ffi.ffi_call( - "conv_backward", - ( - jax.ShapeDtypeStruct(X.shape, irrep_dtype), - jax.ShapeDtypeStruct(Y.shape, irrep_dtype), - jax.ShapeDtypeStruct(W.shape, irrep_dtype), - ), - ) - return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs) - - -def backward_fwd(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): - out = backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs) - return out, (X, Y, W, dZ, rows, cols) - - -def backward_bwd(workspace, sender_perm, irrep_dtype, attrs, res, derivatives): - X, Y, W, dZ, rows, cols = res - ddX, ddY, ddW = derivatives - - gX, gY, gW, gdZ = double_backward( - X, - Y, - W, - dZ, - ddX, - ddY, - ddW, - rows, - cols, - workspace, - sender_perm, - irrep_dtype, - attrs, - ) - - return gX, gY, gW, gdZ, None, None - - -backward.defvjp(backward_fwd, backward_bwd) - - -@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12)) -def double_backward( - X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs -): - double_backward_call = jax.ffi.ffi_call( - "conv_double_backward", - ( - jax.ShapeDtypeStruct(X.shape, irrep_dtype), - jax.ShapeDtypeStruct(Y.shape, irrep_dtype), - jax.ShapeDtypeStruct(W.shape, irrep_dtype), - jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), - ), - ) - return double_backward_call( - X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, **attrs - ) - - -def double_backward_fwd( - X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs -): - out = double_backward( - X, - Y, - W, - dZ, - ddX, - ddY, - ddW, - rows, - cols, - workspace, - sender_perm, - irrep_dtype, - attrs, - ) - return out, (X, Y, W, dZ, ddX, ddY, ddW, rows, cols) +from openequivariance.jax.jvp import conv_prim +from openequivariance.jax.vjp import conv_func -def triple_backward( - workspace, - sender_perm, - irrep_dtype, - attrs, - residuals, - tangent_outputs, -): - X, Y, W, dZ, ddX, ddY, ddW, rows, cols = residuals - t_dX, t_dY, t_dW, t_ddZ = tangent_outputs - - common_args = (rows, cols, workspace, sender_perm, irrep_dtype, attrs) - - op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W)) - g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, *common_args) - - op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW)) - g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, *common_args) - - op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW) - g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, *common_args) - - op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW) - g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, *common_args) - - g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, *common_args) - g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, *common_args) - g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, *common_args) - - grad_X = g2_X + g4_X + g6_X + g7_X - grad_Y = g2_Y + g3_Y + g5_Y + g7_Y - grad_W = g1_W + g3_W + g4_W + g5_W + g6_W - grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ - - grad_ddX = g1_ddX + g3_ddX + g5_ddX - grad_ddY = g1_ddY + g4_ddY + g6_ddY - grad_ddW = g2_ddW + g7_ddW - - return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW, None, None - - -double_backward.defvjp(double_backward_fwd, triple_backward) +logger = getLogger() class TensorProductConv(LoopUnrollConv): + r""" + Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one + key difference: integer arrays passed to this function must have dtype + ``np.int32`` (as opposed to ``np.int64`` in the PyTorch version). + + :param problem: Specification of the tensor product. + :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic + fixup-based algorithm. `Default`: ``False``. + :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option, + the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``. + """ + def __init__( - self, config: TPProblem, deterministic: bool = False, kahan: bool = False + self, + config: TPProblem, + deterministic: bool = False, + kahan: bool = False, + requires_jvp: bool = True, ): dp = extlib.DeviceProp(0) + self.requires_jvp = requires_jvp super().__init__( config, dp, @@ -189,14 +49,18 @@ def __init__( kahan=kahan, ) - self.attrs = { - "kernel": self.jit_kernel, - "forward_config": vars(self.forward_schedule.launch_config), - "backward_config": vars(self.backward_schedule.launch_config), - "double_backward_config": vars(self.double_backward_schedule.launch_config), - "kernel_prop": self.kernel_prop, - } - hash_attributes(self.attrs) + self.kernel = json.dumps( + { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars( + self.double_backward_schedule.launch_config + ), + "kernel_prop": self.kernel_prop, + } + ) + self.hash = self.kernel.__hash__() self.weight_numel = config.weight_numel self.L3_dim = self.config.irreps_out.dim @@ -223,7 +87,12 @@ def forward( "Must provide sender_perm for deterministic convolutions." ) - return forward( + func = conv_prim.conv_fwd_p.bind + + if not self.requires_jvp: + func = conv_func.forward + + return func( X, Y, W, @@ -231,9 +100,9 @@ def forward( cols, self.workspace, sender_perm, - self.L3_dim, - self.config.irrep_dtype, - self.attrs, + L3_dim=self.L3_dim, + kernel=self.kernel, + hash=self.hash, ) def __call__( @@ -260,7 +129,9 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): weights = self.reorder_weights_from_e3nn( weights, has_batch_dim=not self.config.shared_weights ) - result = self.forward( + + jit_fwd = jax.jit(self.forward) + result = jit_fwd( jax.numpy.asarray(L1_in), jax.numpy.asarray(L2_in), jax.numpy.asarray(weights), @@ -288,19 +159,22 @@ def backward_cpu( weights, has_batch_dim=not self.config.shared_weights ) - backward_fn = jax.vjp( - lambda X, Y, W: self.forward( - X, - Y, - W, - jax.numpy.asarray(rows), - jax.numpy.asarray(cols), - jax.numpy.asarray(sender_perm), - ), - jax.numpy.asarray(L1_in), - jax.numpy.asarray(L2_in), - jax.numpy.asarray(weights), - )[1] + backward_fn = jax.jit( + jax.vjp( + lambda X, Y, W: self.forward( + X, + Y, + W, + jax.numpy.asarray(rows), + jax.numpy.asarray(cols), + jax.numpy.asarray(sender_perm), + ), + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + )[1] + ) + L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn( jax.numpy.asarray(L3_grad) ) @@ -326,20 +200,22 @@ def double_backward_cpu( cols_jax = jax.numpy.asarray(graph.cols.astype(self.idx_dtype)) sender_perm_jax = jax.numpy.asarray(graph.transpose_perm.astype(self.idx_dtype)) - in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp( - lambda x, y, w, o: jax.vjp( - lambda a, b, c: self.forward( - a, b, c, rows_jax, cols_jax, sender_perm_jax - ), - x, - y, - w, - )[1](o), - in1_jax, - in2_jax, - weights_jax, - out_grad_jax, - )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) + in1_grad, in2_grad, weights_grad, out_dgrad = jax.jit( + jax.vjp( + lambda x, y, w, o: jax.vjp( + lambda a, b, c: self.forward( + a, b, c, rows_jax, cols_jax, sender_perm_jax + ), + x, + y, + w, + )[1](o), + in1_jax, + in2_jax, + weights_jax, + out_grad_jax, + )[1] + )((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) return ( np.asarray(in1_grad), diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py new file mode 100644 index 00000000..9324b785 --- /dev/null +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -0,0 +1,507 @@ +import jax +import jax.numpy as jnp +from jax.extend import core +from jax.interpreters import mlir, ad +from openequivariance.jax.utils import clean_tensors + +# ============================================================================== +# 1. Forward Primitive +# ============================================================================== + +conv_fwd_p = core.Primitive("conv_fwd") + + +def conv_fwd_impl(X, Y, W, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + irrep_dtype = X.dtype + out_shape = jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) + call = jax.ffi.ffi_call("conv_forward", out_shape) + return call(X, Y, W, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash) + + +def conv_fwd_abstract_eval( + X, Y, W, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash +): + return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + + +conv_fwd_p.def_impl(conv_fwd_impl) +conv_fwd_p.def_abstract_eval(conv_fwd_abstract_eval) +mlir.register_lowering( + conv_fwd_p, mlir.lower_fun(conv_fwd_impl, multiple_results=False), platform="cuda" +) +mlir.register_lowering( + conv_fwd_p, mlir.lower_fun(conv_fwd_impl, multiple_results=False), platform="rocm" +) + + +# ============================================================================== +# 2. Backward Primitive +# ============================================================================== + +conv_bwd_p = core.Primitive("conv_bwd") +conv_bwd_p.multiple_results = True + + +def conv_bwd_impl(X, Y, W, dZ, rows, cols, workspace, sender_perm, *, kernel, hash): + irrep_dtype = X.dtype + out_shapes = ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + ) + call = jax.ffi.ffi_call("conv_backward", out_shapes) + return call( + X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) + + +def conv_bwd_abstract_eval( + X, Y, W, dZ, rows, cols, workspace, sender_perm, *, kernel, hash +): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), + ) + + +conv_bwd_p.def_impl(conv_bwd_impl) +conv_bwd_p.def_abstract_eval(conv_bwd_abstract_eval) +mlir.register_lowering( + conv_bwd_p, mlir.lower_fun(conv_bwd_impl, multiple_results=True), platform="cuda" +) +mlir.register_lowering( + conv_bwd_p, mlir.lower_fun(conv_bwd_impl, multiple_results=True), platform="rocm" +) + + +# ============================================================================== +# 3. Double Backward Primitive +# ============================================================================== + +conv_dbwd_p = core.Primitive("conv_dbwd") +conv_dbwd_p.multiple_results = True + + +def conv_dbwd_impl( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash +): + irrep_dtype = X.dtype + out_shapes = ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), + ) + call = jax.ffi.ffi_call("conv_double_backward", out_shapes) + return call( + X, + Y, + W, + dZ, + ddX, + ddY, + ddW, + rows, + cols, + workspace, + sender_perm, + kernel=kernel, + hash=hash, + ) + + +def conv_dbwd_abstract_eval( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash +): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), + jax.core.ShapedArray(dZ.shape, irrep_dtype), + ) + + +conv_dbwd_p.def_impl(conv_dbwd_impl) +conv_dbwd_p.def_abstract_eval(conv_dbwd_abstract_eval) +mlir.register_lowering( + conv_dbwd_p, mlir.lower_fun(conv_dbwd_impl, multiple_results=True), platform="cuda" +) +mlir.register_lowering( + conv_dbwd_p, mlir.lower_fun(conv_dbwd_impl, multiple_results=True), platform="rocm" +) + + +# ============================================================================== +# 4. Forward JVP Primitive Definition +# ============================================================================== + +conv_fwd_jvp_p = core.Primitive("conv_fwd_jvp") + + +def conv_fwd_jvp_impl( + X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash +): + kwargs = dict(L3_dim=L3_dim, kernel=kernel, hash=hash) + args_meta = (rows, cols, workspace, sender_perm) + + term1 = conv_fwd_p.bind(dX, Y, W, *args_meta, **kwargs) + term2 = conv_fwd_p.bind(X, dY, W, *args_meta, **kwargs) + term3 = conv_fwd_p.bind(X, Y, dW, *args_meta, **kwargs) + return term1 + term2 + term3 + + +def conv_fwd_jvp_abstract_eval( + X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash +): + return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + + +conv_fwd_jvp_p.def_impl(conv_fwd_jvp_impl) +conv_fwd_jvp_p.def_abstract_eval(conv_fwd_jvp_abstract_eval) + + +# ============================================================================== +# 5. Transpose Rule (Implicit VJP) +# ============================================================================== + + +def conv_fwd_jvp_transpose( + ct, X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash +): + X, Y, W = clean_tensors(X, Y, W) + + grad_X, grad_Y, grad_W = conv_bwd_p.bind( + X, Y, W, ct, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) + + return (None, None, None, grad_X, grad_Y, grad_W, None, None, None, None) + + +ad.primitive_transposes[conv_fwd_jvp_p] = conv_fwd_jvp_transpose + + +# ============================================================================== +# 6. JVP Rule for Original Forward Primitive +# ============================================================================== + + +def conv_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): + X, Y, W, rows, cols, workspace, sender_perm = primals + dX, dY, dW, drows, dcols, dworkspace, dsender_perm = tangents + + dX, dY, dW = clean_tensors(dX, dY, dW) + out_primal = conv_fwd_p.bind( + X, + Y, + W, + rows, + cols, + workspace, + sender_perm, + L3_dim=L3_dim, + kernel=kernel, + hash=hash, + ) + out_tangent = conv_fwd_jvp_p.bind( + X, + Y, + W, + dX, + dY, + dW, + rows, + cols, + workspace, + sender_perm, + L3_dim=L3_dim, + kernel=kernel, + hash=hash, + ) + + return out_primal, out_tangent + + +ad.primitive_jvps[conv_fwd_p] = conv_fwd_jvp_rule + + +# ============================================================================== +# 7. JVP Rule for Forward JVP Primitive (Higher Order) +# ============================================================================== + + +def conv_fwd_jvp_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): + tangents_clean = tuple(clean_tensors(*tangents)) + + def func(x, y, w, dx, dy, dw, r, c, ws, sp): + return conv_fwd_jvp_impl( + x, y, w, dx, dy, dw, r, c, ws, sp, L3_dim=L3_dim, kernel=kernel, hash=hash + ) + + return jax.jvp(func, primals, tangents_clean) + + +ad.primitive_jvps[conv_fwd_jvp_p] = conv_fwd_jvp_jvp_rule + + +# ============================================================================== +# 8. Backward JVP Primitive Definition +# ============================================================================== + +conv_bwd_jvp_p = core.Primitive("conv_bwd_jvp") +conv_bwd_jvp_p.multiple_results = True + + +def conv_bwd_jvp_impl( + X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, *, kernel, hash +): + kwargs = dict(kernel=kernel, hash=hash) + args_meta = (rows, cols, workspace, sender_perm) + + term_dZ = conv_bwd_p.bind(X, Y, W, tdZ, *args_meta, **kwargs) + term_X = conv_bwd_p.bind(tX, Y, W, dZ, *args_meta, **kwargs) + term_Y = conv_bwd_p.bind(X, tY, W, dZ, *args_meta, **kwargs) + term_W = conv_bwd_p.bind(X, Y, tW, dZ, *args_meta, **kwargs) + + out_dX = term_dZ[0] + term_Y[0] + term_W[0] + out_dY = term_dZ[1] + term_X[1] + term_W[1] + out_dW = term_dZ[2] + term_X[2] + term_Y[2] + + return out_dX, out_dY, out_dW + + +def conv_bwd_jvp_abstract_eval( + X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, *, kernel, hash +): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), + ) + + +conv_bwd_jvp_p.def_impl(conv_bwd_jvp_impl) +conv_bwd_jvp_p.def_abstract_eval(conv_bwd_jvp_abstract_eval) + + +# ============================================================================== +# 9. Transpose Rule for Backward JVP +# ============================================================================== + + +def conv_bwd_jvp_transpose( + ct, + X, + Y, + W, + dZ, + tX, + tY, + tW, + tdZ, + rows, + cols, + workspace, + sender_perm, + *, + kernel, + hash, +): + ddX, ddY, ddW = ct + + if ad.is_undefined_primal(X): + X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): + Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): + W = jnp.zeros(W.aval.shape, W.aval.dtype) + if ad.is_undefined_primal(dZ): + dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) + + tensors_clean = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) + + g_X, g_Y, g_W, g_dZ = conv_dbwd_p.bind( + *tensors_clean, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) + + return (None, None, None, None, g_X, g_Y, g_W, g_dZ, None, None, None, None) + + +ad.primitive_transposes[conv_bwd_jvp_p] = conv_bwd_jvp_transpose + + +# ============================================================================== +# 10. JVP Rule for Backward JVP Primitive (Higher Order) +# ============================================================================== + + +def conv_bwd_jvp_jvp_rule(primals, tangents, *, kernel, hash): + tangents_clean = tuple(clean_tensors(*tangents)) + + def func(x, y, w, dz, tx, ty, tw, tdz, r, c, ws, sp): + return conv_bwd_jvp_impl( + x, y, w, dz, tx, ty, tw, tdz, r, c, ws, sp, kernel=kernel, hash=hash + ) + + return jax.jvp(func, primals, tangents_clean) + + +ad.primitive_jvps[conv_bwd_jvp_p] = conv_bwd_jvp_jvp_rule + + +# ============================================================================== +# 11. JVP Rule for Original Backward Primitive +# ============================================================================== + + +def conv_bwd_jvp_rule(primals, tangents, *, kernel, hash): + X, Y, W, dZ, rows, cols, workspace, sender_perm = primals + tX, tY, tW, tdZ, drows, dcols, dworkspace, dsender_perm = tangents + + tX, tY, tW, tdZ = clean_tensors(tX, tY, tW, tdZ) + + out_primal = conv_bwd_p.bind( + X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) + out_tangent = conv_bwd_jvp_p.bind( + X, + Y, + W, + dZ, + tX, + tY, + tW, + tdZ, + rows, + cols, + workspace, + sender_perm, + kernel=kernel, + hash=hash, + ) + + return out_primal, out_tangent + + +ad.primitive_jvps[conv_bwd_p] = conv_bwd_jvp_rule + + +# ============================================================================== +# 12. Slow Double Backward Implementation (Reference) +# ============================================================================== + + +def conv_dbwd_slow( + X, + Y, + W, + dZ, + ddX, + ddY, + ddW, + rows, + cols, + workspace, + sender_perm, + *, + L3_dim, + kernel, + hash, +): + kwargs = dict(kernel=kernel, hash=hash) + args_meta = (rows, cols, workspace, sender_perm) + + op1 = conv_bwd_p.bind(ddX, ddY, W, dZ, *args_meta, **kwargs) + op2 = conv_bwd_p.bind(X, Y, ddW, dZ, *args_meta, **kwargs) + + op3 = conv_fwd_p.bind(ddX, Y, W, *args_meta, L3_dim=L3_dim, **kwargs) + op4 = conv_bwd_p.bind(ddX, Y, W, dZ, *args_meta, **kwargs) + op5 = conv_bwd_p.bind(X, ddY, W, dZ, *args_meta, **kwargs) + + op6 = conv_fwd_p.bind(X, ddY, W, *args_meta, L3_dim=L3_dim, **kwargs) + op7 = conv_fwd_p.bind(X, Y, ddW, *args_meta, L3_dim=L3_dim, **kwargs) + + grad_X = op1[0] + op2[0] + grad_Y = op1[1] + op2[1] + grad_W = op4[2] + op5[2] + grad_dZ = op3 + op6 + op7 + + return grad_X, grad_Y, grad_W, grad_dZ + + +# ============================================================================== +# 13. JVP rule for double backward (implicit) +# ============================================================================== + + +def conv_dbwd_jvp_rule(primals, tangents, *, kernel, hash): + dZ = primals[3] # Infer L3_dim from dZ (4th input) + L3_dim = dZ.shape[1] + + def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): + return conv_dbwd_slow( + x, + y, + w, + dz, + ddx, + ddy, + ddw, + r, + c, + ws, + sp, + L3_dim=L3_dim, + kernel=kernel, + hash=hash, + ) + + tangents_clean = tuple(clean_tensors(*tangents)) + return jax.jvp(func, primals, tangents_clean) + + +ad.primitive_jvps[conv_dbwd_p] = conv_dbwd_jvp_rule + + +# ============================================================================== +# 14. Transpose rule for double backward +# ============================================================================== + + +def conv_dbwd_transpose( + ct, X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash +): + L3_dim = dZ.shape[1] + + X, Y, W, dZ, ddX, ddY, ddW = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) + + def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): + return conv_dbwd_slow( + x, + y, + w, + dz, + ddx, + ddy, + ddw, + r, + c, + ws, + sp, + L3_dim=L3_dim, + kernel=kernel, + hash=hash, + ) + + _, vjp_fun = jax.vjp( + func, X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm + ) + input_grads = vjp_fun(ct) + + return input_grads + + +ad.primitive_transposes[conv_dbwd_p] = conv_dbwd_transpose diff --git a/openequivariance/openequivariance/jax/jvp/tp_prim.py b/openequivariance/openequivariance/jax/jvp/tp_prim.py new file mode 100644 index 00000000..c31c3ec0 --- /dev/null +++ b/openequivariance/openequivariance/jax/jvp/tp_prim.py @@ -0,0 +1,364 @@ +import jax +import jax.numpy as jnp +from jax.extend import core +from jax.interpreters import mlir, ad +from openequivariance.jax.utils import clean_tensors + +# ============================================================================== +# 1. Forward Primitive +# ============================================================================== + +tp_fwd_p = core.Primitive("tp_fwd") + + +def tp_fwd_impl(X, Y, W, *, L3_dim, kernel, hash): + irrep_dtype = X.dtype + out_shape = jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) + call = jax.ffi.ffi_call("tp_forward", out_shape) + return call(X, Y, W, kernel=kernel, hash=hash) + + +def tp_fwd_abstract_eval(X, Y, W, *, L3_dim, kernel, hash): + return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + + +tp_fwd_p.def_impl(tp_fwd_impl) +tp_fwd_p.def_abstract_eval(tp_fwd_abstract_eval) +mlir.register_lowering( + tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="cuda" +) +mlir.register_lowering( + tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="rocm" +) + + +# ============================================================================== +# 2. Backward Primitive +# ============================================================================== + +tp_bwd_p = core.Primitive("tp_bwd") +tp_bwd_p.multiple_results = True + + +def tp_bwd_impl(X, Y, W, dZ, *, kernel, hash): + irrep_dtype = X.dtype + out_shapes = ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + ) + call = jax.ffi.ffi_call("tp_backward", out_shapes) + return call(X, Y, W, dZ, kernel=kernel, hash=hash) + + +def tp_bwd_abstract_eval(X, Y, W, dZ, *, kernel, hash): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), + ) + + +tp_bwd_p.def_impl(tp_bwd_impl) +tp_bwd_p.def_abstract_eval(tp_bwd_abstract_eval) +mlir.register_lowering( + tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="cuda" +) +mlir.register_lowering( + tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="rocm" +) + + +# ============================================================================== +# 3. Double Backward Primitive +# ============================================================================== + +tp_dbwd_p = core.Primitive("tp_dbwd") +tp_dbwd_p.multiple_results = True + + +def tp_dbwd_impl(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): + irrep_dtype = X.dtype + out_shapes = ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), + ) + call = jax.ffi.ffi_call("tp_double_backward", out_shapes) + return call(X, Y, W, dZ, ddX, ddY, ddW, kernel=kernel, hash=hash) + + +def tp_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), + jax.core.ShapedArray(dZ.shape, irrep_dtype), + ) + + +tp_dbwd_p.def_impl(tp_dbwd_impl) +tp_dbwd_p.def_abstract_eval(tp_dbwd_abstract_eval) +mlir.register_lowering( + tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="cuda" +) +mlir.register_lowering( + tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="rocm" +) + + +# ============================================================================== +# 4. Forward JVP Primitive Definition +# ============================================================================== + +tp_fwd_jvp_p = core.Primitive("tp_fwd_jvp") + + +def tp_fwd_jvp_impl(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): + kwargs = dict(L3_dim=L3_dim, kernel=kernel, hash=hash) + + term1 = tp_fwd_p.bind(dX, Y, W, **kwargs) + term2 = tp_fwd_p.bind(X, dY, W, **kwargs) + term3 = tp_fwd_p.bind(X, Y, dW, **kwargs) + return term1 + term2 + term3 + + +def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): + return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + + +tp_fwd_jvp_p.def_impl(tp_fwd_jvp_impl) +tp_fwd_jvp_p.def_abstract_eval(tp_fwd_jvp_abstract_eval) + + +# ============================================================================== +# 5. Transpose Rule (Implicit VJP) +# ============================================================================== + + +def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): + X, Y, W = clean_tensors(X, Y, W) + + grad_X, grad_Y, grad_W = tp_bwd_p.bind(X, Y, W, ct, kernel=kernel, hash=hash) + + return (None, None, None, grad_X, grad_Y, grad_W) + + +ad.primitive_transposes[tp_fwd_jvp_p] = tp_fwd_jvp_transpose + + +# ============================================================================== +# 6. JVP Rule for Original Forward Primitive +# ============================================================================== + + +def tp_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): + X, Y, W = primals + dX, dY, dW = tangents + + dX, dY, dW = clean_tensors(dX, dY, dW) + + out_primal = tp_fwd_p.bind(X, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) + out_tangent = tp_fwd_jvp_p.bind( + X, Y, W, dX, dY, dW, L3_dim=L3_dim, kernel=kernel, hash=hash + ) + + return out_primal, out_tangent + + +ad.primitive_jvps[tp_fwd_p] = tp_fwd_jvp_rule + + +# ============================================================================== +# 7. JVP Rule for Forward JVP Primitive (Higher Order) +# ============================================================================== + + +def tp_fwd_jvp_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): + tangents_clean = tuple(clean_tensors(*tangents)) + + def func(x, y, w, dx, dy, dw): + return tp_fwd_jvp_impl( + x, y, w, dx, dy, dw, L3_dim=L3_dim, kernel=kernel, hash=hash + ) + + return jax.jvp(func, primals, tangents_clean) + + +ad.primitive_jvps[tp_fwd_jvp_p] = tp_fwd_jvp_jvp_rule + + +# ============================================================================== +# 8. Backward JVP Primitive Definition +# ============================================================================== + +tp_bwd_jvp_p = core.Primitive("tp_bwd_jvp") +tp_bwd_jvp_p.multiple_results = True + + +def tp_bwd_jvp_impl(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): + kwargs = dict(kernel=kernel, hash=hash) + + term_dZ = tp_bwd_p.bind(X, Y, W, tdZ, **kwargs) + term_X = tp_bwd_p.bind(tX, Y, W, dZ, **kwargs) + term_Y = tp_bwd_p.bind(X, tY, W, dZ, **kwargs) + term_W = tp_bwd_p.bind(X, Y, tW, dZ, **kwargs) + + out_dX = term_dZ[0] + term_Y[0] + term_W[0] + out_dY = term_dZ[1] + term_X[1] + term_W[1] + out_dW = term_dZ[2] + term_X[2] + term_Y[2] + + return out_dX, out_dY, out_dW + + +def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), + ) + + +tp_bwd_jvp_p.def_impl(tp_bwd_jvp_impl) +tp_bwd_jvp_p.def_abstract_eval(tp_bwd_jvp_abstract_eval) + + +# ============================================================================== +# 9. Transpose Rule for Backward JVP +# ============================================================================== + + +def tp_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): + ddX, ddY, ddW = ct + + if ad.is_undefined_primal(X): + X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): + Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): + W = jnp.zeros(W.aval.shape, W.aval.dtype) + if ad.is_undefined_primal(dZ): + dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) + + tensors_clean = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) + + g_X, g_Y, g_W, g_dZ = tp_dbwd_p.bind(*tensors_clean, kernel=kernel, hash=hash) + + return (None, None, None, None, g_X, g_Y, g_W, g_dZ) + + +ad.primitive_transposes[tp_bwd_jvp_p] = tp_bwd_jvp_transpose + + +# ============================================================================== +# 10. JVP Rule for Backward JVP Primitive (Higher Order) +# ============================================================================== + + +def tp_bwd_jvp_jvp_rule(primals, tangents, *, kernel, hash): + tangents_clean = tuple(clean_tensors(*tangents)) + + def func(x, y, w, dz, tx, ty, tw, tdz): + return tp_bwd_jvp_impl(x, y, w, dz, tx, ty, tw, tdz, kernel=kernel, hash=hash) + + return jax.jvp(func, primals, tangents_clean) + + +ad.primitive_jvps[tp_bwd_jvp_p] = tp_bwd_jvp_jvp_rule + + +# ============================================================================== +# 11. JVP Rule for Original Backward Primitive +# ============================================================================== + + +def tp_bwd_jvp_rule(primals, tangents, *, kernel, hash): + X, Y, W, dZ = primals + tX, tY, tW, tdZ = tangents + + tX, tY, tW, tdZ = clean_tensors(tX, tY, tW, tdZ) + + out_primal = tp_bwd_p.bind(X, Y, W, dZ, kernel=kernel, hash=hash) + out_tangent = tp_bwd_jvp_p.bind( + X, Y, W, dZ, tX, tY, tW, tdZ, kernel=kernel, hash=hash + ) + + return out_primal, out_tangent + + +ad.primitive_jvps[tp_bwd_p] = tp_bwd_jvp_rule + + +# ============================================================================== +# 12. Slow Double Backward Implementation (Reference) +# ============================================================================== + + +def tp_dbwd_slow(X, Y, W, dZ, ddX, ddY, ddW, *, L3_dim, kernel, hash): + kwargs = dict(kernel=kernel, hash=hash) + + op1 = tp_bwd_p.bind(ddX, ddY, W, dZ, **kwargs) + op2 = tp_bwd_p.bind(X, Y, ddW, dZ, **kwargs) + + op3 = tp_fwd_p.bind(ddX, Y, W, L3_dim=L3_dim, **kwargs) + op4 = tp_bwd_p.bind(ddX, Y, W, dZ, **kwargs) + op5 = tp_bwd_p.bind(X, ddY, W, dZ, **kwargs) + + op6 = tp_fwd_p.bind(X, ddY, W, L3_dim=L3_dim, **kwargs) + op7 = tp_fwd_p.bind(X, Y, ddW, L3_dim=L3_dim, **kwargs) + + grad_X = op1[0] + op2[0] + grad_Y = op1[1] + op2[1] + grad_W = op4[2] + op5[2] + grad_dZ = op3 + op6 + op7 + + return grad_X, grad_Y, grad_W, grad_dZ + + +# ============================================================================== +# 13. JVP rule for double backward (implicit) +# ============================================================================== + + +def tp_dbwd_jvp_rule(primals, tangents, *, kernel, hash): + dZ = primals[3] # Infer L3_dim from dZ (4th input) + L3_dim = dZ.shape[1] + + def func(x, y, w, dz, ddx, ddy, ddw): + return tp_dbwd_slow( + x, y, w, dz, ddx, ddy, ddw, L3_dim=L3_dim, kernel=kernel, hash=hash + ) + + tangents_clean = tuple(clean_tensors(*tangents)) + return jax.jvp(func, primals, tangents_clean) + + +ad.primitive_jvps[tp_dbwd_p] = tp_dbwd_jvp_rule + + +# ============================================================================== +# 14. Transpose rule for double backward +# ============================================================================== + + +def tp_dbwd_transpose(ct, X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): + L3_dim = dZ.shape[1] + + X, Y, W, dZ, ddX, ddY, ddW = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) + + def func(x, y, w, dz, ddx, ddy, ddw): + return tp_dbwd_slow( + x, y, w, dz, ddx, ddy, ddw, L3_dim=L3_dim, kernel=kernel, hash=hash + ) + + _, vjp_fun = jax.vjp(func, X, Y, W, dZ, ddX, ddY, ddW) + input_grads = vjp_fun(ct) + + return input_grads + + +ad.primitive_transposes[tp_dbwd_p] = tp_dbwd_transpose diff --git a/openequivariance/openequivariance/jax/utils.py b/openequivariance/openequivariance/jax/utils.py index ae15d1a6..371b0ae5 100644 --- a/openequivariance/openequivariance/jax/utils.py +++ b/openequivariance/openequivariance/jax/utils.py @@ -1,6 +1,7 @@ import jax import jax.numpy as jnp import numpy as np +from jax.interpreters import ad def reorder_jax_helper(schedule, weights_in, direction, has_batch_dim): @@ -61,3 +62,13 @@ def reorder_jax(schedule, weights_in, direction, has_batch_dim): return reorder_jax_helper(schedule, weights_in, direction, has_batch_dim) else: return reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim) + + +def clean_tensors(*tensors): + tensors_clean = [] + for t in tensors: + result = t + if type(t) is ad.Zero or ad.is_undefined_primal(t): + result = jnp.zeros(t.aval.shape, t.aval.dtype) + tensors_clean.append(result) + return tensors_clean diff --git a/openequivariance/openequivariance/jax/vjp/conv_func.py b/openequivariance/openequivariance/jax/vjp/conv_func.py new file mode 100644 index 00000000..20d7ccc2 --- /dev/null +++ b/openequivariance/openequivariance/jax/vjp/conv_func.py @@ -0,0 +1,162 @@ +import jax +import jax.numpy as jnp +from functools import partial + + +def zeros_like(x): + return jnp.zeros_like(x) + + +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) +def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, kernel, hash): + forward_call = jax.ffi.ffi_call( + "conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), X.dtype) + ) + return forward_call( + X, Y, W, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) + + +def forward_fwd(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, kernel, hash): + out = forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, kernel, hash) + return out, (X, Y, W, rows, cols) + + +def forward_bwd(workspace, sender_perm, L3_dim, kernel, hash, res, dZ): + X, Y, W, rows, cols = res + dX, dY, dW = backward( + X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) + return dX, dY, dW, None, None + + +forward.defvjp(forward_fwd, forward_bwd) + + +@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9)) +def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel, hash): + backward_call = jax.ffi.ffi_call( + "conv_backward", + ( + jax.ShapeDtypeStruct(X.shape, X.dtype), + jax.ShapeDtypeStruct(Y.shape, Y.dtype), + jax.ShapeDtypeStruct(W.shape, W.dtype), + ), + ) + return backward_call( + X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) + + +def backward_fwd(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel, hash): + out = backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel, hash) + return out, (X, Y, W, dZ, rows, cols) + + +def backward_bwd(workspace, sender_perm, kernel, hash, res, derivatives): + X, Y, W, dZ, rows, cols = res + ddX, ddY, ddW = derivatives + + gX, gY, gW, gdZ = double_backward( + X, + Y, + W, + dZ, + ddX, + ddY, + ddW, + rows, + cols, + workspace, + sender_perm, + kernel, + hash, + ) + + return gX, gY, gW, gdZ, None, None + + +backward.defvjp(backward_fwd, backward_bwd) + + +@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12)) +def double_backward( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel, hash +): + double_backward_call = jax.ffi.ffi_call( + "conv_double_backward", + ( + jax.ShapeDtypeStruct(X.shape, X.dtype), + jax.ShapeDtypeStruct(Y.shape, Y.dtype), + jax.ShapeDtypeStruct(W.shape, W.dtype), + jax.ShapeDtypeStruct(dZ.shape, dZ.dtype), + ), + ) + return double_backward_call( + X, + Y, + W, + dZ, + ddX, + ddY, + ddW, + rows, + cols, + workspace, + sender_perm, + kernel=kernel, + hash=hash, + ) + + +def double_backward_fwd( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel, hash +): + out = double_backward( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel, hash + ) + return out, (X, Y, W, dZ, ddX, ddY, ddW, rows, cols) + + +def triple_backward( + workspace, + sender_perm, + kernel, + hash, + residuals, + tangent_outputs, +): + X, Y, W, dZ, ddX, ddY, ddW, rows, cols = residuals + t_dX, t_dY, t_dW, t_ddZ = tangent_outputs + + common_args = (rows, cols, workspace, sender_perm, kernel, hash) + + op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W)) + g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, *common_args) + + op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW)) + g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, *common_args) + + op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW) + g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, *common_args) + + op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW) + g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, *common_args) + + g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, *common_args) + g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, *common_args) + g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, *common_args) + + grad_X = g2_X + g4_X + g6_X + g7_X + grad_Y = g2_Y + g3_Y + g5_Y + g7_Y + grad_W = g1_W + g3_W + g4_W + g5_W + g6_W + grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ + + grad_ddX = g1_ddX + g3_ddX + g5_ddX + grad_ddY = g1_ddY + g4_ddY + g6_ddY + grad_ddW = g2_ddW + g7_ddW + + return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW, None, None + + +double_backward.defvjp(double_backward_fwd, triple_backward) diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 25fec285..91617d94 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -52,6 +52,7 @@ endif() set(HEADER_DIR "src/extension") set(OEQ_JAX_SOURCES src/libjax_tp_jit.cpp + ${HEADER_DIR}/json11/json11.cpp ) set(OEQ_JAX_HEADERS @@ -60,6 +61,7 @@ set(OEQ_JAX_HEADERS ${HEADER_DIR}/util/backend_cuda.hpp ${HEADER_DIR}/util/backend_hip.hpp ${HEADER_DIR}/util/buffer.hpp + ${HEADER_DIR}/json11/json11.hpp ) nanobind_add_module(openequivariance_extjax NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_JAX_HEADERS}) diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index c9b2856f..def67e12 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance_extjax" -version = "0.1.0" +version = "0.2.0" authors = [ { name="Austin Glover" }, { name="Vivek Bharadwaj" }, diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index ae2035e8..bf3fd69e 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -8,9 +8,11 @@ #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" +#include "json11/json11.hpp" namespace nb = nanobind; namespace ffi = xla::ffi; +using json = json11::Json; #ifdef CUDA_BACKEND #include @@ -131,6 +133,14 @@ void zero_buffer(ffi::AnyBuffer &buffer, stream_t stream) { } #endif +std::unordered_map parse_json_config(const json &j_obj) { + std::unordered_map result; + for (const auto &kv : j_obj.object_items()) { + result[kv.first] = static_cast(kv.second.number_value()); + } + return result; +} + struct KernelProp { int64_t L1_dim, L2_dim, L3_dim, weight_numel; bool shared_weights; @@ -175,39 +185,8 @@ std::unordered_map> conv_cache; std::mutex mut; -std::vector launch_config_keys = { - "num_blocks", - "num_threads", - "smem"}; -std::vector kernel_prop_keys = { - "L1_dim", - "L2_dim", - "L3_dim", - "weight_numel", - "shared_weights", - "opt_level", - "irrep_dtype", - "weight_dtype", - - // Convolution only - "workspace_size", - "deterministic", - "idx_dtype"}; - -std::unordered_map parse_ffi_dict(ffi::Dictionary &dict, const std::vector &keys) { - std::unordered_map result; - for (const auto &key : keys) { - result[key] = dict.get(key).value(); - } - return result; -} - std::pair*, KernelProp> - compile_tp_with_caching(std::string_view kernel, - ffi::Dictionary forward_config, - ffi::Dictionary backward_config, - ffi::Dictionary double_backward_config, - ffi::Dictionary kernel_prop, + compile_tp_with_caching(std::string_view json_payload, int64_t hash, bool is_convolution) { @@ -215,12 +194,21 @@ std::pair*, KernelProp> const std::lock_guard lock(mut); auto it = tp_cache.find(hash); if (it == tp_cache.end()) { - auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); + std::string err; + json root = json::parse(std::string(json_payload), err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + auto jit_tp_impl = std::make_unique>( - std::string(kernel), - parse_ffi_dict(forward_config, launch_config_keys), - parse_ffi_dict(backward_config, launch_config_keys), - parse_ffi_dict(double_backward_config, launch_config_keys), + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, kernel_prop_map); tp_cache.insert({hash, std::make_pair(std::move(jit_tp_impl), @@ -232,11 +220,7 @@ std::pair*, KernelProp> } std::pair*, KernelProp> - compile_conv_with_caching(std::string_view kernel, - ffi::Dictionary forward_config, - ffi::Dictionary backward_config, - ffi::Dictionary double_backward_config, - ffi::Dictionary kernel_prop, + compile_conv_with_caching(std::string_view json_payload, int64_t hash, bool is_convolution) { @@ -244,12 +228,21 @@ std::pair*, KernelProp> const std::lock_guard lock(mut); auto it = conv_cache.find(hash); if (it == conv_cache.end()) { - auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); + std::string err; + json root = json::parse(std::string(json_payload), err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + auto jit_conv_impl = std::make_unique>( - std::string(kernel), - parse_ffi_dict(forward_config, launch_config_keys), - parse_ffi_dict(backward_config, launch_config_keys), - parse_ffi_dict(double_backward_config, launch_config_keys), + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, kernel_prop_map); conv_cache.insert({hash, std::make_pair(std::move(jit_conv_impl), @@ -300,12 +293,12 @@ ffi::Error tp_forward_impl( ffi::AnyBuffer L2_in, ffi::AnyBuffer W, ffi::Result L3_out, - stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + stream_t stream, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + kernel_json, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); @@ -336,11 +329,11 @@ ffi::Error tp_backward_impl( ffi::Result L2_grad, ffi::Result W_grad, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + kernel_json, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -386,11 +379,11 @@ ffi::Error tp_double_backward_impl( ffi::Result W_grad, ffi::Result L3_dgrad, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + kernel_json, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -435,7 +428,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled @@ -450,7 +443,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -469,7 +462,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -484,11 +477,11 @@ ffi::Error conv_forward_impl( ffi::AnyBuffer transpose_perm, ffi::Result L3_out, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + kernel_json, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -539,11 +532,11 @@ ffi::Error conv_backward_impl( ffi::AnyBuffer workspace, ffi::AnyBuffer transpose_perm, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + kernel_json, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -562,6 +555,8 @@ ffi::Error conv_backward_impl( workspace_ptr = nullptr; } zero_buffer(*L1_grad, stream); + zero_buffer(*L2_grad, stream); + zero_buffer(*W_grad, stream); if (k.shared_weights) { check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); @@ -571,8 +566,6 @@ ffi::Error conv_backward_impl( check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); check_tensor(*W_grad, {nnz, k.weight_numel}, k.weight_dtype, "W_grad"); } - if(k.shared_weights) - zero_buffer(*W_grad, stream); jit_kernel->backward( data_ptr(L1_in), @@ -608,11 +601,11 @@ ffi::Error conv_double_backward_impl( ffi::AnyBuffer workspace, ffi::AnyBuffer transpose_perm, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + kernel_json, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -633,8 +626,9 @@ ffi::Error conv_double_backward_impl( workspace_ptr = nullptr; } zero_buffer(*L1_grad, stream); + zero_buffer(*L2_grad, stream); + zero_buffer(*W_grad, stream); zero_buffer(*L3_dgrad, stream); - if (k.shared_weights) { check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); @@ -643,8 +637,6 @@ ffi::Error conv_double_backward_impl( check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); check_tensor(W_dgrad, {nnz, k.weight_numel}, k.weight_dtype, "W_dgrad"); } - if(k.shared_weights) - zero_buffer(*W_grad, stream); jit_kernel->double_backward( data_ptr(L1_in), @@ -689,7 +681,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -708,7 +700,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -731,7 +723,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -764,10 +756,4 @@ NB_MODULE(openequivariance_extjax, m) { .def("start", &GPUTimer::start) .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) .def("clear_L2_cache", &GPUTimer::clear_L2_cache); - - /*nb::class_>(m, "DeviceBuffer") - .def(nb::init()) - .def(nb::init()) - .def("copy_to_host", &PyDeviceBuffer::copy_to_host) - .def("data_ptr", &PyDeviceBuffer::data_ptr);*/ -} +} \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 0e7098e0..323de863 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ os.environ["JAX_ENABLE_X64"] = "True" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" +os.environ["JAX_TRACEBACK_FILTERING"] = "off" def pytest_addoption(parser):