diff --git a/src/python/client.c b/src/python/client.c index e5b2662..7f97f97 100644 --- a/src/python/client.c +++ b/src/python/client.c @@ -127,6 +127,31 @@ sequence_to_str_array(PyObject *sequence) return res; } +static gchar * +parse_arch_list(PyObject *arches) +{ + gchar *arch_list = NULL; + + if (1 != PySequence_Check(arches)) { + PyErr_SetString(PyExc_TypeError, "arches must be an iterable"); + return NULL; + } + + const gchar **arch_array = sequence_to_str_array(arches); + if (NULL == arch_array) { + return NULL; + } + + arch_list = g_strjoinv(" ", (gchar **)arch_array); + g_free(arch_array); + if (NULL == arch_list) { + PyErr_NoMemory(); + return NULL; + } + + return arch_list; +} + static PyObject * client_add(ClientObject *self, PyObject *args) { @@ -141,18 +166,10 @@ client_add(ClientObject *self, PyObject *args) } if (arches != NULL && arches != Py_None) { - if (1 != PySequence_Check(arches)) { - PyErr_SetString(PyExc_TypeError, "arches must be an iterable"); - return NULL; - } - - const gchar **arch_array = sequence_to_str_array(arches); - if (NULL == arch_array) { + arch_list = parse_arch_list(arches); + if (NULL == arch_list) { return NULL; } - - arch_list = g_strjoinv(" ", (gchar **)arch_array); - g_free(arch_array); } cmd = g_strjoin(" ", "ADD", package, arch_list, NULL); @@ -268,6 +285,49 @@ client_set_invalidate_family(ClientObject *self, PyObject *args) return set_option(self, "invalidate_family", invalidate_family); } +static PyObject * +client_sync(ClientObject *self, PyObject *args, PyObject *kwargs) +{ + char *base_url = NULL; + char *pattern = NULL; + PyObject *arches = NULL; + gchar *arch_list = NULL; + gchar *cmd; + PyObject *ret; + + static char * keywords[] = { + "base_url", + "pattern", + "arches", + NULL, + }; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|zO", keywords, &base_url, &pattern, &arches)) { + return NULL; + } + + if (arches != NULL && arches != Py_None) { + arch_list = parse_arch_list(arches); + if (NULL == arch_list) { + return NULL; + } + } + + if (NULL == pattern) { + cmd = g_strjoin(" ", "SYNC", base_url, arch_list, NULL); + } else { + cmd = g_strjoin(" ", "SYNC_PATTERN", base_url, pattern, arch_list, NULL); + } + g_free(arch_list); + if (!cmd) { + return PyErr_NoMemory(); + } + + ret = execute_transaction(self, cmd); + g_free(cmd); + return ret; +} + static PyObject * client_enter(ClientObject *self, PyObject *args) { @@ -299,6 +359,7 @@ static struct PyMethodDef client_methods[] = { {"disconnect", (PyCFunction)client_disconnect, METH_NOARGS, NULL}, {"set_invalidate_dependants", (PyCFunction)client_set_invalidate_dependants, METH_VARARGS, NULL}, {"set_invalidate_family", (PyCFunction)client_set_invalidate_family, METH_VARARGS, NULL}, + {"sync", (PyCFunction)(void(*)(void))client_sync, METH_VARARGS | METH_KEYWORDS, NULL}, {"__enter__", (PyCFunction)client_enter, METH_NOARGS, NULL}, {"__exit__", (PyCFunction)client_disconnect, METH_VARARGS, NULL}, {NULL, NULL, 0, NULL} diff --git a/test/test_smoke.py b/test/test_smoke.py index 68512a4..8f05a4b 100644 --- a/test/test_smoke.py +++ b/test/test_smoke.py @@ -7,6 +7,14 @@ import pytest FIXTURES_DIR = Path(__file__).parent / 'fixtures' +POPULATED_REPO = FIXTURES_DIR / 'populated' +POPULATED_RPM = Path( + POPULATED_REPO, + 'x86_64', + 'Packages', + 'r', + 'ros-dev-tools-1.0.1-1.el9.noarch.rpm', +) def test_version(): @@ -14,21 +22,26 @@ def test_version(): def test_add(tmp_path): - packages_path = FIXTURES_DIR / 'populated' / 'x86_64' / 'Packages' - rpm_path = packages_path / 'r' / 'ros-dev-tools-1.0.1-1.el9.noarch.rpm' + rpm_path = str(POPULATED_RPM) with createrepo_agent.Server(str(tmp_path)): with createrepo_agent.Client(str(tmp_path)) as c: with pytest.raises(TypeError): - c.add(str(rpm_path), 1) + c.add(None) with pytest.raises(TypeError): - c.add(str(rpm_path), (1,)) + c.add(rpm_path, 1) + with pytest.raises(TypeError): + c.add(rpm_path, (1,)) c.set_invalidate_dependants(True) c.set_invalidate_family(True) - c.add(str(rpm_path), ('x86_64',)) + c.add(rpm_path, ('x86_64',)) c.commit() - assert (tmp_path / 'x86_64' / 'repodata' / 'repomd.xml').is_file() + arch_path = tmp_path / 'x86_64' + repomd_path = arch_path / 'repodata' / 'repomd.xml' + + assert repomd_path.is_file() + assert (arch_path / 'Packages' / 'r' / POPULATED_RPM.name).is_file() def test_commit_nothing(tmp_path): @@ -45,3 +58,63 @@ def test_server_socket_collision(tmp_path): with pytest.raises(OSError): with createrepo_agent.Server(str(tmp_path)): pass + + +def test_sync_all(tmp_path): + base_url = POPULATED_REPO.as_uri() + with createrepo_agent.Server(str(tmp_path)): + with createrepo_agent.Client(str(tmp_path)) as c: + with pytest.raises(TypeError): + c.sync(None) + with pytest.raises(TypeError): + c.sync(base_url, arches=1) + with pytest.raises(TypeError): + c.sync(base_url, arches=(1,)) + c.sync(base_url, arches=('x86_64',)) + c.commit() + + arch_path = tmp_path / 'x86_64' + repomd_path = arch_path / 'repodata' / 'repomd.xml' + + assert repomd_path.is_file() + assert (arch_path / 'Packages' / 'r' / POPULATED_RPM.name).is_file() + + # Performing the same operation again results in no changes, so CRA shouldn't + # make any changes to the metadata at all. + + old_repomd_contents = repomd_path.read_text() + + with createrepo_agent.Server(str(tmp_path)): + with createrepo_agent.Client(str(tmp_path)) as c: + c.sync(base_url, arches=('x86_64',)) + c.commit() + + assert old_repomd_contents == repomd_path.read_text() + + +def test_sync_pattern_hit(tmp_path): + base_url = POPULATED_REPO.as_uri() + pattern = POPULATED_RPM.name[:3] + '.*' + with createrepo_agent.Server(str(tmp_path)): + with createrepo_agent.Client(str(tmp_path)) as c: + c.sync(base_url, pattern, ('x86_64',)) + c.commit() + + arch_path = tmp_path / 'x86_64' + repomd_path = arch_path / 'repodata' / 'repomd.xml' + + assert repomd_path.is_file() + assert (arch_path / 'Packages' / 'r' / POPULATED_RPM.name).is_file() + + +def test_sync_pattern_miss(tmp_path): + base_url = POPULATED_REPO.as_uri() + pattern = 'does-not-match' + with createrepo_agent.Server(str(tmp_path)): + with createrepo_agent.Client(str(tmp_path)) as c: + c.sync(base_url, pattern, ('x86_64',)) + c.commit() + + arch_path = tmp_path / 'x86_64' + + assert not (arch_path / 'Packages' / 'r' / POPULATED_RPM.name).is_file()