From 3f5294ade3ba6066e93b139c7ce8cd2cb15e08d8 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 24 Jan 2026 23:47:56 -0800 Subject: [PATCH 01/18] Added a JSN parser. --- .../extension/json11/json11.cpp | 790 ++++++++++++++++++ .../extension/json11/json11.hpp | 232 +++++ openequivariance_extjax/CMakeLists.txt | 2 + openequivariance_extjax/src/libjax_tp_jit.cpp | 131 ++- 4 files changed, 1083 insertions(+), 72 deletions(-) create mode 100644 openequivariance/openequivariance/extension/json11/json11.cpp create mode 100644 openequivariance/openequivariance/extension/json11/json11.hpp diff --git a/openequivariance/openequivariance/extension/json11/json11.cpp b/openequivariance/openequivariance/extension/json11/json11.cpp new file mode 100644 index 00000000..f3140252 --- /dev/null +++ b/openequivariance/openequivariance/extension/json11/json11.cpp @@ -0,0 +1,790 @@ +/* Copyright (c) 2013 Dropbox, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "json11.hpp" +#include +#include +#include +#include +#include + +namespace json11 { + +static const int max_depth = 200; + +using std::string; +using std::vector; +using std::map; +using std::make_shared; +using std::initializer_list; +using std::move; + +/* Helper for representing null - just a do-nothing struct, plus comparison + * operators so the helpers in JsonValue work. We can't use nullptr_t because + * it may not be orderable. + */ +struct NullStruct { + bool operator==(NullStruct) const { return true; } + bool operator<(NullStruct) const { return false; } +}; + +/* * * * * * * * * * * * * * * * * * * * + * Serialization + */ + +static void dump(NullStruct, string &out) { + out += "null"; +} + +static void dump(double value, string &out) { + if (std::isfinite(value)) { + char buf[32]; + snprintf(buf, sizeof buf, "%.17g", value); + out += buf; + } else { + out += "null"; + } +} + +static void dump(int value, string &out) { + char buf[32]; + snprintf(buf, sizeof buf, "%d", value); + out += buf; +} + +static void dump(bool value, string &out) { + out += value ? "true" : "false"; +} + +static void dump(const string &value, string &out) { + out += '"'; + for (size_t i = 0; i < value.length(); i++) { + const char ch = value[i]; + if (ch == '\\') { + out += "\\\\"; + } else if (ch == '"') { + out += "\\\""; + } else if (ch == '\b') { + out += "\\b"; + } else if (ch == '\f') { + out += "\\f"; + } else if (ch == '\n') { + out += "\\n"; + } else if (ch == '\r') { + out += "\\r"; + } else if (ch == '\t') { + out += "\\t"; + } else if (static_cast(ch) <= 0x1f) { + char buf[8]; + snprintf(buf, sizeof buf, "\\u%04x", ch); + out += buf; + } else if (static_cast(ch) == 0xe2 && static_cast(value[i+1]) == 0x80 + && static_cast(value[i+2]) == 0xa8) { + out += "\\u2028"; + i += 2; + } else if (static_cast(ch) == 0xe2 && static_cast(value[i+1]) == 0x80 + && static_cast(value[i+2]) == 0xa9) { + out += "\\u2029"; + i += 2; + } else { + out += ch; + } + } + out += '"'; +} + +static void dump(const Json::array &values, string &out) { + bool first = true; + out += "["; + for (const auto &value : values) { + if (!first) + out += ", "; + value.dump(out); + first = false; + } + out += "]"; +} + +static void dump(const Json::object &values, string &out) { + bool first = true; + out += "{"; + for (const auto &kv : values) { + if (!first) + out += ", "; + dump(kv.first, out); + out += ": "; + kv.second.dump(out); + first = false; + } + out += "}"; +} + +void Json::dump(string &out) const { + m_ptr->dump(out); +} + +/* * * * * * * * * * * * * * * * * * * * + * Value wrappers + */ + +template +class Value : public JsonValue { +protected: + + // Constructors + explicit Value(const T &value) : m_value(value) {} + explicit Value(T &&value) : m_value(move(value)) {} + + // Get type tag + Json::Type type() const override { + return tag; + } + + // Comparisons + bool equals(const JsonValue * other) const override { + return m_value == static_cast *>(other)->m_value; + } + bool less(const JsonValue * other) const override { + return m_value < static_cast *>(other)->m_value; + } + + const T m_value; + void dump(string &out) const override { json11::dump(m_value, out); } +}; + +class JsonDouble final : public Value { + double number_value() const override { return m_value; } + int int_value() const override { return static_cast(m_value); } + bool equals(const JsonValue * other) const override { return m_value == other->number_value(); } + bool less(const JsonValue * other) const override { return m_value < other->number_value(); } +public: + explicit JsonDouble(double value) : Value(value) {} +}; + +class JsonInt final : public Value { + double number_value() const override { return m_value; } + int int_value() const override { return m_value; } + bool equals(const JsonValue * other) const override { return m_value == other->number_value(); } + bool less(const JsonValue * other) const override { return m_value < other->number_value(); } +public: + explicit JsonInt(int value) : Value(value) {} +}; + +class JsonBoolean final : public Value { + bool bool_value() const override { return m_value; } +public: + explicit JsonBoolean(bool value) : Value(value) {} +}; + +class JsonString final : public Value { + const string &string_value() const override { return m_value; } +public: + explicit JsonString(const string &value) : Value(value) {} + explicit JsonString(string &&value) : Value(move(value)) {} +}; + +class JsonArray final : public Value { + const Json::array &array_items() const override { return m_value; } + const Json & operator[](size_t i) const override; +public: + explicit JsonArray(const Json::array &value) : Value(value) {} + explicit JsonArray(Json::array &&value) : Value(move(value)) {} +}; + +class JsonObject final : public Value { + const Json::object &object_items() const override { return m_value; } + const Json & operator[](const string &key) const override; +public: + explicit JsonObject(const Json::object &value) : Value(value) {} + explicit JsonObject(Json::object &&value) : Value(move(value)) {} +}; + +class JsonNull final : public Value { +public: + JsonNull() : Value({}) {} +}; + +/* * * * * * * * * * * * * * * * * * * * + * Static globals - static-init-safe + */ +struct Statics { + const std::shared_ptr null = make_shared(); + const std::shared_ptr t = make_shared(true); + const std::shared_ptr f = make_shared(false); + const string empty_string; + const vector empty_vector; + const map empty_map; + Statics() {} +}; + +static const Statics & statics() { + static const Statics s {}; + return s; +} + +static const Json & static_null() { + // This has to be separate, not in Statics, because Json() accesses statics().null. + static const Json json_null; + return json_null; +} + +/* * * * * * * * * * * * * * * * * * * * + * Constructors + */ + +Json::Json() noexcept : m_ptr(statics().null) {} +Json::Json(std::nullptr_t) noexcept : m_ptr(statics().null) {} +Json::Json(double value) : m_ptr(make_shared(value)) {} +Json::Json(int value) : m_ptr(make_shared(value)) {} +Json::Json(bool value) : m_ptr(value ? statics().t : statics().f) {} +Json::Json(const string &value) : m_ptr(make_shared(value)) {} +Json::Json(string &&value) : m_ptr(make_shared(move(value))) {} +Json::Json(const char * value) : m_ptr(make_shared(value)) {} +Json::Json(const Json::array &values) : m_ptr(make_shared(values)) {} +Json::Json(Json::array &&values) : m_ptr(make_shared(move(values))) {} +Json::Json(const Json::object &values) : m_ptr(make_shared(values)) {} +Json::Json(Json::object &&values) : m_ptr(make_shared(move(values))) {} + +/* * * * * * * * * * * * * * * * * * * * + * Accessors + */ + +Json::Type Json::type() const { return m_ptr->type(); } +double Json::number_value() const { return m_ptr->number_value(); } +int Json::int_value() const { return m_ptr->int_value(); } +bool Json::bool_value() const { return m_ptr->bool_value(); } +const string & Json::string_value() const { return m_ptr->string_value(); } +const vector & Json::array_items() const { return m_ptr->array_items(); } +const map & Json::object_items() const { return m_ptr->object_items(); } +const Json & Json::operator[] (size_t i) const { return (*m_ptr)[i]; } +const Json & Json::operator[] (const string &key) const { return (*m_ptr)[key]; } + +double JsonValue::number_value() const { return 0; } +int JsonValue::int_value() const { return 0; } +bool JsonValue::bool_value() const { return false; } +const string & JsonValue::string_value() const { return statics().empty_string; } +const vector & JsonValue::array_items() const { return statics().empty_vector; } +const map & JsonValue::object_items() const { return statics().empty_map; } +const Json & JsonValue::operator[] (size_t) const { return static_null(); } +const Json & JsonValue::operator[] (const string &) const { return static_null(); } + +const Json & JsonObject::operator[] (const string &key) const { + auto iter = m_value.find(key); + return (iter == m_value.end()) ? static_null() : iter->second; +} +const Json & JsonArray::operator[] (size_t i) const { + if (i >= m_value.size()) return static_null(); + else return m_value[i]; +} + +/* * * * * * * * * * * * * * * * * * * * + * Comparison + */ + +bool Json::operator== (const Json &other) const { + if (m_ptr == other.m_ptr) + return true; + if (m_ptr->type() != other.m_ptr->type()) + return false; + + return m_ptr->equals(other.m_ptr.get()); +} + +bool Json::operator< (const Json &other) const { + if (m_ptr == other.m_ptr) + return false; + if (m_ptr->type() != other.m_ptr->type()) + return m_ptr->type() < other.m_ptr->type(); + + return m_ptr->less(other.m_ptr.get()); +} + +/* * * * * * * * * * * * * * * * * * * * + * Parsing + */ + +/* esc(c) + * + * Format char c suitable for printing in an error message. + */ +static inline string esc(char c) { + char buf[12]; + if (static_cast(c) >= 0x20 && static_cast(c) <= 0x7f) { + snprintf(buf, sizeof buf, "'%c' (%d)", c, c); + } else { + snprintf(buf, sizeof buf, "(%d)", c); + } + return string(buf); +} + +static inline bool in_range(long x, long lower, long upper) { + return (x >= lower && x <= upper); +} + +namespace { +/* JsonParser + * + * Object that tracks all state of an in-progress parse. + */ +struct JsonParser final { + + /* State + */ + const string &str; + size_t i; + string &err; + bool failed; + const JsonParse strategy; + + /* fail(msg, err_ret = Json()) + * + * Mark this parse as failed. + */ + Json fail(string &&msg) { + return fail(move(msg), Json()); + } + + template + T fail(string &&msg, const T err_ret) { + if (!failed) + err = std::move(msg); + failed = true; + return err_ret; + } + + /* consume_whitespace() + * + * Advance until the current character is non-whitespace. + */ + void consume_whitespace() { + while (str[i] == ' ' || str[i] == '\r' || str[i] == '\n' || str[i] == '\t') + i++; + } + + /* consume_comment() + * + * Advance comments (c-style inline and multiline). + */ + bool consume_comment() { + bool comment_found = false; + if (str[i] == '/') { + i++; + if (i == str.size()) + return fail("unexpected end of input after start of comment", false); + if (str[i] == '/') { // inline comment + i++; + // advance until next line, or end of input + while (i < str.size() && str[i] != '\n') { + i++; + } + comment_found = true; + } + else if (str[i] == '*') { // multiline comment + i++; + if (i > str.size()-2) + return fail("unexpected end of input inside multi-line comment", false); + // advance until closing tokens + while (!(str[i] == '*' && str[i+1] == '/')) { + i++; + if (i > str.size()-2) + return fail( + "unexpected end of input inside multi-line comment", false); + } + i += 2; + comment_found = true; + } + else + return fail("malformed comment", false); + } + return comment_found; + } + + /* consume_garbage() + * + * Advance until the current character is non-whitespace and non-comment. + */ + void consume_garbage() { + consume_whitespace(); + if(strategy == JsonParse::COMMENTS) { + bool comment_found = false; + do { + comment_found = consume_comment(); + if (failed) return; + consume_whitespace(); + } + while(comment_found); + } + } + + /* get_next_token() + * + * Return the next non-whitespace character. If the end of the input is reached, + * flag an error and return 0. + */ + char get_next_token() { + consume_garbage(); + if (failed) return static_cast(0); + if (i == str.size()) + return fail("unexpected end of input", static_cast(0)); + + return str[i++]; + } + + /* encode_utf8(pt, out) + * + * Encode pt as UTF-8 and add it to out. + */ + void encode_utf8(long pt, string & out) { + if (pt < 0) + return; + + if (pt < 0x80) { + out += static_cast(pt); + } else if (pt < 0x800) { + out += static_cast((pt >> 6) | 0xC0); + out += static_cast((pt & 0x3F) | 0x80); + } else if (pt < 0x10000) { + out += static_cast((pt >> 12) | 0xE0); + out += static_cast(((pt >> 6) & 0x3F) | 0x80); + out += static_cast((pt & 0x3F) | 0x80); + } else { + out += static_cast((pt >> 18) | 0xF0); + out += static_cast(((pt >> 12) & 0x3F) | 0x80); + out += static_cast(((pt >> 6) & 0x3F) | 0x80); + out += static_cast((pt & 0x3F) | 0x80); + } + } + + /* parse_string() + * + * Parse a string, starting at the current position. + */ + string parse_string() { + string out; + long last_escaped_codepoint = -1; + while (true) { + if (i == str.size()) + return fail("unexpected end of input in string", ""); + + char ch = str[i++]; + + if (ch == '"') { + encode_utf8(last_escaped_codepoint, out); + return out; + } + + if (in_range(ch, 0, 0x1f)) + return fail("unescaped " + esc(ch) + " in string", ""); + + // The usual case: non-escaped characters + if (ch != '\\') { + encode_utf8(last_escaped_codepoint, out); + last_escaped_codepoint = -1; + out += ch; + continue; + } + + // Handle escapes + if (i == str.size()) + return fail("unexpected end of input in string", ""); + + ch = str[i++]; + + if (ch == 'u') { + // Extract 4-byte escape sequence + string esc = str.substr(i, 4); + // Explicitly check length of the substring. The following loop + // relies on std::string returning the terminating NUL when + // accessing str[length]. Checking here reduces brittleness. + if (esc.length() < 4) { + return fail("bad \\u escape: " + esc, ""); + } + for (size_t j = 0; j < 4; j++) { + if (!in_range(esc[j], 'a', 'f') && !in_range(esc[j], 'A', 'F') + && !in_range(esc[j], '0', '9')) + return fail("bad \\u escape: " + esc, ""); + } + + long codepoint = strtol(esc.data(), nullptr, 16); + + // JSON specifies that characters outside the BMP shall be encoded as a pair + // of 4-hex-digit \u escapes encoding their surrogate pair components. Check + // whether we're in the middle of such a beast: the previous codepoint was an + // escaped lead (high) surrogate, and this is a trail (low) surrogate. + if (in_range(last_escaped_codepoint, 0xD800, 0xDBFF) + && in_range(codepoint, 0xDC00, 0xDFFF)) { + // Reassemble the two surrogate pairs into one astral-plane character, per + // the UTF-16 algorithm. + encode_utf8((((last_escaped_codepoint - 0xD800) << 10) + | (codepoint - 0xDC00)) + 0x10000, out); + last_escaped_codepoint = -1; + } else { + encode_utf8(last_escaped_codepoint, out); + last_escaped_codepoint = codepoint; + } + + i += 4; + continue; + } + + encode_utf8(last_escaped_codepoint, out); + last_escaped_codepoint = -1; + + if (ch == 'b') { + out += '\b'; + } else if (ch == 'f') { + out += '\f'; + } else if (ch == 'n') { + out += '\n'; + } else if (ch == 'r') { + out += '\r'; + } else if (ch == 't') { + out += '\t'; + } else if (ch == '"' || ch == '\\' || ch == '/') { + out += ch; + } else { + return fail("invalid escape character " + esc(ch), ""); + } + } + } + + /* parse_number() + * + * Parse a double. + */ + Json parse_number() { + size_t start_pos = i; + + if (str[i] == '-') + i++; + + // Integer part + if (str[i] == '0') { + i++; + if (in_range(str[i], '0', '9')) + return fail("leading 0s not permitted in numbers"); + } else if (in_range(str[i], '1', '9')) { + i++; + while (in_range(str[i], '0', '9')) + i++; + } else { + return fail("invalid " + esc(str[i]) + " in number"); + } + + if (str[i] != '.' && str[i] != 'e' && str[i] != 'E' + && (i - start_pos) <= static_cast(std::numeric_limits::digits10)) { + return std::atoi(str.c_str() + start_pos); + } + + // Decimal part + if (str[i] == '.') { + i++; + if (!in_range(str[i], '0', '9')) + return fail("at least one digit required in fractional part"); + + while (in_range(str[i], '0', '9')) + i++; + } + + // Exponent part + if (str[i] == 'e' || str[i] == 'E') { + i++; + + if (str[i] == '+' || str[i] == '-') + i++; + + if (!in_range(str[i], '0', '9')) + return fail("at least one digit required in exponent"); + + while (in_range(str[i], '0', '9')) + i++; + } + + return std::strtod(str.c_str() + start_pos, nullptr); + } + + /* expect(str, res) + * + * Expect that 'str' starts at the character that was just read. If it does, advance + * the input and return res. If not, flag an error. + */ + Json expect(const string &expected, Json res) { + assert(i != 0); + i--; + if (str.compare(i, expected.length(), expected) == 0) { + i += expected.length(); + return res; + } else { + return fail("parse error: expected " + expected + ", got " + str.substr(i, expected.length())); + } + } + + /* parse_json() + * + * Parse a JSON object. + */ + Json parse_json(int depth) { + if (depth > max_depth) { + return fail("exceeded maximum nesting depth"); + } + + char ch = get_next_token(); + if (failed) + return Json(); + + if (ch == '-' || (ch >= '0' && ch <= '9')) { + i--; + return parse_number(); + } + + if (ch == 't') + return expect("true", true); + + if (ch == 'f') + return expect("false", false); + + if (ch == 'n') + return expect("null", Json()); + + if (ch == '"') + return parse_string(); + + if (ch == '{') { + map data; + ch = get_next_token(); + if (ch == '}') + return data; + + while (1) { + if (ch != '"') + return fail("expected '\"' in object, got " + esc(ch)); + + string key = parse_string(); + if (failed) + return Json(); + + ch = get_next_token(); + if (ch != ':') + return fail("expected ':' in object, got " + esc(ch)); + + data[std::move(key)] = parse_json(depth + 1); + if (failed) + return Json(); + + ch = get_next_token(); + if (ch == '}') + break; + if (ch != ',') + return fail("expected ',' in object, got " + esc(ch)); + + ch = get_next_token(); + } + return data; + } + + if (ch == '[') { + vector data; + ch = get_next_token(); + if (ch == ']') + return data; + + while (1) { + i--; + data.push_back(parse_json(depth + 1)); + if (failed) + return Json(); + + ch = get_next_token(); + if (ch == ']') + break; + if (ch != ',') + return fail("expected ',' in list, got " + esc(ch)); + + ch = get_next_token(); + (void)ch; + } + return data; + } + + return fail("expected value, got " + esc(ch)); + } +}; +}//namespace { + +Json Json::parse(const string &in, string &err, JsonParse strategy) { + JsonParser parser { in, 0, err, false, strategy }; + Json result = parser.parse_json(0); + + // Check for any trailing garbage + parser.consume_garbage(); + if (parser.failed) + return Json(); + if (parser.i != in.size()) + return parser.fail("unexpected trailing " + esc(in[parser.i])); + + return result; +} + +// Documented in json11.hpp +vector Json::parse_multi(const string &in, + std::string::size_type &parser_stop_pos, + string &err, + JsonParse strategy) { + JsonParser parser { in, 0, err, false, strategy }; + parser_stop_pos = 0; + vector json_vec; + while (parser.i != in.size() && !parser.failed) { + json_vec.push_back(parser.parse_json(0)); + if (parser.failed) + break; + + // Check for another object + parser.consume_garbage(); + if (parser.failed) + break; + parser_stop_pos = parser.i; + } + return json_vec; +} + +/* * * * * * * * * * * * * * * * * * * * + * Shape-checking + */ + +bool Json::has_shape(const shape & types, string & err) const { + if (!is_object()) { + err = "expected JSON object, got " + dump(); + return false; + } + + const auto& obj_items = object_items(); + for (auto & item : types) { + const auto it = obj_items.find(item.first); + if (it == obj_items.cend() || it->second.type() != item.second) { + err = "bad type for " + item.first + " in " + dump(); + return false; + } + } + + return true; +} + +} // namespace json11 \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/json11/json11.hpp b/openequivariance/openequivariance/extension/json11/json11.hpp new file mode 100644 index 00000000..965388a7 --- /dev/null +++ b/openequivariance/openequivariance/extension/json11/json11.hpp @@ -0,0 +1,232 @@ +/* json11 + * + * json11 is a tiny JSON library for C++11, providing JSON parsing and serialization. + * + * The core object provided by the library is json11::Json. A Json object represents any JSON + * value: null, bool, number (int or double), string (std::string), array (std::vector), or + * object (std::map). + * + * Json objects act like values: they can be assigned, copied, moved, compared for equality or + * order, etc. There are also helper methods Json::dump, to serialize a Json to a string, and + * Json::parse (static) to parse a std::string as a Json object. + * + * Internally, the various types of Json object are represented by the JsonValue class + * hierarchy. + * + * A note on numbers - JSON specifies the syntax of number formatting but not its semantics, + * so some JSON implementations distinguish between integers and floating-point numbers, while + * some don't. In json11, we choose the latter. Because some JSON implementations (namely + * Javascript itself) treat all numbers as the same type, distinguishing the two leads + * to JSON that will be *silently* changed by a round-trip through those implementations. + * Dangerous! To avoid that risk, json11 stores all numbers as double internally, but also + * provides integer helpers. + * + * Fortunately, double-precision IEEE754 ('double') can precisely store any integer in the + * range +/-2^53, which includes every 'int' on most systems. (Timestamps often use int64 + * or long long to avoid the Y2038K problem; a double storing microseconds since some epoch + * will be exact for +/- 275 years.) + */ + +/* Copyright (c) 2013 Dropbox, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#ifdef _MSC_VER + #if _MSC_VER <= 1800 // VS 2013 + #ifndef noexcept + #define noexcept throw() + #endif + + #ifndef snprintf + #define snprintf _snprintf_s + #endif + #endif +#endif + +namespace json11 { + +enum JsonParse { + STANDARD, COMMENTS +}; + +class JsonValue; + +class Json final { +public: + // Types + enum Type { + NUL, NUMBER, BOOL, STRING, ARRAY, OBJECT + }; + + // Array and object typedefs + typedef std::vector array; + typedef std::map object; + + // Constructors for the various types of JSON value. + Json() noexcept; // NUL + Json(std::nullptr_t) noexcept; // NUL + Json(double value); // NUMBER + Json(int value); // NUMBER + Json(bool value); // BOOL + Json(const std::string &value); // STRING + Json(std::string &&value); // STRING + Json(const char * value); // STRING + Json(const array &values); // ARRAY + Json(array &&values); // ARRAY + Json(const object &values); // OBJECT + Json(object &&values); // OBJECT + + // Implicit constructor: anything with a to_json() function. + template + Json(const T & t) : Json(t.to_json()) {} + + // Implicit constructor: map-like objects (std::map, std::unordered_map, etc) + template ().begin()->first)>::value + && std::is_constructible().begin()->second)>::value, + int>::type = 0> + Json(const M & m) : Json(object(m.begin(), m.end())) {} + + // Implicit constructor: vector-like objects (std::list, std::vector, std::set, etc) + template ().begin())>::value, + int>::type = 0> + Json(const V & v) : Json(array(v.begin(), v.end())) {} + + // This prevents Json(some_pointer) from accidentally producing a bool. Use + // Json(bool(some_pointer)) if that behavior is desired. + Json(void *) = delete; + + // Accessors + Type type() const; + + bool is_null() const { return type() == NUL; } + bool is_number() const { return type() == NUMBER; } + bool is_bool() const { return type() == BOOL; } + bool is_string() const { return type() == STRING; } + bool is_array() const { return type() == ARRAY; } + bool is_object() const { return type() == OBJECT; } + + // Return the enclosed value if this is a number, 0 otherwise. Note that json11 does not + // distinguish between integer and non-integer numbers - number_value() and int_value() + // can both be applied to a NUMBER-typed object. + double number_value() const; + int int_value() const; + + // Return the enclosed value if this is a boolean, false otherwise. + bool bool_value() const; + // Return the enclosed string if this is a string, "" otherwise. + const std::string &string_value() const; + // Return the enclosed std::vector if this is an array, or an empty vector otherwise. + const array &array_items() const; + // Return the enclosed std::map if this is an object, or an empty map otherwise. + const object &object_items() const; + + // Return a reference to arr[i] if this is an array, Json() otherwise. + const Json & operator[](size_t i) const; + // Return a reference to obj[key] if this is an object, Json() otherwise. + const Json & operator[](const std::string &key) const; + + // Serialize. + void dump(std::string &out) const; + std::string dump() const { + std::string out; + dump(out); + return out; + } + + // Parse. If parse fails, return Json() and assign an error message to err. + static Json parse(const std::string & in, + std::string & err, + JsonParse strategy = JsonParse::STANDARD); + static Json parse(const char * in, + std::string & err, + JsonParse strategy = JsonParse::STANDARD) { + if (in) { + return parse(std::string(in), err, strategy); + } else { + err = "null input"; + return nullptr; + } + } + // Parse multiple objects, concatenated or separated by whitespace + static std::vector parse_multi( + const std::string & in, + std::string::size_type & parser_stop_pos, + std::string & err, + JsonParse strategy = JsonParse::STANDARD); + + static inline std::vector parse_multi( + const std::string & in, + std::string & err, + JsonParse strategy = JsonParse::STANDARD) { + std::string::size_type parser_stop_pos; + return parse_multi(in, parser_stop_pos, err, strategy); + } + + bool operator== (const Json &rhs) const; + bool operator< (const Json &rhs) const; + bool operator!= (const Json &rhs) const { return !(*this == rhs); } + bool operator<= (const Json &rhs) const { return !(rhs < *this); } + bool operator> (const Json &rhs) const { return (rhs < *this); } + bool operator>= (const Json &rhs) const { return !(*this < rhs); } + + /* has_shape(types, err) + * + * Return true if this is a JSON object and, for each item in types, has a field of + * the given type. If not, return false and set err to a descriptive message. + */ + typedef std::initializer_list> shape; + bool has_shape(const shape & types, std::string & err) const; + +private: + std::shared_ptr m_ptr; +}; + +// Internal class hierarchy - JsonValue objects are not exposed to users of this API. +class JsonValue { +protected: + friend class Json; + friend class JsonInt; + friend class JsonDouble; + virtual Json::Type type() const = 0; + virtual bool equals(const JsonValue * other) const = 0; + virtual bool less(const JsonValue * other) const = 0; + virtual void dump(std::string &out) const = 0; + virtual double number_value() const; + virtual int int_value() const; + virtual bool bool_value() const; + virtual const std::string &string_value() const; + virtual const Json::array &array_items() const; + virtual const Json &operator[](size_t i) const; + virtual const Json::object &object_items() const; + virtual const Json &operator[](const std::string &key) const; + virtual ~JsonValue() {} +}; + +} // namespace json11 \ No newline at end of file diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 25fec285..91617d94 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -52,6 +52,7 @@ endif() set(HEADER_DIR "src/extension") set(OEQ_JAX_SOURCES src/libjax_tp_jit.cpp + ${HEADER_DIR}/json11/json11.cpp ) set(OEQ_JAX_HEADERS @@ -60,6 +61,7 @@ set(OEQ_JAX_HEADERS ${HEADER_DIR}/util/backend_cuda.hpp ${HEADER_DIR}/util/backend_hip.hpp ${HEADER_DIR}/util/buffer.hpp + ${HEADER_DIR}/json11/json11.hpp ) nanobind_add_module(openequivariance_extjax NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_JAX_HEADERS}) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index ae2035e8..2d48ea0a 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -8,9 +8,11 @@ #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" +#include "json11/json11.hpp" namespace nb = nanobind; namespace ffi = xla::ffi; +using json = json11::Json; #ifdef CUDA_BACKEND #include @@ -131,6 +133,14 @@ void zero_buffer(ffi::AnyBuffer &buffer, stream_t stream) { } #endif +std::unordered_map parse_json_config(const json &j_obj) { + std::unordered_map result; + for (const auto &kv : j_obj.object_items()) { + result[kv.first] = static_cast(kv.second.number_value()); + } + return result; +} + struct KernelProp { int64_t L1_dim, L2_dim, L3_dim, weight_numel; bool shared_weights; @@ -175,39 +185,8 @@ std::unordered_map> conv_cache; std::mutex mut; -std::vector launch_config_keys = { - "num_blocks", - "num_threads", - "smem"}; -std::vector kernel_prop_keys = { - "L1_dim", - "L2_dim", - "L3_dim", - "weight_numel", - "shared_weights", - "opt_level", - "irrep_dtype", - "weight_dtype", - - // Convolution only - "workspace_size", - "deterministic", - "idx_dtype"}; - -std::unordered_map parse_ffi_dict(ffi::Dictionary &dict, const std::vector &keys) { - std::unordered_map result; - for (const auto &key : keys) { - result[key] = dict.get(key).value(); - } - return result; -} - std::pair*, KernelProp> - compile_tp_with_caching(std::string_view kernel, - ffi::Dictionary forward_config, - ffi::Dictionary backward_config, - ffi::Dictionary double_backward_config, - ffi::Dictionary kernel_prop, + compile_tp_with_caching(std::string_view json_payload, int64_t hash, bool is_convolution) { @@ -215,12 +194,21 @@ std::pair*, KernelProp> const std::lock_guard lock(mut); auto it = tp_cache.find(hash); if (it == tp_cache.end()) { - auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); + std::string err; + json root = json::parse(std::string(json_payload), err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + auto jit_tp_impl = std::make_unique>( - std::string(kernel), - parse_ffi_dict(forward_config, launch_config_keys), - parse_ffi_dict(backward_config, launch_config_keys), - parse_ffi_dict(double_backward_config, launch_config_keys), + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, kernel_prop_map); tp_cache.insert({hash, std::make_pair(std::move(jit_tp_impl), @@ -232,11 +220,7 @@ std::pair*, KernelProp> } std::pair*, KernelProp> - compile_conv_with_caching(std::string_view kernel, - ffi::Dictionary forward_config, - ffi::Dictionary backward_config, - ffi::Dictionary double_backward_config, - ffi::Dictionary kernel_prop, + compile_conv_with_caching(std::string_view json_payload, int64_t hash, bool is_convolution) { @@ -244,12 +228,21 @@ std::pair*, KernelProp> const std::lock_guard lock(mut); auto it = conv_cache.find(hash); if (it == conv_cache.end()) { - auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); + std::string err; + json root = json::parse(std::string(json_payload), err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + auto jit_conv_impl = std::make_unique>( - std::string(kernel), - parse_ffi_dict(forward_config, launch_config_keys), - parse_ffi_dict(backward_config, launch_config_keys), - parse_ffi_dict(double_backward_config, launch_config_keys), + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, kernel_prop_map); conv_cache.insert({hash, std::make_pair(std::move(jit_conv_impl), @@ -301,11 +294,11 @@ ffi::Error tp_forward_impl( ffi::AnyBuffer W, ffi::Result L3_out, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + kernel_json, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); @@ -336,11 +329,11 @@ ffi::Error tp_backward_impl( ffi::Result L2_grad, ffi::Result W_grad, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + kernel_json, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -386,11 +379,11 @@ ffi::Error tp_double_backward_impl( ffi::Result W_grad, ffi::Result L3_dgrad, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + kernel_json, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -435,7 +428,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled @@ -450,7 +443,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -469,7 +462,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -484,11 +477,11 @@ ffi::Error conv_forward_impl( ffi::AnyBuffer transpose_perm, ffi::Result L3_out, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + kernel_json, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -539,11 +532,11 @@ ffi::Error conv_backward_impl( ffi::AnyBuffer workspace, ffi::AnyBuffer transpose_perm, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + kernel_json, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -608,11 +601,11 @@ ffi::Error conv_double_backward_impl( ffi::AnyBuffer workspace, ffi::AnyBuffer transpose_perm, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + kernel_json, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -689,7 +682,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -708,7 +701,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -731,7 +724,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -764,10 +757,4 @@ NB_MODULE(openequivariance_extjax, m) { .def("start", &GPUTimer::start) .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) .def("clear_L2_cache", &GPUTimer::clear_L2_cache); - - /*nb::class_>(m, "DeviceBuffer") - .def(nb::init()) - .def(nb::init()) - .def("copy_to_host", &PyDeviceBuffer::copy_to_host) - .def("data_ptr", &PyDeviceBuffer::data_ptr);*/ -} +} \ No newline at end of file From 7f49104c7969e9d9995147a2da54fcf9ee4582ac Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 27 Jan 2026 20:51:48 -0800 Subject: [PATCH 02/18] Debugging in progress. --- .../core/ComputationSchedule.py | 2 +- .../openequivariance/jax/TensorProduct.py | 335 ++++++++++++++---- .../openequivariance/jax/__init__.py | 8 +- .../openequivariance/jax/utils.py | 70 +++- openequivariance_extjax/src/libjax_tp_jit.cpp | 4 +- tests/example_test.py | 34 +- 6 files changed, 341 insertions(+), 112 deletions(-) diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index 9c3884c9..c9765d0d 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -320,7 +320,7 @@ def __init__(self, num_blocks, num_threads, warp_size, smem): self.num_blocks = num_blocks self.num_threads = num_threads self.warp_size = warp_size - self.smem = smem + self.smem = int(smem) class ComputationSchedule: diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index 05e4b097..4dd7aec3 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -5,112 +5,299 @@ from openequivariance.jax import extlib from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollTP import LoopUnrollTP -from openequivariance.core.utils import hash_attributes -from openequivariance.jax.utils import reorder_jax +from openequivariance.jax.utils import reorder_jax, trace +import json +import jax +import jax.numpy as jnp +from jax.extend import core +from jax.interpreters import mlir, ad -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) -def forward(X, Y, W, L3_dim, irrep_dtype, attrs): - forward_call = jax.ffi.ffi_call( - "tp_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) - ) - return forward_call(X, Y, W, **attrs) +# ============================================================================== +# 1. Forward Primitive +# ============================================================================== +tp_fwd_p = core.Primitive("tp_fwd") -def forward_fwd(X, Y, W, L3_dim, irrep_dtype, attrs): - return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W) +def tp_fwd_impl(X, Y, W, *, L3_dim, kernel, hash): + irrep_dtype = X.dtype + out_shape = jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) + call = jax.ffi.ffi_call("tp_forward", out_shape) + return call(X, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) +def tp_fwd_abstract_eval(X, Y, W, *, L3_dim, kernel, hash): + return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) -def forward_bwd(L3_dim, irrep_dtype, attrs, inputs, dZ): - X, Y, W = inputs - return backward(X, Y, W, dZ, irrep_dtype, attrs) +tp_fwd_p.def_impl(tp_fwd_impl) +tp_fwd_p.def_abstract_eval(tp_fwd_abstract_eval) +mlir.register_lowering(tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="cuda") +mlir.register_lowering(tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="rocm") -forward.defvjp(forward_fwd, forward_bwd) +# ============================================================================== +# 2. Backward Primitive +# ============================================================================== +tp_bwd_p = core.Primitive("tp_bwd") -@partial(jax.custom_vjp, nondiff_argnums=(4, 5)) -def backward(X, Y, W, dZ, irrep_dtype, attrs): - backward_call = jax.ffi.ffi_call( - "tp_backward", - ( - jax.ShapeDtypeStruct(X.shape, irrep_dtype), - jax.ShapeDtypeStruct(Y.shape, irrep_dtype), - jax.ShapeDtypeStruct(W.shape, irrep_dtype), - ), +def tp_bwd_impl(X, Y, W, dZ, *, kernel, hash): + print("hihi, Am here") + irrep_dtype = X.dtype + out_shapes = ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + ) + call = jax.ffi.ffi_call("tp_backward", out_shapes) + + print("MADE IT HERE!") + print(X.shape, Y.shape, W.shape, dZ.shape) + print(type(kernel)) + print(hash) + print(type(X)) + print(type(Y)) + print(type(W)) + print(type(dZ)) + + result = call(X, Y, W, dZ, kernel=kernel, hash=hash) + print("Made call successfully") + print(type(result)) + return result + +def tp_bwd_abstract_eval(X, Y, W, dZ, *, kernel, hash): + irrep_dtype = X.dtype + return ( + core.ShapedArray(X.shape, irrep_dtype), + core.ShapedArray(Y.shape, irrep_dtype), + core.ShapedArray(W.shape, irrep_dtype), ) - return backward_call(X, Y, W, dZ, **attrs) +tp_bwd_p.def_impl(tp_bwd_impl) +tp_bwd_p.def_abstract_eval(tp_bwd_abstract_eval) +mlir.register_lowering(tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="cuda") +mlir.register_lowering(tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="rocm") + + +# ============================================================================== +# 3. Double Backward Primitive +# ============================================================================== + +tp_dbwd_p = core.Primitive("tp_dbwd") -def backward_fwd(X, Y, W, dZ, irrep_dtype, attrs): - return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ) +def tp_dbwd_impl(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): + irrep_dtype = X.dtype + out_shapes = ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), + ) + call = jax.ffi.ffi_call("tp_double_backward", out_shapes) + return call(X, Y, W, dZ, ddX, ddY, ddW, kernel=kernel, hash=hash) + +def tp_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), + jax.core.ShapedArray(dZ.shape, irrep_dtype), + ) +tp_dbwd_p.def_impl(tp_dbwd_impl) +tp_dbwd_p.def_abstract_eval(tp_dbwd_abstract_eval) +mlir.register_lowering(tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="cuda") +mlir.register_lowering(tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="rocm") -def backward_bwd(irrep_dtype, attrs, inputs, derivs): - X, Y, W, dZ = inputs - ddX, ddY, ddW = derivs - return double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs) +# ============================================================================== +# 4. Forward JVP Primitive Definition +# ============================================================================== +tp_fwd_jvp_p = core.Primitive("tp_fwd_jvp") -backward.defvjp(backward_fwd, backward_bwd) +def tp_fwd_jvp_impl(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): + term1 = tp_fwd_p.bind(dX, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) + term2 = tp_fwd_p.bind(X, dY, W, L3_dim=L3_dim, kernel=kernel, hash=hash) + term3 = tp_fwd_p.bind(X, Y, dW, L3_dim=L3_dim, kernel=kernel, hash=hash) + return term1 + term2 + term3 +def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): + return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) -@partial(jax.custom_vjp, nondiff_argnums=(7, 8)) -def double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs): - double_backward_call = jax.ffi.ffi_call( - "tp_double_backward", - ( - jax.ShapeDtypeStruct(X.shape, irrep_dtype), - jax.ShapeDtypeStruct(Y.shape, irrep_dtype), - jax.ShapeDtypeStruct(W.shape, irrep_dtype), - jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), - ), +tp_fwd_jvp_p.def_impl(tp_fwd_jvp_impl) +tp_fwd_jvp_p.def_abstract_eval(tp_fwd_jvp_abstract_eval) + + +# ============================================================================== +# 5. Transpose Rule (Implicit VJP) +# ============================================================================== + +def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): + # This transpose corresponds to the Backward pass. + # We assert that we are differentiating with respect to the input tangents. + assert ad.is_undefined_primal(dX) + assert ad.is_undefined_primal(dY) + assert ad.is_undefined_primal(dW) + + # If the primals X, Y, W are being differentiated (undefined), we replace + # them with zeros for the purpose of this kernel call. + if ad.is_undefined_primal(X): + X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): + Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): + W = jnp.zeros(W.aval.shape, W.aval.dtype) + + print("STARTED HERE!") + grad_X, grad_Y, grad_W = tp_bwd_p.bind(X, Y, W, ct, kernel=kernel, hash=hash) + print("GOT OUT HERE!") + + # Return gradients for (X, Y, W, dX, dY, dW). + # Primals get None, tangents get the computed gradients. + return (None, None, None, grad_X, grad_Y, grad_W) + +ad.primitive_transposes[tp_fwd_jvp_p] = tp_fwd_jvp_transpose + +def ensure_array(tan, primal): + if type(tan) is ad.Zero: + return jnp.zeros_like(primal) + return tan + +# ============================================================================== +# 6. JVP Rule for Original Forward Primitive +# ============================================================================== + +def tp_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): + X, Y, W = primals + dX, dY, dW = tangents + + dX = ensure_array(dX, X) + dY = ensure_array(dY, Y) + dW = ensure_array(dW, W) + + out_primal = tp_fwd_p.bind(X, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) + out_tangent = tp_fwd_jvp_p.bind(X, Y, W, dX, dY, dW, L3_dim=L3_dim, kernel=kernel, hash=hash) + + return out_primal, out_tangent + +ad.primitive_jvps[tp_fwd_p] = tp_fwd_jvp_rule + + +# ============================================================================== +# 7. JVP Rule for Forward JVP Primitive (Higher Order) +# ============================================================================== + +def tp_fwd_jvp_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): + tangents_clean = [] + for t, p in zip(tangents, primals): + if type(t) is ad.Zero: + tangents_clean.append(jnp.zeros_like(p)) + else: + tangents_clean.append(t) + tangents_clean = tuple(tangents_clean) + + def func(x, y, w, dx, dy, dw): + return tp_fwd_jvp_impl(x, y, w, dx, dy, dw, L3_dim=L3_dim, kernel=kernel, hash=hash) + + return jax.jvp(func, primals, tangents_clean) + +ad.primitive_jvps[tp_fwd_jvp_p] = tp_fwd_jvp_jvp_rule + + +# ============================================================================== +# 8. Backward JVP Primitive Definition +# ============================================================================== + +tp_bwd_jvp_p = core.Primitive("tp_bwd_jvp") + +def tp_bwd_jvp_impl(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): + term_dZ = tp_bwd_p.bind(X, Y, W, tdZ, kernel=kernel, hash=hash) + term_X = tp_bwd_p.bind(tX, Y, W, dZ, kernel=kernel, hash=hash) + term_Y = tp_bwd_p.bind(X, tY, W, dZ, kernel=kernel, hash=hash) + term_W = tp_bwd_p.bind(X, Y, tW, dZ, kernel=kernel, hash=hash) + + out_dX = term_dZ[0] + term_Y[0] + term_W[0] + out_dY = term_dZ[1] + term_X[1] + term_W[1] + out_dW = term_dZ[2] + term_X[2] + term_Y[2] + + return out_dX, out_dY, out_dW + +def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), ) - return double_backward_call(X, Y, W, dZ, ddX, ddY, ddW, **attrs) +tp_bwd_jvp_p.def_impl(tp_bwd_jvp_impl) +tp_bwd_jvp_p.def_abstract_eval(tp_bwd_jvp_abstract_eval) + + +# ============================================================================== +# 9. Transpose Rule for Backward JVP +# ============================================================================== + +def tp_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): + ddX, ddY, ddW = ct + + assert ad.is_undefined_primal(tX) + assert ad.is_undefined_primal(tY) + assert ad.is_undefined_primal(tW) + assert ad.is_undefined_primal(tdZ) + + if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) + if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) -def double_backward_fwd(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs): - out = double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs) - return out, (X, Y, W, dZ, ddX, ddY, ddW) + g_X, g_Y, g_W, g_dZ = tp_dbwd_p.bind(X, Y, W, dZ, ddX, ddY, ddW, kernel=kernel, hash=hash) + return (None, None, None, None, g_X, g_Y, g_W, g_dZ) -def zeros_like(x): - return jnp.zeros_like(x) +ad.primitive_transposes[tp_bwd_jvp_p] = tp_bwd_jvp_transpose -def triple_backward(irrep_dtype, attrs, residuals, tangent_outputs): - X, Y, W, dZ, ddX, ddY, ddW = residuals - t_dX, t_dY, t_dW, t_ddZ = tangent_outputs +# ============================================================================== +# 10. JVP Rule for Backward JVP Primitive (Higher Order) +# ============================================================================== - op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W)) - g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, irrep_dtype, attrs) +def tp_bwd_jvp_jvp_rule(primals, tangents, *, kernel, hash): + tangents_clean = [] + for t, p in zip(tangents, primals): + if type(t) is ad.Zero: + tangents_clean.append(jnp.zeros_like(p)) + else: + tangents_clean.append(t) + tangents_clean = tuple(tangents_clean) - op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW)) - g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, irrep_dtype, attrs) + def func(x, y, w, dz, tx, ty, tw, tdz): + return tp_bwd_jvp_impl(x, y, w, dz, tx, ty, tw, tdz, kernel=kernel, hash=hash) - op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW) - g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, irrep_dtype, attrs) + return jax.jvp(func, primals, tangents_clean) - op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW) - g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, irrep_dtype, attrs) +ad.primitive_jvps[tp_bwd_jvp_p] = tp_bwd_jvp_jvp_rule - g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, irrep_dtype, attrs) - g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, irrep_dtype, attrs) - g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, irrep_dtype, attrs) - grad_X = g2_X + g4_X + g6_X + g7_X - grad_Y = g2_Y + g3_Y + g5_Y + g7_Y - grad_W = g1_W + g3_W + g4_W + g5_W + g6_W - grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ +# ============================================================================== +# 11. JVP Rule for Original Backward Primitive +# ============================================================================== - grad_ddX = g1_ddX + g3_ddX + g5_ddX - grad_ddY = g1_ddY + g4_ddY + g6_ddY - grad_ddW = g2_ddW + g7_ddW +def tp_bwd_jvp_rule(primals, tangents, *, kernel, hash): + X, Y, W, dZ = primals + tX, tY, tW, tdZ = tangents - return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW + tX, tY, tW, tdZ = [ensure_array(t, p) for t, p in zip(tangents, primals)] + print("AM HERE!") + print(type(kernel)) + print(hash) + out_primal = tp_bwd_p.bind(X, Y, W, dZ, kernel=kernel, hash=hash) + #out_tangent = tp_bwd_jvp_p.bind(X, Y, W, dZ, tX, tY, tW, tdZ, kernel=kernel, hash=hash) + #return out_primal, out_tangent + return out_primal, None -double_backward.defvjp(double_backward_fwd, triple_backward) +ad.primitive_jvps[tp_bwd_p] = tp_bwd_jvp_rule class TensorProduct(LoopUnrollTP): @@ -124,14 +311,14 @@ def __init__(self, problem: TPProblem): dp = extlib.DeviceProp(0) super().__init__(problem, dp, extlib.postprocess_kernel, torch_op=False) - self.attrs = { + self.kernel = json.dumps({ "kernel": self.jit_kernel, "forward_config": vars(self.forward_schedule.launch_config), "backward_config": vars(self.backward_schedule.launch_config), "double_backward_config": vars(self.double_backward_schedule.launch_config), "kernel_prop": self.kernelProp, - } - hash_attributes(self.attrs) + }) + self.hash = self.kernel.__hash__() self.weight_numel = problem.weight_numel self.L3_dim = self.config.irreps_out.dim @@ -139,7 +326,7 @@ def __init__(self, problem: TPProblem): def forward( self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray ) -> jax.numpy.ndarray: - return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs) + return tp_fwd_p.bind(X, Y, W, L3_dim=self.L3_dim, kernel=self.kernel, hash=self.hash) def __call__( self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray diff --git a/openequivariance/openequivariance/jax/__init__.py b/openequivariance/openequivariance/jax/__init__.py index 410e5dbf..6e1a9087 100644 --- a/openequivariance/openequivariance/jax/__init__.py +++ b/openequivariance/openequivariance/jax/__init__.py @@ -1,6 +1,6 @@ from openequivariance.jax.TensorProduct import TensorProduct as TensorProduct -from openequivariance.jax.TensorProductConv import ( - TensorProductConv as TensorProductConv, -) +#from openequivariance.jax.TensorProductConv import ( +# TensorProductConv as TensorProductConv, +#) -__all__ = ["TensorProduct", "TensorProductConv"] +__all__ = ["TensorProduct"] #"TensorProductConv"] diff --git a/openequivariance/openequivariance/jax/utils.py b/openequivariance/openequivariance/jax/utils.py index ae15d1a6..00bcee7a 100644 --- a/openequivariance/openequivariance/jax/utils.py +++ b/openequivariance/openequivariance/jax/utils.py @@ -1,7 +1,8 @@ import jax import jax.numpy as jnp import numpy as np - +import functools +import traceback def reorder_jax_helper(schedule, weights_in, direction, has_batch_dim): assert direction in ["forward", "backward"] @@ -61,3 +62,70 @@ def reorder_jax(schedule, weights_in, direction, has_batch_dim): return reorder_jax_helper(schedule, weights_in, direction, has_batch_dim) else: return reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim) + + +_indentation = 0 +def _trace(msg=None): + """Print a message at current indentation.""" + if msg is not None: + print(" " * _indentation + msg) + +def _trace_indent(msg=None): + """Print a message and then indent the rest.""" + global _indentation + _trace(msg) + _indentation = 1 + _indentation + +def _trace_unindent(msg=None): + """Unindent then print a message.""" + global _indentation + _indentation = _indentation - 1 + _trace(msg) + +def trace(name): + """A decorator for functions to trace arguments and results.""" + + def trace_func(func): # pylint: disable=missing-docstring + def pp(v): + """Print certain values more succinctly""" + vtype = str(type(v)) + if "jax._src.xla_bridge._JaxComputationBuilder" in vtype: + return "" + elif "jaxlib._jax_.XlaOp" in vtype: + return "".format(id(v)) + elif ("partial_eval.JaxprTracer" in vtype or + "batching.BatchTracer" in vtype or + "ad.JVPTracer" in vtype): + return "Traced<{}>".format(v.aval) + elif isinstance(v, tuple): + return "({})".format(pp_values(v)) + else: + return str(v) + def pp_values(args): + return ", ".join([pp(arg) for arg in args]) + + @functools.wraps(func) + def func_wrapper(*args): + _trace_indent("call {}({})".format(name, pp_values(args))) + res = func(*args) + _trace_unindent("|<- {} = {}".format(name, pp(res))) + return res + + return func_wrapper + + return trace_func + +class expectNotImplementedError(object): + """Context manager to check for NotImplementedError.""" + def __enter__(self): pass + def __exit__(self, type, value, tb): + global _indentation + _indentation = 0 + if type is NotImplementedError: + print("\nFound expected exception:") + traceback.print_exc(limit=3) + return True + elif type is None: # No exception + assert False, "Expected NotImplementedError" + else: + return False \ No newline at end of file diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 2d48ea0a..cb975329 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -293,7 +293,8 @@ ffi::Error tp_forward_impl( ffi::AnyBuffer L2_in, ffi::AnyBuffer W, ffi::Result L3_out, - stream_t stream, + stream_t stream, + int64_t L3_dim, std::string_view kernel_json, int64_t hash) { @@ -428,6 +429,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Ret() .Ctx>() + .Attr("L3_dim") .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled diff --git a/tests/example_test.py b/tests/example_test.py index e8d23cb7..2ab6f6ab 100644 --- a/tests/example_test.py +++ b/tests/example_test.py @@ -132,35 +132,7 @@ def test_tutorial_jax(with_jax): Z = tp_fast(X, Y, W) print(jax.numpy.linalg.norm(Z)) - edge_index = jax.numpy.array( - [ - [0, 1, 1, 2], - [1, 0, 2, 1], - ], - dtype=jax.numpy.int32, # NOTE: This int32, not int64 - ) - - node_ct, nonzero_ct = 3, 4 - X = jax.random.uniform( - key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32 - ) - Y = jax.random.uniform( - key, - shape=(nonzero_ct, Y_ir.dim), - minval=0.0, - maxval=1.0, - dtype=jax.numpy.float32, - ) - W = jax.random.uniform( - key, - shape=(nonzero_ct, problem.weight_numel), - minval=0.0, - maxval=1.0, - dtype=jax.numpy.float32, + result = jax.vjp(lambda X, Y, W: tp_fast(X, Y, W), X, Y, W)[1]( + jax.numpy.ones_like(Z) ) - tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) - Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) - print(jax.numpy.linalg.norm(Z)) - - jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2)) - print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1]))) + print(result) \ No newline at end of file From b78a4ad2dd7ca6b0cf8e8627ee0b3d41df5665d1 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 27 Jan 2026 22:41:46 -0800 Subject: [PATCH 03/18] Bound the double backward pass. --- .../openequivariance/jax/TensorProduct.py | 27 +++++-------------- tests/conftest.py | 2 +- 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index 4dd7aec3..4dfcf5fe 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -39,29 +39,18 @@ def tp_fwd_abstract_eval(X, Y, W, *, L3_dim, kernel, hash): # ============================================================================== tp_bwd_p = core.Primitive("tp_bwd") +tp_bwd_p.multiple_results = True def tp_bwd_impl(X, Y, W, dZ, *, kernel, hash): - print("hihi, Am here") irrep_dtype = X.dtype out_shapes = ( jax.ShapeDtypeStruct(X.shape, irrep_dtype), jax.ShapeDtypeStruct(Y.shape, irrep_dtype), jax.ShapeDtypeStruct(W.shape, irrep_dtype), ) - call = jax.ffi.ffi_call("tp_backward", out_shapes) - - print("MADE IT HERE!") - print(X.shape, Y.shape, W.shape, dZ.shape) - print(type(kernel)) - print(hash) - print(type(X)) - print(type(Y)) - print(type(W)) - print(type(dZ)) + call = jax.ffi.ffi_call("tp_backward", out_shapes) result = call(X, Y, W, dZ, kernel=kernel, hash=hash) - print("Made call successfully") - print(type(result)) return result def tp_bwd_abstract_eval(X, Y, W, dZ, *, kernel, hash): @@ -83,6 +72,7 @@ def tp_bwd_abstract_eval(X, Y, W, dZ, *, kernel, hash): # ============================================================================== tp_dbwd_p = core.Primitive("tp_dbwd") +tp_dbwd_p.multiple_results = True def tp_dbwd_impl(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): irrep_dtype = X.dtype @@ -148,9 +138,7 @@ def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) - print("STARTED HERE!") grad_X, grad_Y, grad_W = tp_bwd_p.bind(X, Y, W, ct, kernel=kernel, hash=hash) - print("GOT OUT HERE!") # Return gradients for (X, Y, W, dX, dY, dW). # Primals get None, tangents get the computed gradients. @@ -209,6 +197,7 @@ def func(x, y, w, dx, dy, dw): # ============================================================================== tp_bwd_jvp_p = core.Primitive("tp_bwd_jvp") +tp_bwd_jvp_p.multiple_results = True def tp_bwd_jvp_impl(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): term_dZ = tp_bwd_p.bind(X, Y, W, tdZ, kernel=kernel, hash=hash) @@ -288,14 +277,10 @@ def tp_bwd_jvp_rule(primals, tangents, *, kernel, hash): tX, tY, tW, tdZ = tangents tX, tY, tW, tdZ = [ensure_array(t, p) for t, p in zip(tangents, primals)] - print("AM HERE!") - print(type(kernel)) - print(hash) out_primal = tp_bwd_p.bind(X, Y, W, dZ, kernel=kernel, hash=hash) - #out_tangent = tp_bwd_jvp_p.bind(X, Y, W, dZ, tX, tY, tW, tdZ, kernel=kernel, hash=hash) + out_tangent = tp_bwd_jvp_p.bind(X, Y, W, dZ, tX, tY, tW, tdZ, kernel=kernel, hash=hash) - #return out_primal, out_tangent - return out_primal, None + return out_primal, out_tangent ad.primitive_jvps[tp_bwd_p] = tp_bwd_jvp_rule diff --git a/tests/conftest.py b/tests/conftest.py index 0e7098e0..ad6cb9fa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ os.environ["JAX_ENABLE_X64"] = "True" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" - +os.environ["JAX_TRACEBACK_FILTERING"] = "off" def pytest_addoption(parser): parser.addoption( From d395e6b150e5830bdda9a8883257dfa41a3fd40b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 27 Jan 2026 23:37:14 -0800 Subject: [PATCH 04/18] Ready to try conv next. --- .../openequivariance/jax/TensorProduct.py | 280 +----------------- .../openequivariance/jax/jvp/conv_prim.py | 0 .../openequivariance/jax/jvp/tp_prim.py | 279 +++++++++++++++++ 3 files changed, 281 insertions(+), 278 deletions(-) create mode 100644 openequivariance/openequivariance/jax/jvp/conv_prim.py create mode 100644 openequivariance/openequivariance/jax/jvp/tp_prim.py diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index 4dfcf5fe..ca8d0999 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -5,285 +5,9 @@ from openequivariance.jax import extlib from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollTP import LoopUnrollTP -from openequivariance.jax.utils import reorder_jax, trace - +from openequivariance.jax.utils import reorder_jax +from openequivariance.jax.jvp.tp_prim import tp_fwd_p import json -import jax -import jax.numpy as jnp -from jax.extend import core -from jax.interpreters import mlir, ad - -# ============================================================================== -# 1. Forward Primitive -# ============================================================================== - -tp_fwd_p = core.Primitive("tp_fwd") - -def tp_fwd_impl(X, Y, W, *, L3_dim, kernel, hash): - irrep_dtype = X.dtype - out_shape = jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) - call = jax.ffi.ffi_call("tp_forward", out_shape) - return call(X, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) - -def tp_fwd_abstract_eval(X, Y, W, *, L3_dim, kernel, hash): - return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) - -tp_fwd_p.def_impl(tp_fwd_impl) -tp_fwd_p.def_abstract_eval(tp_fwd_abstract_eval) -mlir.register_lowering(tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="cuda") -mlir.register_lowering(tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="rocm") - - -# ============================================================================== -# 2. Backward Primitive -# ============================================================================== - -tp_bwd_p = core.Primitive("tp_bwd") -tp_bwd_p.multiple_results = True - -def tp_bwd_impl(X, Y, W, dZ, *, kernel, hash): - irrep_dtype = X.dtype - out_shapes = ( - jax.ShapeDtypeStruct(X.shape, irrep_dtype), - jax.ShapeDtypeStruct(Y.shape, irrep_dtype), - jax.ShapeDtypeStruct(W.shape, irrep_dtype), - ) - - call = jax.ffi.ffi_call("tp_backward", out_shapes) - result = call(X, Y, W, dZ, kernel=kernel, hash=hash) - return result - -def tp_bwd_abstract_eval(X, Y, W, dZ, *, kernel, hash): - irrep_dtype = X.dtype - return ( - core.ShapedArray(X.shape, irrep_dtype), - core.ShapedArray(Y.shape, irrep_dtype), - core.ShapedArray(W.shape, irrep_dtype), - ) - -tp_bwd_p.def_impl(tp_bwd_impl) -tp_bwd_p.def_abstract_eval(tp_bwd_abstract_eval) -mlir.register_lowering(tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="cuda") -mlir.register_lowering(tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="rocm") - - -# ============================================================================== -# 3. Double Backward Primitive -# ============================================================================== - -tp_dbwd_p = core.Primitive("tp_dbwd") -tp_dbwd_p.multiple_results = True - -def tp_dbwd_impl(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): - irrep_dtype = X.dtype - out_shapes = ( - jax.ShapeDtypeStruct(X.shape, irrep_dtype), - jax.ShapeDtypeStruct(Y.shape, irrep_dtype), - jax.ShapeDtypeStruct(W.shape, irrep_dtype), - jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), - ) - call = jax.ffi.ffi_call("tp_double_backward", out_shapes) - return call(X, Y, W, dZ, ddX, ddY, ddW, kernel=kernel, hash=hash) - -def tp_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): - irrep_dtype = X.dtype - return ( - jax.core.ShapedArray(X.shape, irrep_dtype), - jax.core.ShapedArray(Y.shape, irrep_dtype), - jax.core.ShapedArray(W.shape, irrep_dtype), - jax.core.ShapedArray(dZ.shape, irrep_dtype), - ) - -tp_dbwd_p.def_impl(tp_dbwd_impl) -tp_dbwd_p.def_abstract_eval(tp_dbwd_abstract_eval) -mlir.register_lowering(tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="cuda") -mlir.register_lowering(tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="rocm") - -# ============================================================================== -# 4. Forward JVP Primitive Definition -# ============================================================================== - -tp_fwd_jvp_p = core.Primitive("tp_fwd_jvp") - -def tp_fwd_jvp_impl(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): - term1 = tp_fwd_p.bind(dX, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) - term2 = tp_fwd_p.bind(X, dY, W, L3_dim=L3_dim, kernel=kernel, hash=hash) - term3 = tp_fwd_p.bind(X, Y, dW, L3_dim=L3_dim, kernel=kernel, hash=hash) - return term1 + term2 + term3 - -def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): - return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) - -tp_fwd_jvp_p.def_impl(tp_fwd_jvp_impl) -tp_fwd_jvp_p.def_abstract_eval(tp_fwd_jvp_abstract_eval) - - -# ============================================================================== -# 5. Transpose Rule (Implicit VJP) -# ============================================================================== - -def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): - # This transpose corresponds to the Backward pass. - # We assert that we are differentiating with respect to the input tangents. - assert ad.is_undefined_primal(dX) - assert ad.is_undefined_primal(dY) - assert ad.is_undefined_primal(dW) - - # If the primals X, Y, W are being differentiated (undefined), we replace - # them with zeros for the purpose of this kernel call. - if ad.is_undefined_primal(X): - X = jnp.zeros(X.aval.shape, X.aval.dtype) - if ad.is_undefined_primal(Y): - Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) - if ad.is_undefined_primal(W): - W = jnp.zeros(W.aval.shape, W.aval.dtype) - - grad_X, grad_Y, grad_W = tp_bwd_p.bind(X, Y, W, ct, kernel=kernel, hash=hash) - - # Return gradients for (X, Y, W, dX, dY, dW). - # Primals get None, tangents get the computed gradients. - return (None, None, None, grad_X, grad_Y, grad_W) - -ad.primitive_transposes[tp_fwd_jvp_p] = tp_fwd_jvp_transpose - -def ensure_array(tan, primal): - if type(tan) is ad.Zero: - return jnp.zeros_like(primal) - return tan - -# ============================================================================== -# 6. JVP Rule for Original Forward Primitive -# ============================================================================== - -def tp_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): - X, Y, W = primals - dX, dY, dW = tangents - - dX = ensure_array(dX, X) - dY = ensure_array(dY, Y) - dW = ensure_array(dW, W) - - out_primal = tp_fwd_p.bind(X, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) - out_tangent = tp_fwd_jvp_p.bind(X, Y, W, dX, dY, dW, L3_dim=L3_dim, kernel=kernel, hash=hash) - - return out_primal, out_tangent - -ad.primitive_jvps[tp_fwd_p] = tp_fwd_jvp_rule - - -# ============================================================================== -# 7. JVP Rule for Forward JVP Primitive (Higher Order) -# ============================================================================== - -def tp_fwd_jvp_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): - tangents_clean = [] - for t, p in zip(tangents, primals): - if type(t) is ad.Zero: - tangents_clean.append(jnp.zeros_like(p)) - else: - tangents_clean.append(t) - tangents_clean = tuple(tangents_clean) - - def func(x, y, w, dx, dy, dw): - return tp_fwd_jvp_impl(x, y, w, dx, dy, dw, L3_dim=L3_dim, kernel=kernel, hash=hash) - - return jax.jvp(func, primals, tangents_clean) - -ad.primitive_jvps[tp_fwd_jvp_p] = tp_fwd_jvp_jvp_rule - - -# ============================================================================== -# 8. Backward JVP Primitive Definition -# ============================================================================== - -tp_bwd_jvp_p = core.Primitive("tp_bwd_jvp") -tp_bwd_jvp_p.multiple_results = True - -def tp_bwd_jvp_impl(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): - term_dZ = tp_bwd_p.bind(X, Y, W, tdZ, kernel=kernel, hash=hash) - term_X = tp_bwd_p.bind(tX, Y, W, dZ, kernel=kernel, hash=hash) - term_Y = tp_bwd_p.bind(X, tY, W, dZ, kernel=kernel, hash=hash) - term_W = tp_bwd_p.bind(X, Y, tW, dZ, kernel=kernel, hash=hash) - - out_dX = term_dZ[0] + term_Y[0] + term_W[0] - out_dY = term_dZ[1] + term_X[1] + term_W[1] - out_dW = term_dZ[2] + term_X[2] + term_Y[2] - - return out_dX, out_dY, out_dW - -def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): - irrep_dtype = X.dtype - return ( - jax.core.ShapedArray(X.shape, irrep_dtype), - jax.core.ShapedArray(Y.shape, irrep_dtype), - jax.core.ShapedArray(W.shape, irrep_dtype), - ) - -tp_bwd_jvp_p.def_impl(tp_bwd_jvp_impl) -tp_bwd_jvp_p.def_abstract_eval(tp_bwd_jvp_abstract_eval) - - -# ============================================================================== -# 9. Transpose Rule for Backward JVP -# ============================================================================== - -def tp_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): - ddX, ddY, ddW = ct - - assert ad.is_undefined_primal(tX) - assert ad.is_undefined_primal(tY) - assert ad.is_undefined_primal(tW) - assert ad.is_undefined_primal(tdZ) - - if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) - if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) - if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) - if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) - - g_X, g_Y, g_W, g_dZ = tp_dbwd_p.bind(X, Y, W, dZ, ddX, ddY, ddW, kernel=kernel, hash=hash) - - return (None, None, None, None, g_X, g_Y, g_W, g_dZ) - -ad.primitive_transposes[tp_bwd_jvp_p] = tp_bwd_jvp_transpose - - -# ============================================================================== -# 10. JVP Rule for Backward JVP Primitive (Higher Order) -# ============================================================================== - -def tp_bwd_jvp_jvp_rule(primals, tangents, *, kernel, hash): - tangents_clean = [] - for t, p in zip(tangents, primals): - if type(t) is ad.Zero: - tangents_clean.append(jnp.zeros_like(p)) - else: - tangents_clean.append(t) - tangents_clean = tuple(tangents_clean) - - def func(x, y, w, dz, tx, ty, tw, tdz): - return tp_bwd_jvp_impl(x, y, w, dz, tx, ty, tw, tdz, kernel=kernel, hash=hash) - - return jax.jvp(func, primals, tangents_clean) - -ad.primitive_jvps[tp_bwd_jvp_p] = tp_bwd_jvp_jvp_rule - - -# ============================================================================== -# 11. JVP Rule for Original Backward Primitive -# ============================================================================== - -def tp_bwd_jvp_rule(primals, tangents, *, kernel, hash): - X, Y, W, dZ = primals - tX, tY, tW, tdZ = tangents - - tX, tY, tW, tdZ = [ensure_array(t, p) for t, p in zip(tangents, primals)] - out_primal = tp_bwd_p.bind(X, Y, W, dZ, kernel=kernel, hash=hash) - out_tangent = tp_bwd_jvp_p.bind(X, Y, W, dZ, tX, tY, tW, tdZ, kernel=kernel, hash=hash) - - return out_primal, out_tangent - -ad.primitive_jvps[tp_bwd_p] = tp_bwd_jvp_rule - class TensorProduct(LoopUnrollTP): r""" diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py new file mode 100644 index 00000000..e69de29b diff --git a/openequivariance/openequivariance/jax/jvp/tp_prim.py b/openequivariance/openequivariance/jax/jvp/tp_prim.py new file mode 100644 index 00000000..deabc60b --- /dev/null +++ b/openequivariance/openequivariance/jax/jvp/tp_prim.py @@ -0,0 +1,279 @@ +import json +import jax +import jax.numpy as jnp +from jax.extend import core +from jax.interpreters import mlir, ad + +# Implements the ladder of derivatives for tensor product +# via primitives, JVP, and transpose instead of custom_vjp. + +# ============================================================================== +# 1. Forward Primitive +# ============================================================================== + +tp_fwd_p = core.Primitive("tp_fwd") + +def tp_fwd_impl(X, Y, W, *, L3_dim, kernel, hash): + irrep_dtype = X.dtype + out_shape = jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) + call = jax.ffi.ffi_call("tp_forward", out_shape) + return call(X, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) + +def tp_fwd_abstract_eval(X, Y, W, *, L3_dim, kernel, hash): + return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + +tp_fwd_p.def_impl(tp_fwd_impl) +tp_fwd_p.def_abstract_eval(tp_fwd_abstract_eval) +mlir.register_lowering(tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="cuda") +mlir.register_lowering(tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="rocm") + + +# ============================================================================== +# 2. Backward Primitive +# ============================================================================== + +tp_bwd_p = core.Primitive("tp_bwd") +tp_bwd_p.multiple_results = True + +def tp_bwd_impl(X, Y, W, dZ, *, kernel, hash): + irrep_dtype = X.dtype + out_shapes = ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + ) + + call = jax.ffi.ffi_call("tp_backward", out_shapes) + result = call(X, Y, W, dZ, kernel=kernel, hash=hash) + return result + +def tp_bwd_abstract_eval(X, Y, W, dZ, *, kernel, hash): + irrep_dtype = X.dtype + return ( + core.ShapedArray(X.shape, irrep_dtype), + core.ShapedArray(Y.shape, irrep_dtype), + core.ShapedArray(W.shape, irrep_dtype), + ) + +tp_bwd_p.def_impl(tp_bwd_impl) +tp_bwd_p.def_abstract_eval(tp_bwd_abstract_eval) +mlir.register_lowering(tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="cuda") +mlir.register_lowering(tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="rocm") + + +# ============================================================================== +# 3. Double Backward Primitive +# ============================================================================== + +tp_dbwd_p = core.Primitive("tp_dbwd") +tp_dbwd_p.multiple_results = True + +def tp_dbwd_impl(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): + irrep_dtype = X.dtype + out_shapes = ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), + ) + call = jax.ffi.ffi_call("tp_double_backward", out_shapes) + return call(X, Y, W, dZ, ddX, ddY, ddW, kernel=kernel, hash=hash) + +def tp_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), + jax.core.ShapedArray(dZ.shape, irrep_dtype), + ) + +tp_dbwd_p.def_impl(tp_dbwd_impl) +tp_dbwd_p.def_abstract_eval(tp_dbwd_abstract_eval) +mlir.register_lowering(tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="cuda") +mlir.register_lowering(tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="rocm") + +# ============================================================================== +# 4. Forward JVP Primitive Definition +# ============================================================================== + +tp_fwd_jvp_p = core.Primitive("tp_fwd_jvp") + +def tp_fwd_jvp_impl(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): + term1 = tp_fwd_p.bind(dX, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) + term2 = tp_fwd_p.bind(X, dY, W, L3_dim=L3_dim, kernel=kernel, hash=hash) + term3 = tp_fwd_p.bind(X, Y, dW, L3_dim=L3_dim, kernel=kernel, hash=hash) + return term1 + term2 + term3 + +def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): + return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + +tp_fwd_jvp_p.def_impl(tp_fwd_jvp_impl) +tp_fwd_jvp_p.def_abstract_eval(tp_fwd_jvp_abstract_eval) + + +# ============================================================================== +# 5. Transpose Rule (Implicit VJP) +# ============================================================================== + +def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): + # This transpose corresponds to the Backward pass. + # We assert that we are differentiating with respect to the input tangents. + assert ad.is_undefined_primal(dX) + assert ad.is_undefined_primal(dY) + assert ad.is_undefined_primal(dW) + + # If the primals X, Y, W are being differentiated (undefined), we replace + # them with zeros for the purpose of this kernel call. + if ad.is_undefined_primal(X): + X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): + Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): + W = jnp.zeros(W.aval.shape, W.aval.dtype) + + grad_X, grad_Y, grad_W = tp_bwd_p.bind(X, Y, W, ct, kernel=kernel, hash=hash) + + # Return gradients for (X, Y, W, dX, dY, dW). + # Primals get None, tangents get the computed gradients. + return (None, None, None, grad_X, grad_Y, grad_W) + +ad.primitive_transposes[tp_fwd_jvp_p] = tp_fwd_jvp_transpose + +def ensure_array(tan, primal): + if type(tan) is ad.Zero: + return jnp.zeros_like(primal) + return tan + +# ============================================================================== +# 6. JVP Rule for Original Forward Primitive +# ============================================================================== + +def tp_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): + X, Y, W = primals + dX, dY, dW = tangents + + dX = ensure_array(dX, X) + dY = ensure_array(dY, Y) + dW = ensure_array(dW, W) + + out_primal = tp_fwd_p.bind(X, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) + out_tangent = tp_fwd_jvp_p.bind(X, Y, W, dX, dY, dW, L3_dim=L3_dim, kernel=kernel, hash=hash) + + return out_primal, out_tangent + +ad.primitive_jvps[tp_fwd_p] = tp_fwd_jvp_rule + + +# ============================================================================== +# 7. JVP Rule for Forward JVP Primitive (Higher Order) +# ============================================================================== + +def tp_fwd_jvp_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): + tangents_clean = [] + for t, p in zip(tangents, primals): + if type(t) is ad.Zero: + tangents_clean.append(jnp.zeros_like(p)) + else: + tangents_clean.append(t) + tangents_clean = tuple(tangents_clean) + + def func(x, y, w, dx, dy, dw): + return tp_fwd_jvp_impl(x, y, w, dx, dy, dw, L3_dim=L3_dim, kernel=kernel, hash=hash) + + return jax.jvp(func, primals, tangents_clean) + +ad.primitive_jvps[tp_fwd_jvp_p] = tp_fwd_jvp_jvp_rule + + +# ============================================================================== +# 8. Backward JVP Primitive Definition +# ============================================================================== + +tp_bwd_jvp_p = core.Primitive("tp_bwd_jvp") +tp_bwd_jvp_p.multiple_results = True + +def tp_bwd_jvp_impl(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): + term_dZ = tp_bwd_p.bind(X, Y, W, tdZ, kernel=kernel, hash=hash) + term_X = tp_bwd_p.bind(tX, Y, W, dZ, kernel=kernel, hash=hash) + term_Y = tp_bwd_p.bind(X, tY, W, dZ, kernel=kernel, hash=hash) + term_W = tp_bwd_p.bind(X, Y, tW, dZ, kernel=kernel, hash=hash) + + out_dX = term_dZ[0] + term_Y[0] + term_W[0] + out_dY = term_dZ[1] + term_X[1] + term_W[1] + out_dW = term_dZ[2] + term_X[2] + term_Y[2] + + return out_dX, out_dY, out_dW + +def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), + ) + +tp_bwd_jvp_p.def_impl(tp_bwd_jvp_impl) +tp_bwd_jvp_p.def_abstract_eval(tp_bwd_jvp_abstract_eval) + + +# ============================================================================== +# 9. Transpose Rule for Backward JVP +# ============================================================================== + +def tp_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): + ddX, ddY, ddW = ct + + assert ad.is_undefined_primal(tX) + assert ad.is_undefined_primal(tY) + assert ad.is_undefined_primal(tW) + assert ad.is_undefined_primal(tdZ) + + if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) + if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) + + g_X, g_Y, g_W, g_dZ = tp_dbwd_p.bind(X, Y, W, dZ, ddX, ddY, ddW, kernel=kernel, hash=hash) + + return (None, None, None, None, g_X, g_Y, g_W, g_dZ) + +ad.primitive_transposes[tp_bwd_jvp_p] = tp_bwd_jvp_transpose + + +# ============================================================================== +# 10. JVP Rule for Backward JVP Primitive (Higher Order) +# ============================================================================== + +def tp_bwd_jvp_jvp_rule(primals, tangents, *, kernel, hash): + tangents_clean = [] + for t, p in zip(tangents, primals): + if type(t) is ad.Zero: + tangents_clean.append(jnp.zeros_like(p)) + else: + tangents_clean.append(t) + tangents_clean = tuple(tangents_clean) + + def func(x, y, w, dz, tx, ty, tw, tdz): + return tp_bwd_jvp_impl(x, y, w, dz, tx, ty, tw, tdz, kernel=kernel, hash=hash) + + return jax.jvp(func, primals, tangents_clean) + +ad.primitive_jvps[tp_bwd_jvp_p] = tp_bwd_jvp_jvp_rule + + +# ============================================================================== +# 11. JVP Rule for Original Backward Primitive +# ============================================================================== + +def tp_bwd_jvp_rule(primals, tangents, *, kernel, hash): + X, Y, W, dZ = primals + tX, tY, tW, tdZ = tangents + + tX, tY, tW, tdZ = [ensure_array(t, p) for t, p in zip(tangents, primals)] + out_primal = tp_bwd_p.bind(X, Y, W, dZ, kernel=kernel, hash=hash) + out_tangent = tp_bwd_jvp_p.bind(X, Y, W, dZ, tX, tY, tW, tdZ, kernel=kernel, hash=hash) + + return out_primal, out_tangent + +ad.primitive_jvps[tp_bwd_p] = tp_bwd_jvp_rule \ No newline at end of file From 40dab78b3e9ae45122b1406bf43acd9af6be4291 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 28 Jan 2026 21:14:51 -0800 Subject: [PATCH 05/18] Adding the convolution primitives. --- .../openequivariance/core/utils.py | 10 - .../openequivariance/jax/TensorProductConv.py | 179 +------- .../openequivariance/jax/jvp/conv_prim.py | 381 ++++++++++++++++++ .../openequivariance/jax/jvp/tp_prim.py | 65 ++- 4 files changed, 456 insertions(+), 179 deletions(-) diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 5fd8f81d..15a5ca25 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -200,13 +200,3 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): time_millis[i] = kernel_time return time_millis - - -def hash_attributes(attrs): - m = hashlib.sha256() - - for key in sorted(attrs.keys()): - m.update(attrs[key].__repr__().encode("utf-8")) - - hash = int(m.hexdigest()[:16], 16) >> 1 - attrs["hash"] = hash diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 7439cd4e..a8487268 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -1,4 +1,5 @@ import jax +import json import jax.numpy as jnp import numpy as np from functools import partial @@ -7,171 +8,13 @@ from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollConv import LoopUnrollConv -from openequivariance.core.utils import hash_attributes from openequivariance.jax.utils import reorder_jax from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.jax.jvp.conv_prim import conv_fwd_p -logger = getLogger() - - -def zeros_like(x): - return jnp.zeros_like(x) - - -@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) -def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): - forward_call = jax.ffi.ffi_call( - "conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) - ) - return forward_call(X, Y, W, rows, cols, workspace, sender_perm, **attrs) - - -def forward_fwd( - X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs -): - out = forward( - X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs - ) - return out, (X, Y, W, rows, cols) - - -def forward_bwd(workspace, sender_perm, L3_dim, irrep_dtype, attrs, res, dZ): - X, Y, W, rows, cols = res - dX, dY, dW = backward( - X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs - ) - return dX, dY, dW, None, None - - -forward.defvjp(forward_fwd, forward_bwd) - - -@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9)) -def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): - backward_call = jax.ffi.ffi_call( - "conv_backward", - ( - jax.ShapeDtypeStruct(X.shape, irrep_dtype), - jax.ShapeDtypeStruct(Y.shape, irrep_dtype), - jax.ShapeDtypeStruct(W.shape, irrep_dtype), - ), - ) - return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs) - - -def backward_fwd(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): - out = backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs) - return out, (X, Y, W, dZ, rows, cols) - - -def backward_bwd(workspace, sender_perm, irrep_dtype, attrs, res, derivatives): - X, Y, W, dZ, rows, cols = res - ddX, ddY, ddW = derivatives - - gX, gY, gW, gdZ = double_backward( - X, - Y, - W, - dZ, - ddX, - ddY, - ddW, - rows, - cols, - workspace, - sender_perm, - irrep_dtype, - attrs, - ) - - return gX, gY, gW, gdZ, None, None - -backward.defvjp(backward_fwd, backward_bwd) - - -@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12)) -def double_backward( - X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs -): - double_backward_call = jax.ffi.ffi_call( - "conv_double_backward", - ( - jax.ShapeDtypeStruct(X.shape, irrep_dtype), - jax.ShapeDtypeStruct(Y.shape, irrep_dtype), - jax.ShapeDtypeStruct(W.shape, irrep_dtype), - jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), - ), - ) - return double_backward_call( - X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, **attrs - ) - - -def double_backward_fwd( - X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs -): - out = double_backward( - X, - Y, - W, - dZ, - ddX, - ddY, - ddW, - rows, - cols, - workspace, - sender_perm, - irrep_dtype, - attrs, - ) - return out, (X, Y, W, dZ, ddX, ddY, ddW, rows, cols) - - -def triple_backward( - workspace, - sender_perm, - irrep_dtype, - attrs, - residuals, - tangent_outputs, -): - X, Y, W, dZ, ddX, ddY, ddW, rows, cols = residuals - t_dX, t_dY, t_dW, t_ddZ = tangent_outputs - - common_args = (rows, cols, workspace, sender_perm, irrep_dtype, attrs) - - op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W)) - g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, *common_args) - - op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW)) - g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, *common_args) - - op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW) - g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, *common_args) - - op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW) - g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, *common_args) - - g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, *common_args) - g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, *common_args) - g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, *common_args) - - grad_X = g2_X + g4_X + g6_X + g7_X - grad_Y = g2_Y + g3_Y + g5_Y + g7_Y - grad_W = g1_W + g3_W + g4_W + g5_W + g6_W - grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ - - grad_ddX = g1_ddX + g3_ddX + g5_ddX - grad_ddY = g1_ddY + g4_ddY + g6_ddY - grad_ddW = g2_ddW + g7_ddW - - return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW, None, None - - -double_backward.defvjp(double_backward_fwd, triple_backward) +logger = getLogger() class TensorProductConv(LoopUnrollConv): @@ -189,14 +32,14 @@ def __init__( kahan=kahan, ) - self.attrs = { + self.kernel = json.dumps({ "kernel": self.jit_kernel, "forward_config": vars(self.forward_schedule.launch_config), "backward_config": vars(self.backward_schedule.launch_config), "double_backward_config": vars(self.double_backward_schedule.launch_config), - "kernel_prop": self.kernel_prop, - } - hash_attributes(self.attrs) + "kernel_prop": self.kernelProp, + }) + self.hash = self.kernel.__hash__() self.weight_numel = config.weight_numel self.L3_dim = self.config.irreps_out.dim @@ -223,7 +66,7 @@ def forward( "Must provide sender_perm for deterministic convolutions." ) - return forward( + return conv_fwd_p.bind( X, Y, W, @@ -231,9 +74,9 @@ def forward( cols, self.workspace, sender_perm, - self.L3_dim, - self.config.irrep_dtype, - self.attrs, + L3_dim=self.L3_dim, + kernel=self.kernel, + hash=self.hash ) def __call__( diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index e69de29b..f56fe8df 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -0,0 +1,381 @@ +import jax +import jax.numpy as jnp +from jax.extend import core +from jax.interpreters import mlir, ad + +# ============================================================================== +# 0. Helpers +# ============================================================================== + +def ensure_array(tan, primal): + if type(tan) is ad.Zero: + return jnp.zeros_like(primal) + return tan + +# ============================================================================== +# 1. Forward Primitive +# ============================================================================== + +conv_fwd_p = core.Primitive("conv_fwd") + +def conv_fwd_impl(X, Y, W, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + irrep_dtype = X.dtype + out_shape = jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) + call = jax.ffi.ffi_call("conv_forward", out_shape) + return call(X, Y, W, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash) + +def conv_fwd_abstract_eval(X, Y, W, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + +conv_fwd_p.def_impl(conv_fwd_impl) +conv_fwd_p.def_abstract_eval(conv_fwd_abstract_eval) +mlir.register_lowering(conv_fwd_p, mlir.lower_fun(conv_fwd_impl, multiple_results=False), platform="cuda") +mlir.register_lowering(conv_fwd_p, mlir.lower_fun(conv_fwd_impl, multiple_results=False), platform="rocm") + + +# ============================================================================== +# 2. Backward Primitive +# ============================================================================== + +conv_bwd_p = core.Primitive("conv_bwd") +conv_bwd_p.multiple_results = True + +def conv_bwd_impl(X, Y, W, dZ, rows, cols, workspace, sender_perm, *, kernel, hash): + irrep_dtype = X.dtype + out_shapes = ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + ) + call = jax.ffi.ffi_call("conv_backward", out_shapes) + return call(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash) + +def conv_bwd_abstract_eval(X, Y, W, dZ, rows, cols, workspace, sender_perm, *, kernel, hash): + irrep_dtype = X.dtype + return ( + core.ShapedArray(X.shape, irrep_dtype), + core.ShapedArray(Y.shape, irrep_dtype), + core.ShapedArray(W.shape, irrep_dtype), + ) + +conv_bwd_p.def_impl(conv_bwd_impl) +conv_bwd_p.def_abstract_eval(conv_bwd_abstract_eval) +mlir.register_lowering(conv_bwd_p, mlir.lower_fun(conv_bwd_impl, multiple_results=True), platform="cuda") +mlir.register_lowering(conv_bwd_p, mlir.lower_fun(conv_bwd_impl, multiple_results=True), platform="rocm") + + +# ============================================================================== +# 3. Double Backward Primitive +# ============================================================================== + +conv_dbwd_p = core.Primitive("conv_dbwd") +conv_dbwd_p.multiple_results = True + +def conv_dbwd_impl(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash): + irrep_dtype = X.dtype + out_shapes = ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), + ) + call = jax.ffi.ffi_call("conv_double_backward", out_shapes) + return call(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash) + +def conv_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), + jax.core.ShapedArray(dZ.shape, irrep_dtype), + ) + +conv_dbwd_p.def_impl(conv_dbwd_impl) +conv_dbwd_p.def_abstract_eval(conv_dbwd_abstract_eval) +mlir.register_lowering(conv_dbwd_p, mlir.lower_fun(conv_dbwd_impl, multiple_results=True), platform="cuda") +mlir.register_lowering(conv_dbwd_p, mlir.lower_fun(conv_dbwd_impl, multiple_results=True), platform="rocm") + + +# ============================================================================== +# 4. Forward JVP Primitive Definition +# ============================================================================== + +conv_fwd_jvp_p = core.Primitive("conv_fwd_jvp") + +def conv_fwd_jvp_impl(X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + kwargs = dict(L3_dim=L3_dim, kernel=kernel, hash=hash) + args_meta = (rows, cols, workspace, sender_perm) + + term1 = conv_fwd_p.bind(dX, Y, W, *args_meta, **kwargs) + term2 = conv_fwd_p.bind(X, dY, W, *args_meta, **kwargs) + term3 = conv_fwd_p.bind(X, Y, dW, *args_meta, **kwargs) + return term1 + term2 + term3 + +def conv_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + +conv_fwd_jvp_p.def_impl(conv_fwd_jvp_impl) +conv_fwd_jvp_p.def_abstract_eval(conv_fwd_jvp_abstract_eval) + + +# ============================================================================== +# 5. Transpose Rule (Implicit VJP) +# ============================================================================== + +def conv_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + assert ad.is_undefined_primal(dX) + assert ad.is_undefined_primal(dY) + assert ad.is_undefined_primal(dW) + + if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) + + grad_X, grad_Y, grad_W = conv_bwd_p.bind( + X, Y, W, ct, + rows, cols, workspace, sender_perm, + kernel=kernel, hash=hash + ) + + return (None, None, None, grad_X, grad_Y, grad_W, None, None, None, None) + +ad.primitive_transposes[conv_fwd_jvp_p] = conv_fwd_jvp_transpose + + +# ============================================================================== +# 6. JVP Rule for Original Forward Primitive +# ============================================================================== + +def conv_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): + X, Y, W, rows, cols, workspace, sender_perm = primals + dX, dY, dW, drows, dcols, dworkspace, dsender_perm = tangents + + dX = ensure_array(dX, X) + dY = ensure_array(dY, Y) + dW = ensure_array(dW, W) + + out_primal = conv_fwd_p.bind(X, Y, W, rows, cols, workspace, sender_perm, L3_dim=L3_dim, kernel=kernel, hash=hash) + out_tangent = conv_fwd_jvp_p.bind(X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, L3_dim=L3_dim, kernel=kernel, hash=hash) + + return out_primal, out_tangent + +ad.primitive_jvps[conv_fwd_p] = conv_fwd_jvp_rule + + +# ============================================================================== +# 7. JVP Rule for Forward JVP Primitive (Higher Order) +# ============================================================================== + +def conv_fwd_jvp_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): + tangents_clean = [] + for t, p in zip(tangents, primals): + if type(t) is ad.Zero: + tangents_clean.append(jnp.zeros_like(p)) + else: + tangents_clean.append(t) + tangents_clean = tuple(tangents_clean) + + def func(x, y, w, dx, dy, dw, r, c, ws, sp): + return conv_fwd_jvp_impl( + x, y, w, dx, dy, dw, r, c, ws, sp, + L3_dim=L3_dim, kernel=kernel, hash=hash + ) + + return jax.jvp(func, primals, tangents_clean) + +ad.primitive_jvps[conv_fwd_jvp_p] = conv_fwd_jvp_jvp_rule + + +# ============================================================================== +# 8. Backward JVP Primitive Definition +# ============================================================================== + +conv_bwd_jvp_p = core.Primitive("conv_bwd_jvp") +conv_bwd_jvp_p.multiple_results = True + +def conv_bwd_jvp_impl(X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, *, kernel, hash): + kwargs = dict(kernel=kernel, hash=hash) + args_meta = (rows, cols, workspace, sender_perm) + + term_dZ = conv_bwd_p.bind(X, Y, W, tdZ, *args_meta, **kwargs) + term_X = conv_bwd_p.bind(tX, Y, W, dZ, *args_meta, **kwargs) + term_Y = conv_bwd_p.bind(X, tY, W, dZ, *args_meta, **kwargs) + term_W = conv_bwd_p.bind(X, Y, tW, dZ, *args_meta, **kwargs) + + out_dX = term_dZ[0] + term_Y[0] + term_W[0] + out_dY = term_dZ[1] + term_X[1] + term_W[1] + out_dW = term_dZ[2] + term_X[2] + term_Y[2] + + return out_dX, out_dY, out_dW + +def conv_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, *, kernel, hash): + irrep_dtype = X.dtype + return ( + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), + ) + +conv_bwd_jvp_p.def_impl(conv_bwd_jvp_impl) +conv_bwd_jvp_p.def_abstract_eval(conv_bwd_jvp_abstract_eval) + + +# ============================================================================== +# 9. Transpose Rule for Backward JVP +# ============================================================================== + +def conv_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, *, kernel, hash): + ddX, ddY, ddW = ct + + assert ad.is_undefined_primal(tX) + assert ad.is_undefined_primal(tY) + assert ad.is_undefined_primal(tW) + assert ad.is_undefined_primal(tdZ) + + if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) + if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) + + g_X, g_Y, g_W, g_dZ = conv_dbwd_p.bind( + X, Y, W, dZ, ddX, ddY, ddW, + rows, cols, workspace, sender_perm, + kernel=kernel, hash=hash + ) + + return (None, None, None, None, g_X, g_Y, g_W, g_dZ, None, None, None, None) + +ad.primitive_transposes[conv_bwd_jvp_p] = conv_bwd_jvp_transpose + + +# ============================================================================== +# 10. JVP Rule for Backward JVP Primitive (Higher Order) +# ============================================================================== + +def conv_bwd_jvp_jvp_rule(primals, tangents, *, kernel, hash): + tangents_clean = [] + for t, p in zip(tangents, primals): + if type(t) is ad.Zero: + tangents_clean.append(jnp.zeros_like(p)) + else: + tangents_clean.append(t) + tangents_clean = tuple(tangents_clean) + + def func(x, y, w, dz, tx, ty, tw, tdz, r, c, ws, sp): + return conv_bwd_jvp_impl( + x, y, w, dz, tx, ty, tw, tdz, r, c, ws, sp, + kernel=kernel, hash=hash + ) + + return jax.jvp(func, primals, tangents_clean) + +ad.primitive_jvps[conv_bwd_jvp_p] = conv_bwd_jvp_jvp_rule + + +# ============================================================================== +# 11. JVP Rule for Original Backward Primitive +# ============================================================================== + +def conv_bwd_jvp_rule(primals, tangents, *, kernel, hash): + X, Y, W, dZ, rows, cols, workspace, sender_perm = primals + tX, tY, tW, tdZ, drows, dcols, dworkspace, dsender_perm = tangents + + tX, tY, tW, tdZ = [ensure_array(t, p) for t, p in zip((tX, tY, tW, tdZ), (X, Y, W, dZ))] + + out_primal = conv_bwd_p.bind( + X, Y, W, dZ, rows, cols, workspace, sender_perm, + kernel=kernel, hash=hash + ) + out_tangent = conv_bwd_jvp_p.bind( + X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, + kernel=kernel, hash=hash + ) + + return out_primal, out_tangent + +ad.primitive_jvps[conv_bwd_p] = conv_bwd_jvp_rule + + +# ============================================================================== +# 12. Slow Double Backward Implementation (Reference) +# ============================================================================== + +def conv_dbwd_slow(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + kwargs = dict(kernel=kernel, hash=hash) + args_meta = (rows, cols, workspace, sender_perm) + + op1 = conv_bwd_p.bind(ddX, ddY, W, dZ, *args_meta, **kwargs) + op2 = conv_bwd_p.bind(X, Y, ddW, dZ, *args_meta, **kwargs) + + op3 = conv_fwd_p.bind(ddX, Y, W, *args_meta, L3_dim=L3_dim, **kwargs) + op4 = conv_bwd_p.bind(ddX, Y, W, dZ, *args_meta, **kwargs) + op5 = conv_bwd_p.bind(X, ddY, W, dZ, *args_meta, **kwargs) + + op6 = conv_fwd_p.bind(X, ddY, W, *args_meta, L3_dim=L3_dim, **kwargs) + op7 = conv_fwd_p.bind(X, Y, ddW, *args_meta, L3_dim=L3_dim, **kwargs) + + grad_X = op1[0] + op2[0] + grad_Y = op1[1] + op2[1] + grad_W = op4[2] + op5[2] + grad_dZ = op3 + op6 + op7 + + return grad_X, grad_Y, grad_W, grad_dZ + + +# ============================================================================== +# 13. JVP rule for double backward (implicit) +# ============================================================================== + +def conv_dbwd_jvp_rule(primals, tangents, *, kernel, hash): + # Infer L3_dim from dZ (4th input) + dZ = primals[3] + L3_dim = dZ.shape[1] + + tangents_clean = [] + for t, p in zip(tangents, primals): + if type(t) is ad.Zero: + tangents_clean.append(jnp.zeros_like(p)) + else: + tangents_clean.append(t) + tangents_clean = tuple(tangents_clean) + + def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): + return conv_dbwd_slow( + x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp, + L3_dim=L3_dim, kernel=kernel, hash=hash + ) + + return jax.jvp(func, primals, tangents_clean) + +ad.primitive_jvps[conv_dbwd_p] = conv_dbwd_jvp_rule + + +# ============================================================================== +# 14. Transpose rule for double backward +# ============================================================================== + +def conv_dbwd_transpose(ct, X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash): + # Infer L3_dim from dZ + L3_dim = dZ.shape[1] + + if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) + if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) + if ad.is_undefined_primal(ddX): ddX = jnp.zeros(ddX.aval.shape, ddX.aval.dtype) + if ad.is_undefined_primal(ddY): ddY = jnp.zeros(ddY.aval.shape, ddY.aval.dtype) + if ad.is_undefined_primal(ddW): ddW = jnp.zeros(ddW.aval.shape, ddW.aval.dtype) + + def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): + return conv_dbwd_slow( + x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp, + L3_dim=L3_dim, kernel=kernel, hash=hash + ) + + _, vjp_fun = jax.vjp(func, X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm) + input_grads = vjp_fun(ct) + + return input_grads + +ad.primitive_transposes[conv_dbwd_p] = conv_dbwd_transpose \ No newline at end of file diff --git a/openequivariance/openequivariance/jax/jvp/tp_prim.py b/openequivariance/openequivariance/jax/jvp/tp_prim.py index deabc60b..9ec96a61 100644 --- a/openequivariance/openequivariance/jax/jvp/tp_prim.py +++ b/openequivariance/openequivariance/jax/jvp/tp_prim.py @@ -276,4 +276,67 @@ def tp_bwd_jvp_rule(primals, tangents, *, kernel, hash): return out_primal, out_tangent -ad.primitive_jvps[tp_bwd_p] = tp_bwd_jvp_rule \ No newline at end of file +ad.primitive_jvps[tp_bwd_p] = tp_bwd_jvp_rule + +# ============================================================================== +# 12. Slow Double Backward Implementation (Reference) +# ============================================================================== + +def tp_dbwd_slow(X, Y, W, dZ, ddX, ddY, ddW, *, L3_dim, kernel, hash): + op1 = tp_bwd_p.bind(ddX, ddY, W, dZ, kernel=kernel, hash=hash) + op2 = tp_bwd_p.bind(X, Y, ddW, dZ, kernel=kernel, hash=hash) + op3 = tp_fwd_p.bind(ddX, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) + op4 = tp_bwd_p.bind(ddX, Y, W, dZ, kernel=kernel, hash=hash) + op5 = tp_bwd_p.bind(X, ddY, W, dZ, kernel=kernel, hash=hash) + op6 = tp_fwd_p.bind(X, ddY, W, L3_dim=L3_dim, kernel=kernel, hash=hash) + op7 = tp_fwd_p.bind(X, Y, ddW, L3_dim=L3_dim, kernel=kernel, hash=hash) + + grad_X = op1[0] + op2[0] + grad_Y = op1[1] + op2[1] + grad_W = op4[2] + op5[2] + grad_dZ = op3 + op6 + op7 + + return grad_X, grad_Y, grad_W, grad_dZ + +# ============================================================================== +# 12. JVP rule for double backward (implicit) +# ============================================================================== + +def tp_dbwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): + tangents_clean = [] + for t, p in zip(tangents, primals): + if type(t) is ad.Zero: + tangents_clean.append(jnp.zeros_like(p)) + else: + tangents_clean.append(t) + tangents_clean = tuple(tangents_clean) + + def func(x, y, w, dz, ddx, ddy, ddw): + return tp_dbwd_slow(x, y, w, dz, ddx, ddy, ddw, L3_dim=L3_dim, kernel=kernel, hash=hash) + + return jax.jvp(func, primals, tangents_clean) + +ad.primitive_jvps[tp_dbwd_p] = tp_dbwd_jvp_rule + +# ============================================================================== +# 12. Transpose rule for double backward +# ============================================================================== + +def tp_dbwd_transpose(ct, X, Y, W, dZ, ddX, ddY, ddW, *, L3_dim, kernel, hash): + if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) + if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) + if ad.is_undefined_primal(ddX): ddX = jnp.zeros(ddX.aval.shape, ddX.aval.dtype) + if ad.is_undefined_primal(ddY): ddY = jnp.zeros(ddY.aval.shape, ddY.aval.dtype) + if ad.is_undefined_primal(ddW): ddW = jnp.zeros(ddW.aval.shape, ddW.aval.dtype) + + def func(x, y, w, dz, ddx, ddy, ddw): + return tp_dbwd_slow(x, y, w, dz, ddx, ddy, ddw, L3_dim=L3_dim, kernel=kernel, hash=hash) + + _, vjp_fun = jax.vjp(func, X, Y, W, dZ, ddX, ddY, ddW) + input_grads = vjp_fun(ct) + + return input_grads + +ad.primitive_transposes[tp_dbwd_p] = tp_dbwd_transpose \ No newline at end of file From 08530445c5ddd60ff6981153f6770d6c581ea5f4 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 28 Jan 2026 21:30:50 -0800 Subject: [PATCH 06/18] Convolution working. --- openequivariance/openequivariance/jax/TensorProductConv.py | 2 +- openequivariance/openequivariance/jax/__init__.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index a8487268..1298a23b 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -37,7 +37,7 @@ def __init__( "forward_config": vars(self.forward_schedule.launch_config), "backward_config": vars(self.backward_schedule.launch_config), "double_backward_config": vars(self.double_backward_schedule.launch_config), - "kernel_prop": self.kernelProp, + "kernel_prop": self.kernel_prop, }) self.hash = self.kernel.__hash__() diff --git a/openequivariance/openequivariance/jax/__init__.py b/openequivariance/openequivariance/jax/__init__.py index 6e1a9087..b4e14b24 100644 --- a/openequivariance/openequivariance/jax/__init__.py +++ b/openequivariance/openequivariance/jax/__init__.py @@ -1,6 +1,4 @@ from openequivariance.jax.TensorProduct import TensorProduct as TensorProduct -#from openequivariance.jax.TensorProductConv import ( -# TensorProductConv as TensorProductConv, -#) +from openequivariance.jax.TensorProductConv import TensorProductConv as TensorProductConv -__all__ = ["TensorProduct"] #"TensorProductConv"] +__all__ = ["TensorProduct", "TensorProductConv"] From 389771a21e6637cfdb6ff8f96bdc7852a2d1cfaa Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 28 Jan 2026 21:43:33 -0800 Subject: [PATCH 07/18] JITted. --- openequivariance/openequivariance/jax/TensorProduct.py | 6 ++++-- .../openequivariance/jax/TensorProductConv.py | 9 ++++++--- openequivariance/openequivariance/jax/jvp/conv_prim.py | 6 +++--- openequivariance/openequivariance/jax/jvp/tp_prim.py | 6 +++--- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index ca8d0999..a9af4d64 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -96,7 +96,7 @@ def double_backward_cpu( in2_dgrad_jax = jax.numpy.asarray(in2_dgrad) weights_dgrad_jax = jax.numpy.asarray(weights_dgrad) - in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp( + dbwd_func = jax.jit(jax.vjp( lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[ 1 ](o), @@ -104,6 +104,8 @@ def double_backward_cpu( in2_jax, weights_jax, out_grad_jax, - )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) + )[1]) + + in1_grad, in2_grad, weights_grad, out_dgrad = dbwd_func((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) return in1_grad, in2_grad, weights_grad, out_dgrad diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 1298a23b..7ff45e1e 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -103,7 +103,9 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): weights = self.reorder_weights_from_e3nn( weights, has_batch_dim=not self.config.shared_weights ) - result = self.forward( + + jit_fwd = jax.jit(self.forward) + result = jit_fwd( jax.numpy.asarray(L1_in), jax.numpy.asarray(L2_in), jax.numpy.asarray(weights), @@ -131,7 +133,7 @@ def backward_cpu( weights, has_batch_dim=not self.config.shared_weights ) - backward_fn = jax.vjp( + backward_fn = jax.jit(jax.vjp( lambda X, Y, W: self.forward( X, Y, @@ -143,7 +145,8 @@ def backward_cpu( jax.numpy.asarray(L1_in), jax.numpy.asarray(L2_in), jax.numpy.asarray(weights), - )[1] + )[1]) + L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn( jax.numpy.asarray(L3_grad) ) diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index f56fe8df..a8bdb74d 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -53,9 +53,9 @@ def conv_bwd_impl(X, Y, W, dZ, rows, cols, workspace, sender_perm, *, kernel, ha def conv_bwd_abstract_eval(X, Y, W, dZ, rows, cols, workspace, sender_perm, *, kernel, hash): irrep_dtype = X.dtype return ( - core.ShapedArray(X.shape, irrep_dtype), - core.ShapedArray(Y.shape, irrep_dtype), - core.ShapedArray(W.shape, irrep_dtype), + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), ) conv_bwd_p.def_impl(conv_bwd_impl) diff --git a/openequivariance/openequivariance/jax/jvp/tp_prim.py b/openequivariance/openequivariance/jax/jvp/tp_prim.py index 9ec96a61..6ad04dd7 100644 --- a/openequivariance/openequivariance/jax/jvp/tp_prim.py +++ b/openequivariance/openequivariance/jax/jvp/tp_prim.py @@ -50,9 +50,9 @@ def tp_bwd_impl(X, Y, W, dZ, *, kernel, hash): def tp_bwd_abstract_eval(X, Y, W, dZ, *, kernel, hash): irrep_dtype = X.dtype return ( - core.ShapedArray(X.shape, irrep_dtype), - core.ShapedArray(Y.shape, irrep_dtype), - core.ShapedArray(W.shape, irrep_dtype), + jax.core.ShapedArray(X.shape, irrep_dtype), + jax.core.ShapedArray(Y.shape, irrep_dtype), + jax.core.ShapedArray(W.shape, irrep_dtype), ) tp_bwd_p.def_impl(tp_bwd_impl) From 85bfb9ae919d804c613d5fce6659fcb9979e1a0b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 28 Jan 2026 21:46:54 -0800 Subject: [PATCH 08/18] Tests JIT also. --- openequivariance/openequivariance/jax/TensorProduct.py | 6 +++--- openequivariance/openequivariance/jax/TensorProductConv.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index a9af4d64..ae8617d4 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -56,7 +56,7 @@ def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None: weights = self.reorder_weights_from_e3nn( weights, has_batch_dim=not self.config.shared_weights ) - result = self.forward( + result = jax.jit(self.forward)( jax.numpy.asarray(L1_in), jax.numpy.asarray(L2_in), jax.numpy.asarray(weights), @@ -69,12 +69,12 @@ def backward_cpu( weights = self.reorder_weights_from_e3nn( weights, has_batch_dim=not self.config.shared_weights ) - backward_fn = jax.vjp( + backward_fn = jax.jit(jax.vjp( lambda X, Y, W: self.forward(X, Y, W), jax.numpy.asarray(L1_in), jax.numpy.asarray(L2_in), jax.numpy.asarray(weights), - )[1] + )[1]) L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn( jax.numpy.asarray(L3_grad) ) diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 7ff45e1e..e3814fc9 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -172,7 +172,7 @@ def double_backward_cpu( cols_jax = jax.numpy.asarray(graph.cols.astype(self.idx_dtype)) sender_perm_jax = jax.numpy.asarray(graph.transpose_perm.astype(self.idx_dtype)) - in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp( + in1_grad, in2_grad, weights_grad, out_dgrad = jax.jit(jax.vjp( lambda x, y, w, o: jax.vjp( lambda a, b, c: self.forward( a, b, c, rows_jax, cols_jax, sender_perm_jax @@ -185,7 +185,7 @@ def double_backward_cpu( in2_jax, weights_jax, out_grad_jax, - )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) + )[1])((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) return ( np.asarray(in1_grad), From ca4959423c153774d9a1b8049b871ffeede3d501 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 28 Jan 2026 23:54:26 -0800 Subject: [PATCH 09/18] Third derivative is failing. --- .../openequivariance/jax/jvp/conv_prim.py | 32 ++++++++----------- .../openequivariance/jax/jvp/tp_prim.py | 15 --------- 2 files changed, 13 insertions(+), 34 deletions(-) diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index a8bdb74d..983a8f23 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -12,6 +12,15 @@ def ensure_array(tan, primal): return jnp.zeros_like(primal) return tan +def clean_tensors(*tensors): + tensors_clean = [] + for t in tensors: + result = t + if type(t) is ad.Zero or ad.is_undefined_primal(t): + result = jnp.zeros(t.aval.shape, t.aval.dtype) + tensors_clean.append(result) + return tensors_clean + # ============================================================================== # 1. Forward Primitive # ============================================================================== @@ -124,10 +133,6 @@ def conv_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, rows, cols, workspace, sende # ============================================================================== def conv_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): - assert ad.is_undefined_primal(dX) - assert ad.is_undefined_primal(dY) - assert ad.is_undefined_primal(dW) - if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) @@ -228,19 +233,15 @@ def conv_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspa def conv_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, *, kernel, hash): ddX, ddY, ddW = ct - assert ad.is_undefined_primal(tX) - assert ad.is_undefined_primal(tY) - assert ad.is_undefined_primal(tW) - assert ad.is_undefined_primal(tdZ) - if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) + tensors_clean = clean_tensors(X, Y, W, dZ, tX, tY, tW) + g_X, g_Y, g_W, g_dZ = conv_dbwd_p.bind( - X, Y, W, dZ, ddX, ddY, ddW, - rows, cols, workspace, sender_perm, + *tensors_clean, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash ) @@ -332,20 +333,13 @@ def conv_dbwd_jvp_rule(primals, tangents, *, kernel, hash): dZ = primals[3] L3_dim = dZ.shape[1] - tangents_clean = [] - for t, p in zip(tangents, primals): - if type(t) is ad.Zero: - tangents_clean.append(jnp.zeros_like(p)) - else: - tangents_clean.append(t) - tangents_clean = tuple(tangents_clean) - def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): return conv_dbwd_slow( x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp, L3_dim=L3_dim, kernel=kernel, hash=hash ) + tangents_clean = tuple(clean_tensors(*tangents)) return jax.jvp(func, primals, tangents_clean) ad.primitive_jvps[conv_dbwd_p] = conv_dbwd_jvp_rule diff --git a/openequivariance/openequivariance/jax/jvp/tp_prim.py b/openequivariance/openequivariance/jax/jvp/tp_prim.py index 6ad04dd7..ac914b54 100644 --- a/openequivariance/openequivariance/jax/jvp/tp_prim.py +++ b/openequivariance/openequivariance/jax/jvp/tp_prim.py @@ -117,14 +117,6 @@ def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): # ============================================================================== def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): - # This transpose corresponds to the Backward pass. - # We assert that we are differentiating with respect to the input tangents. - assert ad.is_undefined_primal(dX) - assert ad.is_undefined_primal(dY) - assert ad.is_undefined_primal(dW) - - # If the primals X, Y, W are being differentiated (undefined), we replace - # them with zeros for the purpose of this kernel call. if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) if ad.is_undefined_primal(Y): @@ -134,8 +126,6 @@ def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): grad_X, grad_Y, grad_W = tp_bwd_p.bind(X, Y, W, ct, kernel=kernel, hash=hash) - # Return gradients for (X, Y, W, dX, dY, dW). - # Primals get None, tangents get the computed gradients. return (None, None, None, grad_X, grad_Y, grad_W) ad.primitive_transposes[tp_fwd_jvp_p] = tp_fwd_jvp_transpose @@ -224,11 +214,6 @@ def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): def tp_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): ddX, ddY, ddW = ct - assert ad.is_undefined_primal(tX) - assert ad.is_undefined_primal(tY) - assert ad.is_undefined_primal(tW) - assert ad.is_undefined_primal(tdZ) - if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) From 43751ed3fa8dae2270f00cd51353a98a502f7d76 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 29 Jan 2026 00:05:30 -0800 Subject: [PATCH 10/18] Fixed failing third derivative. --- openequivariance/openequivariance/jax/jvp/conv_prim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index 983a8f23..3686c3eb 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -238,7 +238,7 @@ def conv_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspa if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) - tensors_clean = clean_tensors(X, Y, W, dZ, tX, tY, tW) + tensors_clean = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) g_X, g_Y, g_W, g_dZ = conv_dbwd_p.bind( *tensors_clean, rows, cols, workspace, sender_perm, From 7f627d8e92be67ab817b5aae7e64980d846b05aa Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 29 Jan 2026 00:24:00 -0800 Subject: [PATCH 11/18] Fixed up some more stuff. --- .../openequivariance/jax/jvp/conv_prim.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index 3686c3eb..aa1b041d 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -173,13 +173,7 @@ def conv_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): # ============================================================================== def conv_fwd_jvp_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): - tangents_clean = [] - for t, p in zip(tangents, primals): - if type(t) is ad.Zero: - tangents_clean.append(jnp.zeros_like(p)) - else: - tangents_clean.append(t) - tangents_clean = tuple(tangents_clean) + tangents_clean = tuple(clean_tensors(*tangents)) def func(x, y, w, dx, dy, dw, r, c, ws, sp): return conv_fwd_jvp_impl( @@ -255,13 +249,7 @@ def conv_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspa # ============================================================================== def conv_bwd_jvp_jvp_rule(primals, tangents, *, kernel, hash): - tangents_clean = [] - for t, p in zip(tangents, primals): - if type(t) is ad.Zero: - tangents_clean.append(jnp.zeros_like(p)) - else: - tangents_clean.append(t) - tangents_clean = tuple(tangents_clean) + tangents_clean = tuple(clean_tensors(*tangents)) def func(x, y, w, dz, tx, ty, tw, tdz, r, c, ws, sp): return conv_bwd_jvp_impl( From 65b5e616edfdf9f42fd8c873517a77b3c1b82320 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 29 Jan 2026 20:24:14 -0800 Subject: [PATCH 12/18] Added VJP version for debugging. --- .../openequivariance/jax/TensorProductConv.py | 12 +- .../openequivariance/jax/vjp/conv_func.py | 162 ++++++++++++++++++ 2 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 openequivariance/openequivariance/jax/vjp/conv_func.py diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index e3814fc9..5074fe68 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -11,7 +11,8 @@ from openequivariance.jax.utils import reorder_jax from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.jax.jvp.conv_prim import conv_fwd_p +from openequivariance.jax.jvp import conv_prim +from openequivariance.jax.vjp import conv_func logger = getLogger() @@ -19,9 +20,10 @@ class TensorProductConv(LoopUnrollConv): def __init__( - self, config: TPProblem, deterministic: bool = False, kahan: bool = False + self, config: TPProblem, deterministic: bool = False, kahan: bool = False, requires_jvp: bool = True ): dp = extlib.DeviceProp(0) + self.requires_jvp = requires_jvp super().__init__( config, dp, @@ -66,7 +68,11 @@ def forward( "Must provide sender_perm for deterministic convolutions." ) - return conv_fwd_p.bind( + func = conv_func.forward + if self.requires_jvp: + func = conv_prim.conv_fwd_p.bind + + return func( X, Y, W, diff --git a/openequivariance/openequivariance/jax/vjp/conv_func.py b/openequivariance/openequivariance/jax/vjp/conv_func.py new file mode 100644 index 00000000..5d19df9d --- /dev/null +++ b/openequivariance/openequivariance/jax/vjp/conv_func.py @@ -0,0 +1,162 @@ +import jax +import jax.numpy as jnp +from jax.extend import core +from functools import partial +from jax.interpreters import mlir, ad + +def zeros_like(x): + return jnp.zeros_like(x) + +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) +def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, kernel, hash): + forward_call = jax.ffi.ffi_call( + "conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), X.dtype) + ) + return forward_call(X, Y, W, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash) + + +def forward_fwd( + X, Y, W, rows, cols, workspace, sender_perm, L3_dim, kernel, hash +): + out = forward( + X, Y, W, rows, cols, workspace, sender_perm, L3_dim, kernel, hash + ) + return out, (X, Y, W, rows, cols) + + +def forward_bwd(workspace, sender_perm, L3_dim, kernel, hash, res, dZ): + X, Y, W, rows, cols = res + dX, dY, dW = backward( + X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) + return dX, dY, dW, None, None + + +forward.defvjp(forward_fwd, forward_bwd) + + +@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9)) +def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel, hash): + backward_call = jax.ffi.ffi_call( + "conv_backward", + ( + jax.ShapeDtypeStruct(X.shape, X.dtype), + jax.ShapeDtypeStruct(Y.shape, Y.dtype), + jax.ShapeDtypeStruct(W.shape, W.dtype), + ), + ) + return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash) + + +def backward_fwd(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel, hash): + out = backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel, hash) + return out, (X, Y, W, dZ, rows, cols) + + +def backward_bwd(workspace, sender_perm, kernel, hash, res, derivatives): + X, Y, W, dZ, rows, cols = res + ddX, ddY, ddW = derivatives + + gX, gY, gW, gdZ = double_backward( + X, + Y, + W, + dZ, + ddX, + ddY, + ddW, + rows, + cols, + workspace, + sender_perm, + kernel, + hash, + ) + + return gX, gY, gW, gdZ, None, None + + +backward.defvjp(backward_fwd, backward_bwd) + + +@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12)) +def double_backward( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel, hash +): + double_backward_call = jax.ffi.ffi_call( + "conv_double_backward", + ( + jax.ShapeDtypeStruct(X.shape, X.dtype), + jax.ShapeDtypeStruct(Y.shape, Y.dtype), + jax.ShapeDtypeStruct(W.shape, W.dtype), + jax.ShapeDtypeStruct(dZ.shape, dZ.dtype), + ), + ) + return double_backward_call( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) + + +def double_backward_fwd( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel, hash +): + out = double_backward( + X, + Y, + W, + dZ, + ddX, + ddY, + ddW, + rows, + cols, + workspace, + sender_perm, + kernel, + hash + ) + return out, (X, Y, W, dZ, ddX, ddY, ddW, rows, cols) + + +def triple_backward( + workspace, + sender_perm, + kernel, + hash, + residuals, + tangent_outputs, +): + X, Y, W, dZ, ddX, ddY, ddW, rows, cols = residuals + t_dX, t_dY, t_dW, t_ddZ = tangent_outputs + + common_args = (rows, cols, workspace, sender_perm, kernel, hash) + + op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W)) + g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, *common_args) + + op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW)) + g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, *common_args) + + op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW) + g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, *common_args) + + op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW) + g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, *common_args) + + g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, *common_args) + g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, *common_args) + g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, *common_args) + + grad_X = g2_X + g4_X + g6_X + g7_X + grad_Y = g2_Y + g3_Y + g5_Y + g7_Y + grad_W = g1_W + g3_W + g4_W + g5_W + g6_W + grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ + + grad_ddX = g1_ddX + g3_ddX + g5_ddX + grad_ddY = g1_ddY + g4_ddY + g6_ddY + grad_ddW = g2_ddW + g7_ddW + + return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW, None, None + + +double_backward.defvjp(double_backward_fwd, triple_backward) \ No newline at end of file From d2e749aeff2aa06963ce8cca456515e574ed881f Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 30 Jan 2026 21:11:13 -0800 Subject: [PATCH 13/18] Fixed zero buffers. --- openequivariance_extjax/src/libjax_tp_jit.cpp | 142 ++++++++++-------- 1 file changed, 77 insertions(+), 65 deletions(-) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index cb975329..60aa50d0 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -8,11 +8,9 @@ #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" -#include "json11/json11.hpp" namespace nb = nanobind; namespace ffi = xla::ffi; -using json = json11::Json; #ifdef CUDA_BACKEND #include @@ -133,14 +131,6 @@ void zero_buffer(ffi::AnyBuffer &buffer, stream_t stream) { } #endif -std::unordered_map parse_json_config(const json &j_obj) { - std::unordered_map result; - for (const auto &kv : j_obj.object_items()) { - result[kv.first] = static_cast(kv.second.number_value()); - } - return result; -} - struct KernelProp { int64_t L1_dim, L2_dim, L3_dim, weight_numel; bool shared_weights; @@ -185,8 +175,39 @@ std::unordered_map> conv_cache; std::mutex mut; +std::vector launch_config_keys = { + "num_blocks", + "num_threads", + "smem"}; +std::vector kernel_prop_keys = { + "L1_dim", + "L2_dim", + "L3_dim", + "weight_numel", + "shared_weights", + "opt_level", + "irrep_dtype", + "weight_dtype", + + // Convolution only + "workspace_size", + "deterministic", + "idx_dtype"}; + +std::unordered_map parse_ffi_dict(ffi::Dictionary &dict, const std::vector &keys) { + std::unordered_map result; + for (const auto &key : keys) { + result[key] = dict.get(key).value(); + } + return result; +} + std::pair*, KernelProp> - compile_tp_with_caching(std::string_view json_payload, + compile_tp_with_caching(std::string_view kernel, + ffi::Dictionary forward_config, + ffi::Dictionary backward_config, + ffi::Dictionary double_backward_config, + ffi::Dictionary kernel_prop, int64_t hash, bool is_convolution) { @@ -194,21 +215,12 @@ std::pair*, KernelProp> const std::lock_guard lock(mut); auto it = tp_cache.find(hash); if (it == tp_cache.end()) { - std::string err; - json root = json::parse(std::string(json_payload), err); - if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); - - std::string kernel_src = root["kernel"].string_value(); - auto forward_cfg = parse_json_config(root["forward_config"]); - auto backward_cfg = parse_json_config(root["backward_config"]); - auto dbackward_cfg = parse_json_config(root["double_backward_config"]); - auto kernel_prop_map = parse_json_config(root["kernel_prop"]); - + auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); auto jit_tp_impl = std::make_unique>( - kernel_src, - forward_cfg, - backward_cfg, - dbackward_cfg, + std::string(kernel), + parse_ffi_dict(forward_config, launch_config_keys), + parse_ffi_dict(backward_config, launch_config_keys), + parse_ffi_dict(double_backward_config, launch_config_keys), kernel_prop_map); tp_cache.insert({hash, std::make_pair(std::move(jit_tp_impl), @@ -220,7 +232,11 @@ std::pair*, KernelProp> } std::pair*, KernelProp> - compile_conv_with_caching(std::string_view json_payload, + compile_conv_with_caching(std::string_view kernel, + ffi::Dictionary forward_config, + ffi::Dictionary backward_config, + ffi::Dictionary double_backward_config, + ffi::Dictionary kernel_prop, int64_t hash, bool is_convolution) { @@ -228,21 +244,12 @@ std::pair*, KernelProp> const std::lock_guard lock(mut); auto it = conv_cache.find(hash); if (it == conv_cache.end()) { - std::string err; - json root = json::parse(std::string(json_payload), err); - if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); - - std::string kernel_src = root["kernel"].string_value(); - auto forward_cfg = parse_json_config(root["forward_config"]); - auto backward_cfg = parse_json_config(root["backward_config"]); - auto dbackward_cfg = parse_json_config(root["double_backward_config"]); - auto kernel_prop_map = parse_json_config(root["kernel_prop"]); - + auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); auto jit_conv_impl = std::make_unique>( - kernel_src, - forward_cfg, - backward_cfg, - dbackward_cfg, + std::string(kernel), + parse_ffi_dict(forward_config, launch_config_keys), + parse_ffi_dict(backward_config, launch_config_keys), + parse_ffi_dict(double_backward_config, launch_config_keys), kernel_prop_map); conv_cache.insert({hash, std::make_pair(std::move(jit_conv_impl), @@ -293,13 +300,12 @@ ffi::Error tp_forward_impl( ffi::AnyBuffer L2_in, ffi::AnyBuffer W, ffi::Result L3_out, - stream_t stream, - int64_t L3_dim, - std::string_view kernel_json, + stream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel_json, hash, false); + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); @@ -330,11 +336,11 @@ ffi::Error tp_backward_impl( ffi::Result L2_grad, ffi::Result W_grad, stream_t stream, - std::string_view kernel_json, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel_json, hash, false); + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -380,11 +386,11 @@ ffi::Error tp_double_backward_impl( ffi::Result W_grad, ffi::Result L3_dgrad, stream_t stream, - std::string_view kernel_json, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel_json, hash, false); + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -429,8 +435,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Ret() .Ctx>() - .Attr("L3_dim") - .Attr("kernel") + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled @@ -445,7 +450,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret() .Ret() .Ctx>() - .Attr("kernel") + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -464,7 +469,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret() .Ret() .Ctx>() - .Attr("kernel") + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -479,11 +484,11 @@ ffi::Error conv_forward_impl( ffi::AnyBuffer transpose_perm, ffi::Result L3_out, stream_t stream, - std::string_view kernel_json, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel_json, hash, true); + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -534,11 +539,11 @@ ffi::Error conv_backward_impl( ffi::AnyBuffer workspace, ffi::AnyBuffer transpose_perm, stream_t stream, - std::string_view kernel_json, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel_json, hash, true); + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -557,6 +562,8 @@ ffi::Error conv_backward_impl( workspace_ptr = nullptr; } zero_buffer(*L1_grad, stream); + zero_buffer(*L2_grad, stream); + zero_buffer(*W_grad, stream); if (k.shared_weights) { check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); @@ -603,11 +610,11 @@ ffi::Error conv_double_backward_impl( ffi::AnyBuffer workspace, ffi::AnyBuffer transpose_perm, stream_t stream, - std::string_view kernel_json, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel_json, hash, true); + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -628,8 +635,9 @@ ffi::Error conv_double_backward_impl( workspace_ptr = nullptr; } zero_buffer(*L1_grad, stream); + zero_buffer(*L2_grad, stream); zero_buffer(*L3_dgrad, stream); - + zero_buffer(*W_grad, stream); if (k.shared_weights) { check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); @@ -638,8 +646,6 @@ ffi::Error conv_double_backward_impl( check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); check_tensor(W_dgrad, {nnz, k.weight_numel}, k.weight_dtype, "W_dgrad"); } - if(k.shared_weights) - zero_buffer(*W_grad, stream); jit_kernel->double_backward( data_ptr(L1_in), @@ -684,7 +690,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Ret() .Ctx>() - .Attr("kernel") + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -703,7 +709,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Ctx>() - .Attr("kernel") + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -726,7 +732,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Ctx>() - .Attr("kernel") + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -759,4 +765,10 @@ NB_MODULE(openequivariance_extjax, m) { .def("start", &GPUTimer::start) .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) .def("clear_L2_cache", &GPUTimer::clear_L2_cache); -} \ No newline at end of file + + /*nb::class_>(m, "DeviceBuffer") + .def(nb::init()) + .def(nb::init()) + .def("copy_to_host", &PyDeviceBuffer::copy_to_host) + .def("data_ptr", &PyDeviceBuffer::data_ptr);*/ +} From 950de7ac601616a7a79fc73a824474bdc20c0b00 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 30 Jan 2026 21:20:55 -0800 Subject: [PATCH 14/18] Zero'd some more buffers. --- openequivariance_extjax/src/libjax_tp_jit.cpp | 139 ++++++++---------- 1 file changed, 63 insertions(+), 76 deletions(-) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 60aa50d0..7c8bf999 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -8,9 +8,11 @@ #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" +#include "json11/json11.hpp" namespace nb = nanobind; namespace ffi = xla::ffi; +using json = json11::Json; #ifdef CUDA_BACKEND #include @@ -131,6 +133,14 @@ void zero_buffer(ffi::AnyBuffer &buffer, stream_t stream) { } #endif +std::unordered_map parse_json_config(const json &j_obj) { + std::unordered_map result; + for (const auto &kv : j_obj.object_items()) { + result[kv.first] = static_cast(kv.second.number_value()); + } + return result; +} + struct KernelProp { int64_t L1_dim, L2_dim, L3_dim, weight_numel; bool shared_weights; @@ -175,39 +185,8 @@ std::unordered_map> conv_cache; std::mutex mut; -std::vector launch_config_keys = { - "num_blocks", - "num_threads", - "smem"}; -std::vector kernel_prop_keys = { - "L1_dim", - "L2_dim", - "L3_dim", - "weight_numel", - "shared_weights", - "opt_level", - "irrep_dtype", - "weight_dtype", - - // Convolution only - "workspace_size", - "deterministic", - "idx_dtype"}; - -std::unordered_map parse_ffi_dict(ffi::Dictionary &dict, const std::vector &keys) { - std::unordered_map result; - for (const auto &key : keys) { - result[key] = dict.get(key).value(); - } - return result; -} - std::pair*, KernelProp> - compile_tp_with_caching(std::string_view kernel, - ffi::Dictionary forward_config, - ffi::Dictionary backward_config, - ffi::Dictionary double_backward_config, - ffi::Dictionary kernel_prop, + compile_tp_with_caching(std::string_view json_payload, int64_t hash, bool is_convolution) { @@ -215,12 +194,21 @@ std::pair*, KernelProp> const std::lock_guard lock(mut); auto it = tp_cache.find(hash); if (it == tp_cache.end()) { - auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); + std::string err; + json root = json::parse(std::string(json_payload), err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + auto jit_tp_impl = std::make_unique>( - std::string(kernel), - parse_ffi_dict(forward_config, launch_config_keys), - parse_ffi_dict(backward_config, launch_config_keys), - parse_ffi_dict(double_backward_config, launch_config_keys), + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, kernel_prop_map); tp_cache.insert({hash, std::make_pair(std::move(jit_tp_impl), @@ -232,11 +220,7 @@ std::pair*, KernelProp> } std::pair*, KernelProp> - compile_conv_with_caching(std::string_view kernel, - ffi::Dictionary forward_config, - ffi::Dictionary backward_config, - ffi::Dictionary double_backward_config, - ffi::Dictionary kernel_prop, + compile_conv_with_caching(std::string_view json_payload, int64_t hash, bool is_convolution) { @@ -244,12 +228,21 @@ std::pair*, KernelProp> const std::lock_guard lock(mut); auto it = conv_cache.find(hash); if (it == conv_cache.end()) { - auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); + std::string err; + json root = json::parse(std::string(json_payload), err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + auto jit_conv_impl = std::make_unique>( - std::string(kernel), - parse_ffi_dict(forward_config, launch_config_keys), - parse_ffi_dict(backward_config, launch_config_keys), - parse_ffi_dict(double_backward_config, launch_config_keys), + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, kernel_prop_map); conv_cache.insert({hash, std::make_pair(std::move(jit_conv_impl), @@ -300,12 +293,13 @@ ffi::Error tp_forward_impl( ffi::AnyBuffer L2_in, ffi::AnyBuffer W, ffi::Result L3_out, - stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + stream_t stream, + int64_t L3_dim, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + kernel_json, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); @@ -336,11 +330,11 @@ ffi::Error tp_backward_impl( ffi::Result L2_grad, ffi::Result W_grad, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + kernel_json, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -386,11 +380,11 @@ ffi::Error tp_double_backward_impl( ffi::Result W_grad, ffi::Result L3_dgrad, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_tp_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + kernel_json, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -435,7 +429,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("L3_dim") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled @@ -450,7 +445,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -469,7 +464,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -484,11 +479,11 @@ ffi::Error conv_forward_impl( ffi::AnyBuffer transpose_perm, ffi::Result L3_out, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + kernel_json, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -539,11 +534,11 @@ ffi::Error conv_backward_impl( ffi::AnyBuffer workspace, ffi::AnyBuffer transpose_perm, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + kernel_json, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -573,8 +568,6 @@ ffi::Error conv_backward_impl( check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); check_tensor(*W_grad, {nnz, k.weight_numel}, k.weight_dtype, "W_grad"); } - if(k.shared_weights) - zero_buffer(*W_grad, stream); jit_kernel->backward( data_ptr(L1_in), @@ -610,11 +603,11 @@ ffi::Error conv_double_backward_impl( ffi::AnyBuffer workspace, ffi::AnyBuffer transpose_perm, stream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + std::string_view kernel_json, int64_t hash) { auto [jit_kernel, k] = compile_conv_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + kernel_json, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; void* workspace_ptr = data_ptr(workspace); @@ -636,8 +629,8 @@ ffi::Error conv_double_backward_impl( } zero_buffer(*L1_grad, stream); zero_buffer(*L2_grad, stream); - zero_buffer(*L3_dgrad, stream); zero_buffer(*W_grad, stream); + zero_buffer(*L3_dgrad, stream); if (k.shared_weights) { check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); @@ -690,7 +683,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Ret() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -709,7 +702,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -732,7 +725,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -765,10 +758,4 @@ NB_MODULE(openequivariance_extjax, m) { .def("start", &GPUTimer::start) .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) .def("clear_L2_cache", &GPUTimer::clear_L2_cache); - - /*nb::class_>(m, "DeviceBuffer") - .def(nb::init()) - .def(nb::init()) - .def("copy_to_host", &PyDeviceBuffer::copy_to_host) - .def("data_ptr", &PyDeviceBuffer::data_ptr);*/ -} +} \ No newline at end of file From 7ec251f259b6bde3523da4d5ed9f21948931227f Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 30 Jan 2026 23:45:41 -0800 Subject: [PATCH 15/18] More things are working. --- .../openequivariance/jax/TensorProductConv.py | 7 +- .../openequivariance/jax/jvp/conv_prim.py | 44 ++----- .../openequivariance/jax/jvp/tp_prim.py | 119 ++++++++---------- .../openequivariance/jax/utils.py | 77 ++---------- openequivariance_extjax/pyproject.toml | 2 +- openequivariance_extjax/src/libjax_tp_jit.cpp | 2 - 6 files changed, 72 insertions(+), 179 deletions(-) diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 5074fe68..5d9f2586 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -68,9 +68,10 @@ def forward( "Must provide sender_perm for deterministic convolutions." ) - func = conv_func.forward - if self.requires_jvp: - func = conv_prim.conv_fwd_p.bind + func = conv_prim.conv_fwd_p.bind + + if not self.requires_jvp: + func = conv_func.forward return func( X, diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index aa1b041d..16c0570b 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -2,24 +2,7 @@ import jax.numpy as jnp from jax.extend import core from jax.interpreters import mlir, ad - -# ============================================================================== -# 0. Helpers -# ============================================================================== - -def ensure_array(tan, primal): - if type(tan) is ad.Zero: - return jnp.zeros_like(primal) - return tan - -def clean_tensors(*tensors): - tensors_clean = [] - for t in tensors: - result = t - if type(t) is ad.Zero or ad.is_undefined_primal(t): - result = jnp.zeros(t.aval.shape, t.aval.dtype) - tensors_clean.append(result) - return tensors_clean +from openequivariance.jax.utils import clean_tensors # ============================================================================== # 1. Forward Primitive @@ -133,9 +116,7 @@ def conv_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, rows, cols, workspace, sende # ============================================================================== def conv_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): - if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) - if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) - if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) + X, Y, W = clean_tensors(X, Y, W) grad_X, grad_Y, grad_W = conv_bwd_p.bind( X, Y, W, ct, @@ -156,10 +137,7 @@ def conv_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): X, Y, W, rows, cols, workspace, sender_perm = primals dX, dY, dW, drows, dcols, dworkspace, dsender_perm = tangents - dX = ensure_array(dX, X) - dY = ensure_array(dY, Y) - dW = ensure_array(dW, W) - + dX, dY, dW = clean_tensors(dX, dY, dW) out_primal = conv_fwd_p.bind(X, Y, W, rows, cols, workspace, sender_perm, L3_dim=L3_dim, kernel=kernel, hash=hash) out_tangent = conv_fwd_jvp_p.bind(X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, L3_dim=L3_dim, kernel=kernel, hash=hash) @@ -270,8 +248,8 @@ def conv_bwd_jvp_rule(primals, tangents, *, kernel, hash): X, Y, W, dZ, rows, cols, workspace, sender_perm = primals tX, tY, tW, tdZ, drows, dcols, dworkspace, dsender_perm = tangents - tX, tY, tW, tdZ = [ensure_array(t, p) for t, p in zip((tX, tY, tW, tdZ), (X, Y, W, dZ))] - + tX, tY, tW, tdZ = clean_tensors(tX, tY, tW, tdZ) + out_primal = conv_bwd_p.bind( X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash @@ -317,8 +295,7 @@ def conv_dbwd_slow(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_per # ============================================================================== def conv_dbwd_jvp_rule(primals, tangents, *, kernel, hash): - # Infer L3_dim from dZ (4th input) - dZ = primals[3] + dZ = primals[3] # Infer L3_dim from dZ (4th input) L3_dim = dZ.shape[1] def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): @@ -338,16 +315,9 @@ def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): # ============================================================================== def conv_dbwd_transpose(ct, X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash): - # Infer L3_dim from dZ L3_dim = dZ.shape[1] - if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) - if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) - if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) - if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) - if ad.is_undefined_primal(ddX): ddX = jnp.zeros(ddX.aval.shape, ddX.aval.dtype) - if ad.is_undefined_primal(ddY): ddY = jnp.zeros(ddY.aval.shape, ddY.aval.dtype) - if ad.is_undefined_primal(ddW): ddW = jnp.zeros(ddW.aval.shape, ddW.aval.dtype) + X, Y, W, dZ, ddX, ddY, ddW = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): return conv_dbwd_slow( diff --git a/openequivariance/openequivariance/jax/jvp/tp_prim.py b/openequivariance/openequivariance/jax/jvp/tp_prim.py index ac914b54..764fcb87 100644 --- a/openequivariance/openequivariance/jax/jvp/tp_prim.py +++ b/openequivariance/openequivariance/jax/jvp/tp_prim.py @@ -1,11 +1,8 @@ -import json import jax import jax.numpy as jnp from jax.extend import core from jax.interpreters import mlir, ad - -# Implements the ladder of derivatives for tensor product -# via primitives, JVP, and transpose instead of custom_vjp. +from openequivariance.jax.utils import clean_tensors # ============================================================================== # 1. Forward Primitive @@ -17,7 +14,7 @@ def tp_fwd_impl(X, Y, W, *, L3_dim, kernel, hash): irrep_dtype = X.dtype out_shape = jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) call = jax.ffi.ffi_call("tp_forward", out_shape) - return call(X, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) + return call(X, Y, W, kernel=kernel, hash=hash) def tp_fwd_abstract_eval(X, Y, W, *, L3_dim, kernel, hash): return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) @@ -42,10 +39,8 @@ def tp_bwd_impl(X, Y, W, dZ, *, kernel, hash): jax.ShapeDtypeStruct(Y.shape, irrep_dtype), jax.ShapeDtypeStruct(W.shape, irrep_dtype), ) - call = jax.ffi.ffi_call("tp_backward", out_shapes) - result = call(X, Y, W, dZ, kernel=kernel, hash=hash) - return result + return call(X, Y, W, dZ, kernel=kernel, hash=hash) def tp_bwd_abstract_eval(X, Y, W, dZ, *, kernel, hash): irrep_dtype = X.dtype @@ -93,6 +88,7 @@ def tp_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): mlir.register_lowering(tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="cuda") mlir.register_lowering(tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="rocm") + # ============================================================================== # 4. Forward JVP Primitive Definition # ============================================================================== @@ -100,9 +96,11 @@ def tp_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): tp_fwd_jvp_p = core.Primitive("tp_fwd_jvp") def tp_fwd_jvp_impl(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): - term1 = tp_fwd_p.bind(dX, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) - term2 = tp_fwd_p.bind(X, dY, W, L3_dim=L3_dim, kernel=kernel, hash=hash) - term3 = tp_fwd_p.bind(X, Y, dW, L3_dim=L3_dim, kernel=kernel, hash=hash) + kwargs = dict(L3_dim=L3_dim, kernel=kernel, hash=hash) + + term1 = tp_fwd_p.bind(dX, Y, W, **kwargs) + term2 = tp_fwd_p.bind(X, dY, W, **kwargs) + term3 = tp_fwd_p.bind(X, Y, dW, **kwargs) return term1 + term2 + term3 def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): @@ -117,12 +115,7 @@ def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): # ============================================================================== def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): - if ad.is_undefined_primal(X): - X = jnp.zeros(X.aval.shape, X.aval.dtype) - if ad.is_undefined_primal(Y): - Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) - if ad.is_undefined_primal(W): - W = jnp.zeros(W.aval.shape, W.aval.dtype) + X, Y, W = clean_tensors(X, Y, W) grad_X, grad_Y, grad_W = tp_bwd_p.bind(X, Y, W, ct, kernel=kernel, hash=hash) @@ -130,10 +123,6 @@ def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): ad.primitive_transposes[tp_fwd_jvp_p] = tp_fwd_jvp_transpose -def ensure_array(tan, primal): - if type(tan) is ad.Zero: - return jnp.zeros_like(primal) - return tan # ============================================================================== # 6. JVP Rule for Original Forward Primitive @@ -143,9 +132,7 @@ def tp_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): X, Y, W = primals dX, dY, dW = tangents - dX = ensure_array(dX, X) - dY = ensure_array(dY, Y) - dW = ensure_array(dW, W) + dX, dY, dW = clean_tensors(dX, dY, dW) out_primal = tp_fwd_p.bind(X, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) out_tangent = tp_fwd_jvp_p.bind(X, Y, W, dX, dY, dW, L3_dim=L3_dim, kernel=kernel, hash=hash) @@ -160,13 +147,7 @@ def tp_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): # ============================================================================== def tp_fwd_jvp_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): - tangents_clean = [] - for t, p in zip(tangents, primals): - if type(t) is ad.Zero: - tangents_clean.append(jnp.zeros_like(p)) - else: - tangents_clean.append(t) - tangents_clean = tuple(tangents_clean) + tangents_clean = tuple(clean_tensors(*tangents)) def func(x, y, w, dx, dy, dw): return tp_fwd_jvp_impl(x, y, w, dx, dy, dw, L3_dim=L3_dim, kernel=kernel, hash=hash) @@ -184,10 +165,12 @@ def func(x, y, w, dx, dy, dw): tp_bwd_jvp_p.multiple_results = True def tp_bwd_jvp_impl(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): - term_dZ = tp_bwd_p.bind(X, Y, W, tdZ, kernel=kernel, hash=hash) - term_X = tp_bwd_p.bind(tX, Y, W, dZ, kernel=kernel, hash=hash) - term_Y = tp_bwd_p.bind(X, tY, W, dZ, kernel=kernel, hash=hash) - term_W = tp_bwd_p.bind(X, Y, tW, dZ, kernel=kernel, hash=hash) + kwargs = dict(kernel=kernel, hash=hash) + + term_dZ = tp_bwd_p.bind(X, Y, W, tdZ, **kwargs) + term_X = tp_bwd_p.bind(tX, Y, W, dZ, **kwargs) + term_Y = tp_bwd_p.bind(X, tY, W, dZ, **kwargs) + term_W = tp_bwd_p.bind(X, Y, tW, dZ, **kwargs) out_dX = term_dZ[0] + term_Y[0] + term_W[0] out_dY = term_dZ[1] + term_X[1] + term_W[1] @@ -219,7 +202,11 @@ def tp_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) - g_X, g_Y, g_W, g_dZ = tp_dbwd_p.bind(X, Y, W, dZ, ddX, ddY, ddW, kernel=kernel, hash=hash) + tensors_clean = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) + + g_X, g_Y, g_W, g_dZ = tp_dbwd_p.bind( + *tensors_clean, kernel=kernel, hash=hash + ) return (None, None, None, None, g_X, g_Y, g_W, g_dZ) @@ -231,13 +218,7 @@ def tp_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): # ============================================================================== def tp_bwd_jvp_jvp_rule(primals, tangents, *, kernel, hash): - tangents_clean = [] - for t, p in zip(tangents, primals): - if type(t) is ad.Zero: - tangents_clean.append(jnp.zeros_like(p)) - else: - tangents_clean.append(t) - tangents_clean = tuple(tangents_clean) + tangents_clean = tuple(clean_tensors(*tangents)) def func(x, y, w, dz, tx, ty, tw, tdz): return tp_bwd_jvp_impl(x, y, w, dz, tx, ty, tw, tdz, kernel=kernel, hash=hash) @@ -255,7 +236,8 @@ def tp_bwd_jvp_rule(primals, tangents, *, kernel, hash): X, Y, W, dZ = primals tX, tY, tW, tdZ = tangents - tX, tY, tW, tdZ = [ensure_array(t, p) for t, p in zip(tangents, primals)] + tX, tY, tW, tdZ = clean_tensors(tX, tY, tW, tdZ) + out_primal = tp_bwd_p.bind(X, Y, W, dZ, kernel=kernel, hash=hash) out_tangent = tp_bwd_jvp_p.bind(X, Y, W, dZ, tX, tY, tW, tdZ, kernel=kernel, hash=hash) @@ -263,18 +245,23 @@ def tp_bwd_jvp_rule(primals, tangents, *, kernel, hash): ad.primitive_jvps[tp_bwd_p] = tp_bwd_jvp_rule + # ============================================================================== # 12. Slow Double Backward Implementation (Reference) # ============================================================================== def tp_dbwd_slow(X, Y, W, dZ, ddX, ddY, ddW, *, L3_dim, kernel, hash): - op1 = tp_bwd_p.bind(ddX, ddY, W, dZ, kernel=kernel, hash=hash) - op2 = tp_bwd_p.bind(X, Y, ddW, dZ, kernel=kernel, hash=hash) - op3 = tp_fwd_p.bind(ddX, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) - op4 = tp_bwd_p.bind(ddX, Y, W, dZ, kernel=kernel, hash=hash) - op5 = tp_bwd_p.bind(X, ddY, W, dZ, kernel=kernel, hash=hash) - op6 = tp_fwd_p.bind(X, ddY, W, L3_dim=L3_dim, kernel=kernel, hash=hash) - op7 = tp_fwd_p.bind(X, Y, ddW, L3_dim=L3_dim, kernel=kernel, hash=hash) + kwargs = dict(kernel=kernel, hash=hash) + + op1 = tp_bwd_p.bind(ddX, ddY, W, dZ, **kwargs) + op2 = tp_bwd_p.bind(X, Y, ddW, dZ, **kwargs) + + op3 = tp_fwd_p.bind(ddX, Y, W, L3_dim=L3_dim, **kwargs) + op4 = tp_bwd_p.bind(ddX, Y, W, dZ, **kwargs) + op5 = tp_bwd_p.bind(X, ddY, W, dZ, **kwargs) + + op6 = tp_fwd_p.bind(X, ddY, W, L3_dim=L3_dim, **kwargs) + op7 = tp_fwd_p.bind(X, Y, ddW, L3_dim=L3_dim, **kwargs) grad_X = op1[0] + op2[0] grad_Y = op1[1] + op2[1] @@ -283,38 +270,32 @@ def tp_dbwd_slow(X, Y, W, dZ, ddX, ddY, ddW, *, L3_dim, kernel, hash): return grad_X, grad_Y, grad_W, grad_dZ + # ============================================================================== -# 12. JVP rule for double backward (implicit) +# 13. JVP rule for double backward (implicit) # ============================================================================== -def tp_dbwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): - tangents_clean = [] - for t, p in zip(tangents, primals): - if type(t) is ad.Zero: - tangents_clean.append(jnp.zeros_like(p)) - else: - tangents_clean.append(t) - tangents_clean = tuple(tangents_clean) +def tp_dbwd_jvp_rule(primals, tangents, *, kernel, hash): + dZ = primals[3] # Infer L3_dim from dZ (4th input) + L3_dim = dZ.shape[1] def func(x, y, w, dz, ddx, ddy, ddw): return tp_dbwd_slow(x, y, w, dz, ddx, ddy, ddw, L3_dim=L3_dim, kernel=kernel, hash=hash) + tangents_clean = tuple(clean_tensors(*tangents)) return jax.jvp(func, primals, tangents_clean) ad.primitive_jvps[tp_dbwd_p] = tp_dbwd_jvp_rule + # ============================================================================== -# 12. Transpose rule for double backward +# 14. Transpose rule for double backward # ============================================================================== -def tp_dbwd_transpose(ct, X, Y, W, dZ, ddX, ddY, ddW, *, L3_dim, kernel, hash): - if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) - if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) - if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) - if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) - if ad.is_undefined_primal(ddX): ddX = jnp.zeros(ddX.aval.shape, ddX.aval.dtype) - if ad.is_undefined_primal(ddY): ddY = jnp.zeros(ddY.aval.shape, ddY.aval.dtype) - if ad.is_undefined_primal(ddW): ddW = jnp.zeros(ddW.aval.shape, ddW.aval.dtype) +def tp_dbwd_transpose(ct, X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): + L3_dim = dZ.shape[1] + + X, Y, W, dZ, ddX, ddY, ddW = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) def func(x, y, w, dz, ddx, ddy, ddw): return tp_dbwd_slow(x, y, w, dz, ddx, ddy, ddw, L3_dim=L3_dim, kernel=kernel, hash=hash) diff --git a/openequivariance/openequivariance/jax/utils.py b/openequivariance/openequivariance/jax/utils.py index 00bcee7a..02b1b7a9 100644 --- a/openequivariance/openequivariance/jax/utils.py +++ b/openequivariance/openequivariance/jax/utils.py @@ -1,8 +1,8 @@ import jax import jax.numpy as jnp import numpy as np -import functools -import traceback +from jax.interpreters import ad + def reorder_jax_helper(schedule, weights_in, direction, has_batch_dim): assert direction in ["forward", "backward"] @@ -64,68 +64,11 @@ def reorder_jax(schedule, weights_in, direction, has_batch_dim): return reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim) -_indentation = 0 -def _trace(msg=None): - """Print a message at current indentation.""" - if msg is not None: - print(" " * _indentation + msg) - -def _trace_indent(msg=None): - """Print a message and then indent the rest.""" - global _indentation - _trace(msg) - _indentation = 1 + _indentation - -def _trace_unindent(msg=None): - """Unindent then print a message.""" - global _indentation - _indentation = _indentation - 1 - _trace(msg) - -def trace(name): - """A decorator for functions to trace arguments and results.""" - - def trace_func(func): # pylint: disable=missing-docstring - def pp(v): - """Print certain values more succinctly""" - vtype = str(type(v)) - if "jax._src.xla_bridge._JaxComputationBuilder" in vtype: - return "" - elif "jaxlib._jax_.XlaOp" in vtype: - return "".format(id(v)) - elif ("partial_eval.JaxprTracer" in vtype or - "batching.BatchTracer" in vtype or - "ad.JVPTracer" in vtype): - return "Traced<{}>".format(v.aval) - elif isinstance(v, tuple): - return "({})".format(pp_values(v)) - else: - return str(v) - def pp_values(args): - return ", ".join([pp(arg) for arg in args]) - - @functools.wraps(func) - def func_wrapper(*args): - _trace_indent("call {}({})".format(name, pp_values(args))) - res = func(*args) - _trace_unindent("|<- {} = {}".format(name, pp(res))) - return res - - return func_wrapper - - return trace_func - -class expectNotImplementedError(object): - """Context manager to check for NotImplementedError.""" - def __enter__(self): pass - def __exit__(self, type, value, tb): - global _indentation - _indentation = 0 - if type is NotImplementedError: - print("\nFound expected exception:") - traceback.print_exc(limit=3) - return True - elif type is None: # No exception - assert False, "Expected NotImplementedError" - else: - return False \ No newline at end of file +def clean_tensors(*tensors): + tensors_clean = [] + for t in tensors: + result = t + if type(t) is ad.Zero or ad.is_undefined_primal(t): + result = jnp.zeros(t.aval.shape, t.aval.dtype) + tensors_clean.append(result) + return tensors_clean diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index c9b2856f..def67e12 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance_extjax" -version = "0.1.0" +version = "0.2.0" authors = [ { name="Austin Glover" }, { name="Vivek Bharadwaj" }, diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 7c8bf999..bf3fd69e 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -294,7 +294,6 @@ ffi::Error tp_forward_impl( ffi::AnyBuffer W, ffi::Result L3_out, stream_t stream, - int64_t L3_dim, std::string_view kernel_json, int64_t hash) { @@ -429,7 +428,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Ret() .Ctx>() - .Attr("L3_dim") .Attr("kernel") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled From 067d2d6100b4400ad1617f28f400cb76a0bb523d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 30 Jan 2026 23:57:07 -0800 Subject: [PATCH 16/18] Fixed things up. --- CHANGELOG.md | 17 +++++++++++++++++ .../openequivariance/jax/TensorProductConv.py | 11 +++++++++++ 2 files changed, 28 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9bd29e1..410608be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ ## Latest Changes +### v0.5.4 (2025-02-01) +Minor update, improvements to JAX. + +**Added**: +- Jacobian Vector Products (JVP) + for both `TensorProduct` and `TensorProductConv` via custom primitives, in addition to VJP. +- Arbitrary higher-order derivatives in JAX. +- JAX JIT support; in particular, support for + Phonon Fine Tuning in [Nequix](https://github.com/atomicarchitects/nequix). + +**Fixed**: +- Zero'd all output buffers in the backwards and double-backwards implementations of convolution +before calling kernels. + +### v0.5.1-0.5.3 (2025-02-01) +Minor bugfixes related to packaging and JAX. + ### v0.5.0 (2025-12-25) JAX support is now available in OpenEquivariance for BOTH NVIDIA and diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 5d9f2586..464a326d 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -19,6 +19,17 @@ class TensorProductConv(LoopUnrollConv): + r""" + Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one + key difference: integer arrays passed to this function must have dtype + ``np.int32`` (as opposed to ``np.int64`` in the PyTorch version). + + :param problem: Specification of the tensor product. + :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic + fixup-based algorithm. `Default`: ``False``. + :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option, + the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``. + """ def __init__( self, config: TPProblem, deterministic: bool = False, kahan: bool = False, requires_jvp: bool = True ): From 6cffa9db32c1ba659d2cee93ee051c495d63de7b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 31 Jan 2026 00:32:01 -0800 Subject: [PATCH 17/18] Precommit. --- .../openequivariance/core/utils.py | 1 - .../openequivariance/jax/TensorProduct.py | 69 +++-- .../openequivariance/jax/TensorProductConv.py | 88 +++--- .../openequivariance/jax/__init__.py | 4 +- .../openequivariance/jax/jvp/conv_prim.py | 286 ++++++++++++++---- .../openequivariance/jax/jvp/tp_prim.py | 118 ++++++-- .../openequivariance/jax/utils.py | 2 +- .../openequivariance/jax/vjp/conv_func.py | 44 +-- tests/conftest.py | 1 + tests/example_test.py | 2 +- 10 files changed, 435 insertions(+), 180 deletions(-) diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 15a5ca25..50f35bd4 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -7,7 +7,6 @@ import json import tempfile -import hashlib from enum import IntEnum diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index ae8617d4..bf7a1445 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -1,14 +1,13 @@ import jax -import jax.numpy as jnp import numpy as np -from functools import partial from openequivariance.jax import extlib from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollTP import LoopUnrollTP -from openequivariance.jax.utils import reorder_jax +from openequivariance.jax.utils import reorder_jax from openequivariance.jax.jvp.tp_prim import tp_fwd_p import json + class TensorProduct(LoopUnrollTP): r""" Identical to ``oeq.torch.TensorProduct`` with functionality in JAX. @@ -20,14 +19,18 @@ def __init__(self, problem: TPProblem): dp = extlib.DeviceProp(0) super().__init__(problem, dp, extlib.postprocess_kernel, torch_op=False) - self.kernel = json.dumps({ - "kernel": self.jit_kernel, - "forward_config": vars(self.forward_schedule.launch_config), - "backward_config": vars(self.backward_schedule.launch_config), - "double_backward_config": vars(self.double_backward_schedule.launch_config), - "kernel_prop": self.kernelProp, - }) - self.hash = self.kernel.__hash__() + self.kernel = json.dumps( + { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars( + self.double_backward_schedule.launch_config + ), + "kernel_prop": self.kernelProp, + } + ) + self.hash = self.kernel.__hash__() self.weight_numel = problem.weight_numel self.L3_dim = self.config.irreps_out.dim @@ -35,7 +38,9 @@ def __init__(self, problem: TPProblem): def forward( self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray ) -> jax.numpy.ndarray: - return tp_fwd_p.bind(X, Y, W, L3_dim=self.L3_dim, kernel=self.kernel, hash=self.hash) + return tp_fwd_p.bind( + X, Y, W, L3_dim=self.L3_dim, kernel=self.kernel, hash=self.hash + ) def __call__( self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray @@ -69,12 +74,14 @@ def backward_cpu( weights = self.reorder_weights_from_e3nn( weights, has_batch_dim=not self.config.shared_weights ) - backward_fn = jax.jit(jax.vjp( - lambda X, Y, W: self.forward(X, Y, W), - jax.numpy.asarray(L1_in), - jax.numpy.asarray(L2_in), - jax.numpy.asarray(weights), - )[1]) + backward_fn = jax.jit( + jax.vjp( + lambda X, Y, W: self.forward(X, Y, W), + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + )[1] + ) L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn( jax.numpy.asarray(L3_grad) ) @@ -96,16 +103,20 @@ def double_backward_cpu( in2_dgrad_jax = jax.numpy.asarray(in2_dgrad) weights_dgrad_jax = jax.numpy.asarray(weights_dgrad) - dbwd_func = jax.jit(jax.vjp( - lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[ - 1 - ](o), - in1_jax, - in2_jax, - weights_jax, - out_grad_jax, - )[1]) - - in1_grad, in2_grad, weights_grad, out_dgrad = dbwd_func((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) + dbwd_func = jax.jit( + jax.vjp( + lambda x, y, w, o: jax.vjp( + lambda a, b, c: self.forward(a, b, c), x, y, w + )[1](o), + in1_jax, + in2_jax, + weights_jax, + out_grad_jax, + )[1] + ) + + in1_grad, in2_grad, weights_grad, out_dgrad = dbwd_func( + (in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax) + ) return in1_grad, in2_grad, weights_grad, out_dgrad diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 464a326d..52102294 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -2,7 +2,6 @@ import json import jax.numpy as jnp import numpy as np -from functools import partial from typing import Optional from openequivariance.jax import extlib @@ -30,8 +29,13 @@ class TensorProductConv(LoopUnrollConv): :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option, the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``. """ + def __init__( - self, config: TPProblem, deterministic: bool = False, kahan: bool = False, requires_jvp: bool = True + self, + config: TPProblem, + deterministic: bool = False, + kahan: bool = False, + requires_jvp: bool = True, ): dp = extlib.DeviceProp(0) self.requires_jvp = requires_jvp @@ -45,14 +49,18 @@ def __init__( kahan=kahan, ) - self.kernel = json.dumps({ - "kernel": self.jit_kernel, - "forward_config": vars(self.forward_schedule.launch_config), - "backward_config": vars(self.backward_schedule.launch_config), - "double_backward_config": vars(self.double_backward_schedule.launch_config), - "kernel_prop": self.kernel_prop, - }) - self.hash = self.kernel.__hash__() + self.kernel = json.dumps( + { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars( + self.double_backward_schedule.launch_config + ), + "kernel_prop": self.kernel_prop, + } + ) + self.hash = self.kernel.__hash__() self.weight_numel = config.weight_numel self.L3_dim = self.config.irreps_out.dim @@ -94,7 +102,7 @@ def forward( sender_perm, L3_dim=self.L3_dim, kernel=self.kernel, - hash=self.hash + hash=self.hash, ) def __call__( @@ -151,19 +159,21 @@ def backward_cpu( weights, has_batch_dim=not self.config.shared_weights ) - backward_fn = jax.jit(jax.vjp( - lambda X, Y, W: self.forward( - X, - Y, - W, - jax.numpy.asarray(rows), - jax.numpy.asarray(cols), - jax.numpy.asarray(sender_perm), - ), - jax.numpy.asarray(L1_in), - jax.numpy.asarray(L2_in), - jax.numpy.asarray(weights), - )[1]) + backward_fn = jax.jit( + jax.vjp( + lambda X, Y, W: self.forward( + X, + Y, + W, + jax.numpy.asarray(rows), + jax.numpy.asarray(cols), + jax.numpy.asarray(sender_perm), + ), + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + )[1] + ) L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn( jax.numpy.asarray(L3_grad) @@ -190,20 +200,22 @@ def double_backward_cpu( cols_jax = jax.numpy.asarray(graph.cols.astype(self.idx_dtype)) sender_perm_jax = jax.numpy.asarray(graph.transpose_perm.astype(self.idx_dtype)) - in1_grad, in2_grad, weights_grad, out_dgrad = jax.jit(jax.vjp( - lambda x, y, w, o: jax.vjp( - lambda a, b, c: self.forward( - a, b, c, rows_jax, cols_jax, sender_perm_jax - ), - x, - y, - w, - )[1](o), - in1_jax, - in2_jax, - weights_jax, - out_grad_jax, - )[1])((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) + in1_grad, in2_grad, weights_grad, out_dgrad = jax.jit( + jax.vjp( + lambda x, y, w, o: jax.vjp( + lambda a, b, c: self.forward( + a, b, c, rows_jax, cols_jax, sender_perm_jax + ), + x, + y, + w, + )[1](o), + in1_jax, + in2_jax, + weights_jax, + out_grad_jax, + )[1] + )((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) return ( np.asarray(in1_grad), diff --git a/openequivariance/openequivariance/jax/__init__.py b/openequivariance/openequivariance/jax/__init__.py index b4e14b24..410e5dbf 100644 --- a/openequivariance/openequivariance/jax/__init__.py +++ b/openequivariance/openequivariance/jax/__init__.py @@ -1,4 +1,6 @@ from openequivariance.jax.TensorProduct import TensorProduct as TensorProduct -from openequivariance.jax.TensorProductConv import TensorProductConv as TensorProductConv +from openequivariance.jax.TensorProductConv import ( + TensorProductConv as TensorProductConv, +) __all__ = ["TensorProduct", "TensorProductConv"] diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index 16c0570b..9324b785 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -10,19 +10,28 @@ conv_fwd_p = core.Primitive("conv_fwd") + def conv_fwd_impl(X, Y, W, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): irrep_dtype = X.dtype out_shape = jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) call = jax.ffi.ffi_call("conv_forward", out_shape) return call(X, Y, W, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash) -def conv_fwd_abstract_eval(X, Y, W, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + +def conv_fwd_abstract_eval( + X, Y, W, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash +): return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + conv_fwd_p.def_impl(conv_fwd_impl) conv_fwd_p.def_abstract_eval(conv_fwd_abstract_eval) -mlir.register_lowering(conv_fwd_p, mlir.lower_fun(conv_fwd_impl, multiple_results=False), platform="cuda") -mlir.register_lowering(conv_fwd_p, mlir.lower_fun(conv_fwd_impl, multiple_results=False), platform="rocm") +mlir.register_lowering( + conv_fwd_p, mlir.lower_fun(conv_fwd_impl, multiple_results=False), platform="cuda" +) +mlir.register_lowering( + conv_fwd_p, mlir.lower_fun(conv_fwd_impl, multiple_results=False), platform="rocm" +) # ============================================================================== @@ -32,6 +41,7 @@ def conv_fwd_abstract_eval(X, Y, W, rows, cols, workspace, sender_perm, *, L3_di conv_bwd_p = core.Primitive("conv_bwd") conv_bwd_p.multiple_results = True + def conv_bwd_impl(X, Y, W, dZ, rows, cols, workspace, sender_perm, *, kernel, hash): irrep_dtype = X.dtype out_shapes = ( @@ -40,9 +50,14 @@ def conv_bwd_impl(X, Y, W, dZ, rows, cols, workspace, sender_perm, *, kernel, ha jax.ShapeDtypeStruct(W.shape, irrep_dtype), ) call = jax.ffi.ffi_call("conv_backward", out_shapes) - return call(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash) + return call( + X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) + -def conv_bwd_abstract_eval(X, Y, W, dZ, rows, cols, workspace, sender_perm, *, kernel, hash): +def conv_bwd_abstract_eval( + X, Y, W, dZ, rows, cols, workspace, sender_perm, *, kernel, hash +): irrep_dtype = X.dtype return ( jax.core.ShapedArray(X.shape, irrep_dtype), @@ -50,10 +65,15 @@ def conv_bwd_abstract_eval(X, Y, W, dZ, rows, cols, workspace, sender_perm, *, k jax.core.ShapedArray(W.shape, irrep_dtype), ) + conv_bwd_p.def_impl(conv_bwd_impl) conv_bwd_p.def_abstract_eval(conv_bwd_abstract_eval) -mlir.register_lowering(conv_bwd_p, mlir.lower_fun(conv_bwd_impl, multiple_results=True), platform="cuda") -mlir.register_lowering(conv_bwd_p, mlir.lower_fun(conv_bwd_impl, multiple_results=True), platform="rocm") +mlir.register_lowering( + conv_bwd_p, mlir.lower_fun(conv_bwd_impl, multiple_results=True), platform="cuda" +) +mlir.register_lowering( + conv_bwd_p, mlir.lower_fun(conv_bwd_impl, multiple_results=True), platform="rocm" +) # ============================================================================== @@ -63,7 +83,10 @@ def conv_bwd_abstract_eval(X, Y, W, dZ, rows, cols, workspace, sender_perm, *, k conv_dbwd_p = core.Primitive("conv_dbwd") conv_dbwd_p.multiple_results = True -def conv_dbwd_impl(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash): + +def conv_dbwd_impl( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash +): irrep_dtype = X.dtype out_shapes = ( jax.ShapeDtypeStruct(X.shape, irrep_dtype), @@ -72,9 +95,26 @@ def conv_dbwd_impl(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_per jax.ShapeDtypeStruct(dZ.shape, irrep_dtype), ) call = jax.ffi.ffi_call("conv_double_backward", out_shapes) - return call(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash) + return call( + X, + Y, + W, + dZ, + ddX, + ddY, + ddW, + rows, + cols, + workspace, + sender_perm, + kernel=kernel, + hash=hash, + ) -def conv_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash): + +def conv_dbwd_abstract_eval( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash +): irrep_dtype = X.dtype return ( jax.core.ShapedArray(X.shape, irrep_dtype), @@ -83,10 +123,15 @@ def conv_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, s jax.core.ShapedArray(dZ.shape, irrep_dtype), ) + conv_dbwd_p.def_impl(conv_dbwd_impl) conv_dbwd_p.def_abstract_eval(conv_dbwd_abstract_eval) -mlir.register_lowering(conv_dbwd_p, mlir.lower_fun(conv_dbwd_impl, multiple_results=True), platform="cuda") -mlir.register_lowering(conv_dbwd_p, mlir.lower_fun(conv_dbwd_impl, multiple_results=True), platform="rocm") +mlir.register_lowering( + conv_dbwd_p, mlir.lower_fun(conv_dbwd_impl, multiple_results=True), platform="cuda" +) +mlir.register_lowering( + conv_dbwd_p, mlir.lower_fun(conv_dbwd_impl, multiple_results=True), platform="rocm" +) # ============================================================================== @@ -95,18 +140,25 @@ def conv_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, s conv_fwd_jvp_p = core.Primitive("conv_fwd_jvp") -def conv_fwd_jvp_impl(X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + +def conv_fwd_jvp_impl( + X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash +): kwargs = dict(L3_dim=L3_dim, kernel=kernel, hash=hash) args_meta = (rows, cols, workspace, sender_perm) - + term1 = conv_fwd_p.bind(dX, Y, W, *args_meta, **kwargs) term2 = conv_fwd_p.bind(X, dY, W, *args_meta, **kwargs) term3 = conv_fwd_p.bind(X, Y, dW, *args_meta, **kwargs) return term1 + term2 + term3 -def conv_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + +def conv_fwd_jvp_abstract_eval( + X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash +): return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + conv_fwd_jvp_p.def_impl(conv_fwd_jvp_impl) conv_fwd_jvp_p.def_abstract_eval(conv_fwd_jvp_abstract_eval) @@ -115,17 +167,19 @@ def conv_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, rows, cols, workspace, sende # 5. Transpose Rule (Implicit VJP) # ============================================================================== -def conv_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + +def conv_fwd_jvp_transpose( + ct, X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash +): X, Y, W = clean_tensors(X, Y, W) grad_X, grad_Y, grad_W = conv_bwd_p.bind( - X, Y, W, ct, - rows, cols, workspace, sender_perm, - kernel=kernel, hash=hash + X, Y, W, ct, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash ) return (None, None, None, grad_X, grad_Y, grad_W, None, None, None, None) + ad.primitive_transposes[conv_fwd_jvp_p] = conv_fwd_jvp_transpose @@ -133,16 +187,43 @@ def conv_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, rows, cols, workspace, sende # 6. JVP Rule for Original Forward Primitive # ============================================================================== + def conv_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): X, Y, W, rows, cols, workspace, sender_perm = primals dX, dY, dW, drows, dcols, dworkspace, dsender_perm = tangents dX, dY, dW = clean_tensors(dX, dY, dW) - out_primal = conv_fwd_p.bind(X, Y, W, rows, cols, workspace, sender_perm, L3_dim=L3_dim, kernel=kernel, hash=hash) - out_tangent = conv_fwd_jvp_p.bind(X, Y, W, dX, dY, dW, rows, cols, workspace, sender_perm, L3_dim=L3_dim, kernel=kernel, hash=hash) + out_primal = conv_fwd_p.bind( + X, + Y, + W, + rows, + cols, + workspace, + sender_perm, + L3_dim=L3_dim, + kernel=kernel, + hash=hash, + ) + out_tangent = conv_fwd_jvp_p.bind( + X, + Y, + W, + dX, + dY, + dW, + rows, + cols, + workspace, + sender_perm, + L3_dim=L3_dim, + kernel=kernel, + hash=hash, + ) return out_primal, out_tangent + ad.primitive_jvps[conv_fwd_p] = conv_fwd_jvp_rule @@ -150,28 +231,32 @@ def conv_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): # 7. JVP Rule for Forward JVP Primitive (Higher Order) # ============================================================================== + def conv_fwd_jvp_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): tangents_clean = tuple(clean_tensors(*tangents)) def func(x, y, w, dx, dy, dw, r, c, ws, sp): return conv_fwd_jvp_impl( - x, y, w, dx, dy, dw, r, c, ws, sp, - L3_dim=L3_dim, kernel=kernel, hash=hash + x, y, w, dx, dy, dw, r, c, ws, sp, L3_dim=L3_dim, kernel=kernel, hash=hash ) return jax.jvp(func, primals, tangents_clean) + ad.primitive_jvps[conv_fwd_jvp_p] = conv_fwd_jvp_jvp_rule # ============================================================================== -# 8. Backward JVP Primitive Definition +# 8. Backward JVP Primitive Definition # ============================================================================== conv_bwd_jvp_p = core.Primitive("conv_bwd_jvp") conv_bwd_jvp_p.multiple_results = True -def conv_bwd_jvp_impl(X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, *, kernel, hash): + +def conv_bwd_jvp_impl( + X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, *, kernel, hash +): kwargs = dict(kernel=kernel, hash=hash) args_meta = (rows, cols, workspace, sender_perm) @@ -179,14 +264,17 @@ def conv_bwd_jvp_impl(X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sende term_X = conv_bwd_p.bind(tX, Y, W, dZ, *args_meta, **kwargs) term_Y = conv_bwd_p.bind(X, tY, W, dZ, *args_meta, **kwargs) term_W = conv_bwd_p.bind(X, Y, tW, dZ, *args_meta, **kwargs) - + out_dX = term_dZ[0] + term_Y[0] + term_W[0] out_dY = term_dZ[1] + term_X[1] + term_W[1] out_dW = term_dZ[2] + term_X[2] + term_Y[2] - + return out_dX, out_dY, out_dW -def conv_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, *, kernel, hash): + +def conv_bwd_jvp_abstract_eval( + X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, *, kernel, hash +): irrep_dtype = X.dtype return ( jax.core.ShapedArray(X.shape, irrep_dtype), @@ -194,6 +282,7 @@ def conv_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspa jax.core.ShapedArray(W.shape, irrep_dtype), ) + conv_bwd_jvp_p.def_impl(conv_bwd_jvp_impl) conv_bwd_jvp_p.def_abstract_eval(conv_bwd_jvp_abstract_eval) @@ -202,23 +291,45 @@ def conv_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspa # 9. Transpose Rule for Backward JVP # ============================================================================== -def conv_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, *, kernel, hash): + +def conv_bwd_jvp_transpose( + ct, + X, + Y, + W, + dZ, + tX, + tY, + tW, + tdZ, + rows, + cols, + workspace, + sender_perm, + *, + kernel, + hash, +): ddX, ddY, ddW = ct - if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) - if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) - if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) - if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) + if ad.is_undefined_primal(X): + X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): + Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): + W = jnp.zeros(W.aval.shape, W.aval.dtype) + if ad.is_undefined_primal(dZ): + dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) tensors_clean = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) g_X, g_Y, g_W, g_dZ = conv_dbwd_p.bind( - *tensors_clean, rows, cols, workspace, sender_perm, - kernel=kernel, hash=hash + *tensors_clean, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash ) return (None, None, None, None, g_X, g_Y, g_W, g_dZ, None, None, None, None) + ad.primitive_transposes[conv_bwd_jvp_p] = conv_bwd_jvp_transpose @@ -226,17 +337,18 @@ def conv_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspa # 10. JVP Rule for Backward JVP Primitive (Higher Order) # ============================================================================== + def conv_bwd_jvp_jvp_rule(primals, tangents, *, kernel, hash): tangents_clean = tuple(clean_tensors(*tangents)) def func(x, y, w, dz, tx, ty, tw, tdz, r, c, ws, sp): return conv_bwd_jvp_impl( - x, y, w, dz, tx, ty, tw, tdz, r, c, ws, sp, - kernel=kernel, hash=hash + x, y, w, dz, tx, ty, tw, tdz, r, c, ws, sp, kernel=kernel, hash=hash ) return jax.jvp(func, primals, tangents_clean) + ad.primitive_jvps[conv_bwd_jvp_p] = conv_bwd_jvp_jvp_rule @@ -244,6 +356,7 @@ def func(x, y, w, dz, tx, ty, tw, tdz, r, c, ws, sp): # 11. JVP Rule for Original Backward Primitive # ============================================================================== + def conv_bwd_jvp_rule(primals, tangents, *, kernel, hash): X, Y, W, dZ, rows, cols, workspace, sender_perm = primals tX, tY, tW, tdZ, drows, dcols, dworkspace, dsender_perm = tangents @@ -251,15 +364,27 @@ def conv_bwd_jvp_rule(primals, tangents, *, kernel, hash): tX, tY, tW, tdZ = clean_tensors(tX, tY, tW, tdZ) out_primal = conv_bwd_p.bind( - X, Y, W, dZ, rows, cols, workspace, sender_perm, - kernel=kernel, hash=hash + X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash ) out_tangent = conv_bwd_jvp_p.bind( - X, Y, W, dZ, tX, tY, tW, tdZ, rows, cols, workspace, sender_perm, - kernel=kernel, hash=hash + X, + Y, + W, + dZ, + tX, + tY, + tW, + tdZ, + rows, + cols, + workspace, + sender_perm, + kernel=kernel, + hash=hash, ) - return out_primal, out_tangent + return out_primal, out_tangent + ad.primitive_jvps[conv_bwd_p] = conv_bwd_jvp_rule @@ -268,17 +393,34 @@ def conv_bwd_jvp_rule(primals, tangents, *, kernel, hash): # 12. Slow Double Backward Implementation (Reference) # ============================================================================== -def conv_dbwd_slow(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, L3_dim, kernel, hash): + +def conv_dbwd_slow( + X, + Y, + W, + dZ, + ddX, + ddY, + ddW, + rows, + cols, + workspace, + sender_perm, + *, + L3_dim, + kernel, + hash, +): kwargs = dict(kernel=kernel, hash=hash) args_meta = (rows, cols, workspace, sender_perm) - + op1 = conv_bwd_p.bind(ddX, ddY, W, dZ, *args_meta, **kwargs) op2 = conv_bwd_p.bind(X, Y, ddW, dZ, *args_meta, **kwargs) - + op3 = conv_fwd_p.bind(ddX, Y, W, *args_meta, L3_dim=L3_dim, **kwargs) op4 = conv_bwd_p.bind(ddX, Y, W, dZ, *args_meta, **kwargs) op5 = conv_bwd_p.bind(X, ddY, W, dZ, *args_meta, **kwargs) - + op6 = conv_fwd_p.bind(X, ddY, W, *args_meta, L3_dim=L3_dim, **kwargs) op7 = conv_fwd_p.bind(X, Y, ddW, *args_meta, L3_dim=L3_dim, **kwargs) @@ -291,22 +433,36 @@ def conv_dbwd_slow(X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_per # ============================================================================== -# 13. JVP rule for double backward (implicit) +# 13. JVP rule for double backward (implicit) # ============================================================================== + def conv_dbwd_jvp_rule(primals, tangents, *, kernel, hash): - dZ = primals[3] # Infer L3_dim from dZ (4th input) + dZ = primals[3] # Infer L3_dim from dZ (4th input) L3_dim = dZ.shape[1] def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): return conv_dbwd_slow( - x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp, - L3_dim=L3_dim, kernel=kernel, hash=hash + x, + y, + w, + dz, + ddx, + ddy, + ddw, + r, + c, + ws, + sp, + L3_dim=L3_dim, + kernel=kernel, + hash=hash, ) tangents_clean = tuple(clean_tensors(*tangents)) return jax.jvp(func, primals, tangents_clean) + ad.primitive_jvps[conv_dbwd_p] = conv_dbwd_jvp_rule @@ -314,20 +470,38 @@ def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): # 14. Transpose rule for double backward # ============================================================================== -def conv_dbwd_transpose(ct, X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash): + +def conv_dbwd_transpose( + ct, X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, *, kernel, hash +): L3_dim = dZ.shape[1] X, Y, W, dZ, ddX, ddY, ddW = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) def func(x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp): return conv_dbwd_slow( - x, y, w, dz, ddx, ddy, ddw, r, c, ws, sp, - L3_dim=L3_dim, kernel=kernel, hash=hash + x, + y, + w, + dz, + ddx, + ddy, + ddw, + r, + c, + ws, + sp, + L3_dim=L3_dim, + kernel=kernel, + hash=hash, ) - _, vjp_fun = jax.vjp(func, X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm) + _, vjp_fun = jax.vjp( + func, X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm + ) input_grads = vjp_fun(ct) - + return input_grads -ad.primitive_transposes[conv_dbwd_p] = conv_dbwd_transpose \ No newline at end of file + +ad.primitive_transposes[conv_dbwd_p] = conv_dbwd_transpose diff --git a/openequivariance/openequivariance/jax/jvp/tp_prim.py b/openequivariance/openequivariance/jax/jvp/tp_prim.py index 764fcb87..c31c3ec0 100644 --- a/openequivariance/openequivariance/jax/jvp/tp_prim.py +++ b/openequivariance/openequivariance/jax/jvp/tp_prim.py @@ -10,19 +10,26 @@ tp_fwd_p = core.Primitive("tp_fwd") + def tp_fwd_impl(X, Y, W, *, L3_dim, kernel, hash): irrep_dtype = X.dtype out_shape = jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) call = jax.ffi.ffi_call("tp_forward", out_shape) return call(X, Y, W, kernel=kernel, hash=hash) + def tp_fwd_abstract_eval(X, Y, W, *, L3_dim, kernel, hash): return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + tp_fwd_p.def_impl(tp_fwd_impl) tp_fwd_p.def_abstract_eval(tp_fwd_abstract_eval) -mlir.register_lowering(tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="cuda") -mlir.register_lowering(tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="rocm") +mlir.register_lowering( + tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="cuda" +) +mlir.register_lowering( + tp_fwd_p, mlir.lower_fun(tp_fwd_impl, multiple_results=False), platform="rocm" +) # ============================================================================== @@ -30,7 +37,8 @@ def tp_fwd_abstract_eval(X, Y, W, *, L3_dim, kernel, hash): # ============================================================================== tp_bwd_p = core.Primitive("tp_bwd") -tp_bwd_p.multiple_results = True +tp_bwd_p.multiple_results = True + def tp_bwd_impl(X, Y, W, dZ, *, kernel, hash): irrep_dtype = X.dtype @@ -42,6 +50,7 @@ def tp_bwd_impl(X, Y, W, dZ, *, kernel, hash): call = jax.ffi.ffi_call("tp_backward", out_shapes) return call(X, Y, W, dZ, kernel=kernel, hash=hash) + def tp_bwd_abstract_eval(X, Y, W, dZ, *, kernel, hash): irrep_dtype = X.dtype return ( @@ -50,10 +59,15 @@ def tp_bwd_abstract_eval(X, Y, W, dZ, *, kernel, hash): jax.core.ShapedArray(W.shape, irrep_dtype), ) + tp_bwd_p.def_impl(tp_bwd_impl) tp_bwd_p.def_abstract_eval(tp_bwd_abstract_eval) -mlir.register_lowering(tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="cuda") -mlir.register_lowering(tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="rocm") +mlir.register_lowering( + tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="cuda" +) +mlir.register_lowering( + tp_bwd_p, mlir.lower_fun(tp_bwd_impl, multiple_results=True), platform="rocm" +) # ============================================================================== @@ -63,6 +77,7 @@ def tp_bwd_abstract_eval(X, Y, W, dZ, *, kernel, hash): tp_dbwd_p = core.Primitive("tp_dbwd") tp_dbwd_p.multiple_results = True + def tp_dbwd_impl(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): irrep_dtype = X.dtype out_shapes = ( @@ -74,6 +89,7 @@ def tp_dbwd_impl(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): call = jax.ffi.ffi_call("tp_double_backward", out_shapes) return call(X, Y, W, dZ, ddX, ddY, ddW, kernel=kernel, hash=hash) + def tp_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): irrep_dtype = X.dtype return ( @@ -83,10 +99,15 @@ def tp_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): jax.core.ShapedArray(dZ.shape, irrep_dtype), ) + tp_dbwd_p.def_impl(tp_dbwd_impl) tp_dbwd_p.def_abstract_eval(tp_dbwd_abstract_eval) -mlir.register_lowering(tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="cuda") -mlir.register_lowering(tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="rocm") +mlir.register_lowering( + tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="cuda" +) +mlir.register_lowering( + tp_dbwd_p, mlir.lower_fun(tp_dbwd_impl, multiple_results=True), platform="rocm" +) # ============================================================================== @@ -95,17 +116,20 @@ def tp_dbwd_abstract_eval(X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): tp_fwd_jvp_p = core.Primitive("tp_fwd_jvp") + def tp_fwd_jvp_impl(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): kwargs = dict(L3_dim=L3_dim, kernel=kernel, hash=hash) - + term1 = tp_fwd_p.bind(dX, Y, W, **kwargs) term2 = tp_fwd_p.bind(X, dY, W, **kwargs) term3 = tp_fwd_p.bind(X, Y, dW, **kwargs) return term1 + term2 + term3 + def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): return jax.core.ShapedArray((X.shape[0], L3_dim), X.dtype) + tp_fwd_jvp_p.def_impl(tp_fwd_jvp_impl) tp_fwd_jvp_p.def_abstract_eval(tp_fwd_jvp_abstract_eval) @@ -114,6 +138,7 @@ def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): # 5. Transpose Rule (Implicit VJP) # ============================================================================== + def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): X, Y, W = clean_tensors(X, Y, W) @@ -121,6 +146,7 @@ def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): return (None, None, None, grad_X, grad_Y, grad_W) + ad.primitive_transposes[tp_fwd_jvp_p] = tp_fwd_jvp_transpose @@ -128,17 +154,21 @@ def tp_fwd_jvp_transpose(ct, X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): # 6. JVP Rule for Original Forward Primitive # ============================================================================== + def tp_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): X, Y, W = primals dX, dY, dW = tangents - + dX, dY, dW = clean_tensors(dX, dY, dW) out_primal = tp_fwd_p.bind(X, Y, W, L3_dim=L3_dim, kernel=kernel, hash=hash) - out_tangent = tp_fwd_jvp_p.bind(X, Y, W, dX, dY, dW, L3_dim=L3_dim, kernel=kernel, hash=hash) + out_tangent = tp_fwd_jvp_p.bind( + X, Y, W, dX, dY, dW, L3_dim=L3_dim, kernel=kernel, hash=hash + ) return out_primal, out_tangent + ad.primitive_jvps[tp_fwd_p] = tp_fwd_jvp_rule @@ -146,24 +176,29 @@ def tp_fwd_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): # 7. JVP Rule for Forward JVP Primitive (Higher Order) # ============================================================================== + def tp_fwd_jvp_jvp_rule(primals, tangents, *, L3_dim, kernel, hash): tangents_clean = tuple(clean_tensors(*tangents)) def func(x, y, w, dx, dy, dw): - return tp_fwd_jvp_impl(x, y, w, dx, dy, dw, L3_dim=L3_dim, kernel=kernel, hash=hash) + return tp_fwd_jvp_impl( + x, y, w, dx, dy, dw, L3_dim=L3_dim, kernel=kernel, hash=hash + ) return jax.jvp(func, primals, tangents_clean) + ad.primitive_jvps[tp_fwd_jvp_p] = tp_fwd_jvp_jvp_rule # ============================================================================== -# 8. Backward JVP Primitive Definition +# 8. Backward JVP Primitive Definition # ============================================================================== tp_bwd_jvp_p = core.Primitive("tp_bwd_jvp") tp_bwd_jvp_p.multiple_results = True + def tp_bwd_jvp_impl(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): kwargs = dict(kernel=kernel, hash=hash) @@ -171,13 +206,14 @@ def tp_bwd_jvp_impl(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): term_X = tp_bwd_p.bind(tX, Y, W, dZ, **kwargs) term_Y = tp_bwd_p.bind(X, tY, W, dZ, **kwargs) term_W = tp_bwd_p.bind(X, Y, tW, dZ, **kwargs) - + out_dX = term_dZ[0] + term_Y[0] + term_W[0] out_dY = term_dZ[1] + term_X[1] + term_W[1] out_dW = term_dZ[2] + term_X[2] + term_Y[2] - + return out_dX, out_dY, out_dW + def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): irrep_dtype = X.dtype return ( @@ -186,6 +222,7 @@ def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): jax.core.ShapedArray(W.shape, irrep_dtype), ) + tp_bwd_jvp_p.def_impl(tp_bwd_jvp_impl) tp_bwd_jvp_p.def_abstract_eval(tp_bwd_jvp_abstract_eval) @@ -194,22 +231,26 @@ def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): # 9. Transpose Rule for Backward JVP # ============================================================================== + def tp_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): ddX, ddY, ddW = ct - if ad.is_undefined_primal(X): X = jnp.zeros(X.aval.shape, X.aval.dtype) - if ad.is_undefined_primal(Y): Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) - if ad.is_undefined_primal(W): W = jnp.zeros(W.aval.shape, W.aval.dtype) - if ad.is_undefined_primal(dZ): dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) + if ad.is_undefined_primal(X): + X = jnp.zeros(X.aval.shape, X.aval.dtype) + if ad.is_undefined_primal(Y): + Y = jnp.zeros(Y.aval.shape, Y.aval.dtype) + if ad.is_undefined_primal(W): + W = jnp.zeros(W.aval.shape, W.aval.dtype) + if ad.is_undefined_primal(dZ): + dZ = jnp.zeros(dZ.aval.shape, dZ.aval.dtype) tensors_clean = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) - g_X, g_Y, g_W, g_dZ = tp_dbwd_p.bind( - *tensors_clean, kernel=kernel, hash=hash - ) + g_X, g_Y, g_W, g_dZ = tp_dbwd_p.bind(*tensors_clean, kernel=kernel, hash=hash) return (None, None, None, None, g_X, g_Y, g_W, g_dZ) + ad.primitive_transposes[tp_bwd_jvp_p] = tp_bwd_jvp_transpose @@ -217,6 +258,7 @@ def tp_bwd_jvp_transpose(ct, X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): # 10. JVP Rule for Backward JVP Primitive (Higher Order) # ============================================================================== + def tp_bwd_jvp_jvp_rule(primals, tangents, *, kernel, hash): tangents_clean = tuple(clean_tensors(*tangents)) @@ -225,6 +267,7 @@ def func(x, y, w, dz, tx, ty, tw, tdz): return jax.jvp(func, primals, tangents_clean) + ad.primitive_jvps[tp_bwd_jvp_p] = tp_bwd_jvp_jvp_rule @@ -232,6 +275,7 @@ def func(x, y, w, dz, tx, ty, tw, tdz): # 11. JVP Rule for Original Backward Primitive # ============================================================================== + def tp_bwd_jvp_rule(primals, tangents, *, kernel, hash): X, Y, W, dZ = primals tX, tY, tW, tdZ = tangents @@ -239,9 +283,12 @@ def tp_bwd_jvp_rule(primals, tangents, *, kernel, hash): tX, tY, tW, tdZ = clean_tensors(tX, tY, tW, tdZ) out_primal = tp_bwd_p.bind(X, Y, W, dZ, kernel=kernel, hash=hash) - out_tangent = tp_bwd_jvp_p.bind(X, Y, W, dZ, tX, tY, tW, tdZ, kernel=kernel, hash=hash) + out_tangent = tp_bwd_jvp_p.bind( + X, Y, W, dZ, tX, tY, tW, tdZ, kernel=kernel, hash=hash + ) + + return out_primal, out_tangent - return out_primal, out_tangent ad.primitive_jvps[tp_bwd_p] = tp_bwd_jvp_rule @@ -250,16 +297,17 @@ def tp_bwd_jvp_rule(primals, tangents, *, kernel, hash): # 12. Slow Double Backward Implementation (Reference) # ============================================================================== + def tp_dbwd_slow(X, Y, W, dZ, ddX, ddY, ddW, *, L3_dim, kernel, hash): kwargs = dict(kernel=kernel, hash=hash) - + op1 = tp_bwd_p.bind(ddX, ddY, W, dZ, **kwargs) op2 = tp_bwd_p.bind(X, Y, ddW, dZ, **kwargs) - + op3 = tp_fwd_p.bind(ddX, Y, W, L3_dim=L3_dim, **kwargs) op4 = tp_bwd_p.bind(ddX, Y, W, dZ, **kwargs) op5 = tp_bwd_p.bind(X, ddY, W, dZ, **kwargs) - + op6 = tp_fwd_p.bind(X, ddY, W, L3_dim=L3_dim, **kwargs) op7 = tp_fwd_p.bind(X, Y, ddW, L3_dim=L3_dim, **kwargs) @@ -272,19 +320,23 @@ def tp_dbwd_slow(X, Y, W, dZ, ddX, ddY, ddW, *, L3_dim, kernel, hash): # ============================================================================== -# 13. JVP rule for double backward (implicit) +# 13. JVP rule for double backward (implicit) # ============================================================================== + def tp_dbwd_jvp_rule(primals, tangents, *, kernel, hash): - dZ = primals[3] # Infer L3_dim from dZ (4th input) + dZ = primals[3] # Infer L3_dim from dZ (4th input) L3_dim = dZ.shape[1] def func(x, y, w, dz, ddx, ddy, ddw): - return tp_dbwd_slow(x, y, w, dz, ddx, ddy, ddw, L3_dim=L3_dim, kernel=kernel, hash=hash) + return tp_dbwd_slow( + x, y, w, dz, ddx, ddy, ddw, L3_dim=L3_dim, kernel=kernel, hash=hash + ) tangents_clean = tuple(clean_tensors(*tangents)) return jax.jvp(func, primals, tangents_clean) + ad.primitive_jvps[tp_dbwd_p] = tp_dbwd_jvp_rule @@ -292,17 +344,21 @@ def func(x, y, w, dz, ddx, ddy, ddw): # 14. Transpose rule for double backward # ============================================================================== + def tp_dbwd_transpose(ct, X, Y, W, dZ, ddX, ddY, ddW, *, kernel, hash): L3_dim = dZ.shape[1] X, Y, W, dZ, ddX, ddY, ddW = clean_tensors(X, Y, W, dZ, ddX, ddY, ddW) def func(x, y, w, dz, ddx, ddy, ddw): - return tp_dbwd_slow(x, y, w, dz, ddx, ddy, ddw, L3_dim=L3_dim, kernel=kernel, hash=hash) + return tp_dbwd_slow( + x, y, w, dz, ddx, ddy, ddw, L3_dim=L3_dim, kernel=kernel, hash=hash + ) _, vjp_fun = jax.vjp(func, X, Y, W, dZ, ddX, ddY, ddW) input_grads = vjp_fun(ct) return input_grads -ad.primitive_transposes[tp_dbwd_p] = tp_dbwd_transpose \ No newline at end of file + +ad.primitive_transposes[tp_dbwd_p] = tp_dbwd_transpose diff --git a/openequivariance/openequivariance/jax/utils.py b/openequivariance/openequivariance/jax/utils.py index 02b1b7a9..371b0ae5 100644 --- a/openequivariance/openequivariance/jax/utils.py +++ b/openequivariance/openequivariance/jax/utils.py @@ -66,7 +66,7 @@ def reorder_jax(schedule, weights_in, direction, has_batch_dim): def clean_tensors(*tensors): tensors_clean = [] - for t in tensors: + for t in tensors: result = t if type(t) is ad.Zero or ad.is_undefined_primal(t): result = jnp.zeros(t.aval.shape, t.aval.dtype) diff --git a/openequivariance/openequivariance/jax/vjp/conv_func.py b/openequivariance/openequivariance/jax/vjp/conv_func.py index 5d19df9d..20d7ccc2 100644 --- a/openequivariance/openequivariance/jax/vjp/conv_func.py +++ b/openequivariance/openequivariance/jax/vjp/conv_func.py @@ -1,26 +1,24 @@ import jax import jax.numpy as jnp -from jax.extend import core from functools import partial -from jax.interpreters import mlir, ad + def zeros_like(x): return jnp.zeros_like(x) + @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, kernel, hash): forward_call = jax.ffi.ffi_call( "conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), X.dtype) ) - return forward_call(X, Y, W, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash) + return forward_call( + X, Y, W, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) -def forward_fwd( - X, Y, W, rows, cols, workspace, sender_perm, L3_dim, kernel, hash -): - out = forward( - X, Y, W, rows, cols, workspace, sender_perm, L3_dim, kernel, hash - ) +def forward_fwd(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, kernel, hash): + out = forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, kernel, hash) return out, (X, Y, W, rows, cols) @@ -45,7 +43,9 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel, hash): jax.ShapeDtypeStruct(W.shape, W.dtype), ), ) - return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash) + return backward_call( + X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash + ) def backward_fwd(X, Y, W, dZ, rows, cols, workspace, sender_perm, kernel, hash): @@ -81,7 +81,7 @@ def backward_bwd(workspace, sender_perm, kernel, hash, res, derivatives): @partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12)) def double_backward( - X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel, hash + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel, hash ): double_backward_call = jax.ffi.ffi_call( "conv_double_backward", @@ -93,14 +93,6 @@ def double_backward( ), ) return double_backward_call( - X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel=kernel, hash=hash - ) - - -def double_backward_fwd( - X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel, hash -): - out = double_backward( X, Y, W, @@ -112,8 +104,16 @@ def double_backward_fwd( cols, workspace, sender_perm, - kernel, - hash + kernel=kernel, + hash=hash, + ) + + +def double_backward_fwd( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel, hash +): + out = double_backward( + X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, kernel, hash ) return out, (X, Y, W, dZ, ddX, ddY, ddW, rows, cols) @@ -159,4 +159,4 @@ def triple_backward( return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW, None, None -double_backward.defvjp(double_backward_fwd, triple_backward) \ No newline at end of file +double_backward.defvjp(double_backward_fwd, triple_backward) diff --git a/tests/conftest.py b/tests/conftest.py index ad6cb9fa..323de863 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" os.environ["JAX_TRACEBACK_FILTERING"] = "off" + def pytest_addoption(parser): parser.addoption( "--jax", diff --git a/tests/example_test.py b/tests/example_test.py index 2ab6f6ab..3a0828ac 100644 --- a/tests/example_test.py +++ b/tests/example_test.py @@ -135,4 +135,4 @@ def test_tutorial_jax(with_jax): result = jax.vjp(lambda X, Y, W: tp_fast(X, Y, W), X, Y, W)[1]( jax.numpy.ones_like(Z) ) - print(result) \ No newline at end of file + print(result) From a27e293e4e375a96c0935582a3fd51b90aa1c4d7 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 31 Jan 2026 00:36:34 -0800 Subject: [PATCH 18/18] Merge changes. --- CHANGELOG.md | 2 +- tests/example_test.py | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 410608be..14cd1a72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ ## Latest Changes ### v0.5.4 (2025-02-01) -Minor update, improvements to JAX. +Improvements to JAX frontend. **Added**: - Jacobian Vector Products (JVP) diff --git a/tests/example_test.py b/tests/example_test.py index 3a0828ac..e8d23cb7 100644 --- a/tests/example_test.py +++ b/tests/example_test.py @@ -132,7 +132,35 @@ def test_tutorial_jax(with_jax): Z = tp_fast(X, Y, W) print(jax.numpy.linalg.norm(Z)) - result = jax.vjp(lambda X, Y, W: tp_fast(X, Y, W), X, Y, W)[1]( - jax.numpy.ones_like(Z) + edge_index = jax.numpy.array( + [ + [0, 1, 1, 2], + [1, 0, 2, 1], + ], + dtype=jax.numpy.int32, # NOTE: This int32, not int64 + ) + + node_ct, nonzero_ct = 3, 4 + X = jax.random.uniform( + key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32 + ) + Y = jax.random.uniform( + key, + shape=(nonzero_ct, Y_ir.dim), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, + ) + W = jax.random.uniform( + key, + shape=(nonzero_ct, problem.weight_numel), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, ) - print(result) + tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) + Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) + print(jax.numpy.linalg.norm(Z)) + + jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2)) + print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1])))